diff --git a/electrum/interface.py b/electrum/interface.py index 4b5d0c818..cf6d52821 100644 --- a/electrum/interface.py +++ b/electrum/interface.py @@ -140,6 +140,7 @@ class NotificationSession(RPCSession): self.cache = {} self._msg_counter = itertools.count(start=1) self.interface = interface + self.taskgroup = interface.taskgroup self.cost_hard_limit = 0 # disable aiorpcx resource limits async def handle_request(self, request): @@ -267,11 +268,6 @@ class ConnectError(NetworkException): pass class _RSClient(RSClient): - def __init__(self, *, transport=None, **kwargs): - if transport is None: - transport = PaddedRSTransport - RSClient.__init__(self, transport=transport, **kwargs) - async def create_connection(self): try: return await super().create_connection() @@ -310,7 +306,8 @@ class PaddedRSTransport(RSTransport): self._sbuffer_has_data_evt.set() self._maybe_consume_sbuffer() - def _maybe_consume_sbuffer(self): + def _maybe_consume_sbuffer(self) -> None: + """Maybe take some data from sbuffer and send it on the wire.""" if not self._can_send.is_set() or self.is_closing(): return buf = self._sbuffer @@ -376,7 +373,7 @@ class PaddedRSTransport(RSTransport): def connection_made(self, transport: asyncio.BaseTransport): super().connection_made(transport) if isinstance(self.session, NotificationSession): - coro = self.session.interface.taskgroup.spawn(self._poll_sbuffer()) + coro = self.session.taskgroup.spawn(self._poll_sbuffer()) self._sbuffer_task = self.loop.create_task(coro) else: # This a short-lived "fetch_certificate"-type session. @@ -716,9 +713,13 @@ class Interface(Logger): sslc = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT) sslc.check_hostname = False sslc.verify_mode = ssl.CERT_NONE - async with _RSClient(session_factory=RPCSession, - host=self.host, port=self.port, - ssl=sslc, proxy=self.proxy) as session: + async with _RSClient( + session_factory=RPCSession, + host=self.host, port=self.port, + ssl=sslc, + proxy=self.proxy, + transport=PaddedRSTransport, + ) as session: asyncio_transport = session.transport._asyncio_transport # type: asyncio.BaseTransport ssl_object = asyncio_transport.get_extra_info("ssl_object") # type: ssl.SSLObject return ssl_object.getpeercert(binary_form=True) @@ -787,9 +788,13 @@ class Interface(Logger): async def open_session(self, sslc, exit_early=False): session_factory = lambda *args, iface=self, **kwargs: NotificationSession(*args, **kwargs, interface=iface) - async with _RSClient(session_factory=session_factory, - host=self.host, port=self.port, - ssl=sslc, proxy=self.proxy) as session: + async with _RSClient( + session_factory=session_factory, + host=self.host, port=self.port, + ssl=sslc, + proxy=self.proxy, + transport=PaddedRSTransport, + ) as session: self.session = session # type: NotificationSession self.session.set_default_timeout(self.network.get_network_timeout_seconds(NetworkTimeout.Generic)) try: