interface: address feedback for PaddedRSTransport
This commit is contained in:
@@ -140,6 +140,7 @@ class NotificationSession(RPCSession):
|
||||
self.cache = {}
|
||||
self._msg_counter = itertools.count(start=1)
|
||||
self.interface = interface
|
||||
self.taskgroup = interface.taskgroup
|
||||
self.cost_hard_limit = 0 # disable aiorpcx resource limits
|
||||
|
||||
async def handle_request(self, request):
|
||||
@@ -267,11 +268,6 @@ class ConnectError(NetworkException): pass
|
||||
|
||||
|
||||
class _RSClient(RSClient):
|
||||
def __init__(self, *, transport=None, **kwargs):
|
||||
if transport is None:
|
||||
transport = PaddedRSTransport
|
||||
RSClient.__init__(self, transport=transport, **kwargs)
|
||||
|
||||
async def create_connection(self):
|
||||
try:
|
||||
return await super().create_connection()
|
||||
@@ -310,7 +306,8 @@ class PaddedRSTransport(RSTransport):
|
||||
self._sbuffer_has_data_evt.set()
|
||||
self._maybe_consume_sbuffer()
|
||||
|
||||
def _maybe_consume_sbuffer(self):
|
||||
def _maybe_consume_sbuffer(self) -> None:
|
||||
"""Maybe take some data from sbuffer and send it on the wire."""
|
||||
if not self._can_send.is_set() or self.is_closing():
|
||||
return
|
||||
buf = self._sbuffer
|
||||
@@ -376,7 +373,7 @@ class PaddedRSTransport(RSTransport):
|
||||
def connection_made(self, transport: asyncio.BaseTransport):
|
||||
super().connection_made(transport)
|
||||
if isinstance(self.session, NotificationSession):
|
||||
coro = self.session.interface.taskgroup.spawn(self._poll_sbuffer())
|
||||
coro = self.session.taskgroup.spawn(self._poll_sbuffer())
|
||||
self._sbuffer_task = self.loop.create_task(coro)
|
||||
else:
|
||||
# This a short-lived "fetch_certificate"-type session.
|
||||
@@ -716,9 +713,13 @@ class Interface(Logger):
|
||||
sslc = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT)
|
||||
sslc.check_hostname = False
|
||||
sslc.verify_mode = ssl.CERT_NONE
|
||||
async with _RSClient(session_factory=RPCSession,
|
||||
host=self.host, port=self.port,
|
||||
ssl=sslc, proxy=self.proxy) as session:
|
||||
async with _RSClient(
|
||||
session_factory=RPCSession,
|
||||
host=self.host, port=self.port,
|
||||
ssl=sslc,
|
||||
proxy=self.proxy,
|
||||
transport=PaddedRSTransport,
|
||||
) as session:
|
||||
asyncio_transport = session.transport._asyncio_transport # type: asyncio.BaseTransport
|
||||
ssl_object = asyncio_transport.get_extra_info("ssl_object") # type: ssl.SSLObject
|
||||
return ssl_object.getpeercert(binary_form=True)
|
||||
@@ -787,9 +788,13 @@ class Interface(Logger):
|
||||
|
||||
async def open_session(self, sslc, exit_early=False):
|
||||
session_factory = lambda *args, iface=self, **kwargs: NotificationSession(*args, **kwargs, interface=iface)
|
||||
async with _RSClient(session_factory=session_factory,
|
||||
host=self.host, port=self.port,
|
||||
ssl=sslc, proxy=self.proxy) as session:
|
||||
async with _RSClient(
|
||||
session_factory=session_factory,
|
||||
host=self.host, port=self.port,
|
||||
ssl=sslc,
|
||||
proxy=self.proxy,
|
||||
transport=PaddedRSTransport,
|
||||
) as session:
|
||||
self.session = session # type: NotificationSession
|
||||
self.session.set_default_timeout(self.network.get_network_timeout_seconds(NetworkTimeout.Generic))
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user