interface: address feedback for PaddedRSTransport

This commit is contained in:
SomberNight
2025-05-29 13:45:26 +00:00
parent 447052b4ff
commit c9ed8779fc

View File

@@ -140,6 +140,7 @@ class NotificationSession(RPCSession):
self.cache = {}
self._msg_counter = itertools.count(start=1)
self.interface = interface
self.taskgroup = interface.taskgroup
self.cost_hard_limit = 0 # disable aiorpcx resource limits
async def handle_request(self, request):
@@ -267,11 +268,6 @@ class ConnectError(NetworkException): pass
class _RSClient(RSClient):
def __init__(self, *, transport=None, **kwargs):
if transport is None:
transport = PaddedRSTransport
RSClient.__init__(self, transport=transport, **kwargs)
async def create_connection(self):
try:
return await super().create_connection()
@@ -310,7 +306,8 @@ class PaddedRSTransport(RSTransport):
self._sbuffer_has_data_evt.set()
self._maybe_consume_sbuffer()
def _maybe_consume_sbuffer(self):
def _maybe_consume_sbuffer(self) -> None:
"""Maybe take some data from sbuffer and send it on the wire."""
if not self._can_send.is_set() or self.is_closing():
return
buf = self._sbuffer
@@ -376,7 +373,7 @@ class PaddedRSTransport(RSTransport):
def connection_made(self, transport: asyncio.BaseTransport):
super().connection_made(transport)
if isinstance(self.session, NotificationSession):
coro = self.session.interface.taskgroup.spawn(self._poll_sbuffer())
coro = self.session.taskgroup.spawn(self._poll_sbuffer())
self._sbuffer_task = self.loop.create_task(coro)
else:
# This a short-lived "fetch_certificate"-type session.
@@ -716,9 +713,13 @@ class Interface(Logger):
sslc = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT)
sslc.check_hostname = False
sslc.verify_mode = ssl.CERT_NONE
async with _RSClient(session_factory=RPCSession,
host=self.host, port=self.port,
ssl=sslc, proxy=self.proxy) as session:
async with _RSClient(
session_factory=RPCSession,
host=self.host, port=self.port,
ssl=sslc,
proxy=self.proxy,
transport=PaddedRSTransport,
) as session:
asyncio_transport = session.transport._asyncio_transport # type: asyncio.BaseTransport
ssl_object = asyncio_transport.get_extra_info("ssl_object") # type: ssl.SSLObject
return ssl_object.getpeercert(binary_form=True)
@@ -787,9 +788,13 @@ class Interface(Logger):
async def open_session(self, sslc, exit_early=False):
session_factory = lambda *args, iface=self, **kwargs: NotificationSession(*args, **kwargs, interface=iface)
async with _RSClient(session_factory=session_factory,
host=self.host, port=self.port,
ssl=sslc, proxy=self.proxy) as session:
async with _RSClient(
session_factory=session_factory,
host=self.host, port=self.port,
ssl=sslc,
proxy=self.proxy,
transport=PaddedRSTransport,
) as session:
self.session = session # type: NotificationSession
self.session.set_default_timeout(self.network.get_network_timeout_seconds(NetworkTimeout.Generic))
try: