diff --git a/electrum/onion_message.py b/electrum/onion_message.py index 8e4baffc8..e94632937 100644 --- a/electrum/onion_message.py +++ b/electrum/onion_message.py @@ -31,7 +31,7 @@ import dataclasses from random import random from types import MappingProxyType -from typing import TYPE_CHECKING, Optional, Sequence, NamedTuple, Tuple +from typing import TYPE_CHECKING, Optional, Sequence, NamedTuple, Tuple, Union import electrum_ecc as ecc @@ -496,8 +496,7 @@ class OnionMessageManager(Logger): - forwards are best-effort. They should not need retrying, but a queue is used to limit the pacing of forwarding, and limiting the number of outstanding forwards. Any onion message forwards arriving when the forward queue is full will be dropped. - - TODO: iterate through routes for each request""" + """ SLEEP_DELAY = 1 REQUEST_REPLY_TIMEOUT = 30 @@ -506,10 +505,12 @@ class OnionMessageManager(Logger): FORWARD_RETRY_DELAY = 2 FORWARD_MAX_QUEUE = 3 - class Request(NamedTuple): - future: asyncio.Future - payload: dict - node_id_or_blinded_path: bytes + class Request: + def __init__(self, *, payload: dict, node_id_or_blinded_paths: Union[bytes, Sequence[bytes]]): + self.future = asyncio.Future() + self.payload = payload + self.node_id_or_blinded_paths = node_id_or_blinded_paths + self.current_index: int = 0 def __init__(self, lnwallet: 'LNWallet'): Logger.__init__(self) @@ -623,8 +624,8 @@ class OnionMessageManager(Logger): def submit_send( self, *, payload: dict, - node_id_or_blinded_path: bytes, - key: bytes = None) -> 'Task': + node_id_or_blinded_paths: Union[bytes, Sequence[bytes]], + key: Optional[bytes] = None) -> 'Task': """Add onion message to queue for sending. Queued onion message payloads are supplied with a path_id and a reply_path to determine which request corresponds with arriving replies. @@ -636,13 +637,9 @@ class OnionMessageManager(Logger): key = os.urandom(8) assert type(key) is bytes and len(key) >= 8 - self.logger.debug(f'submit_send {key=} {payload=} {node_id_or_blinded_path=}') + self.logger.debug(f'submit_send {key=} {payload=} {node_id_or_blinded_paths=}') - req = OnionMessageManager.Request( - future=asyncio.Future(), - payload=payload, - node_id_or_blinded_path=node_id_or_blinded_path - ) + req = OnionMessageManager.Request(payload=payload, node_id_or_blinded_paths=node_id_or_blinded_paths) with self.pending_lock: if key in self.pending: raise Exception(f'{key=} already exists!') @@ -665,8 +662,15 @@ class OnionMessageManager(Logger): """adds reply_path to payload""" req = self.pending.get(key) payload = req.payload - node_id_or_blinded_path = req.node_id_or_blinded_path - self.logger.debug(f'send_pending_message {key=} {payload=} {node_id_or_blinded_path=}') + + # get next path (round robin) + dests = req.node_id_or_blinded_paths + if isinstance(req.node_id_or_blinded_paths, bytes): + dests = [req.node_id_or_blinded_paths] + dest = dests[req.current_index] + req.current_index = (req.current_index + 1) % len(dests) + + self.logger.debug(f'send_pending_message {key=} {payload=} {dest=}') final_payload = copy.deepcopy(payload) @@ -679,9 +683,10 @@ class OnionMessageManager(Logger): final_payload['reply_path'] = {'path': reply_paths} - # TODO: we should try alternate paths when retrying, this is currently not done. + # NOTE: we could also try alternate paths to introduction point (the non-blinded part of the route) + # when retrying, this is currently not done. # (send_onion_message_to decides path, without knowledge of prev attempts) - send_onion_message_to(self.lnwallet, node_id_or_blinded_path, final_payload) + send_onion_message_to(self.lnwallet, dest, final_payload) def _path_id_from_payload_and_key(self, payload: dict, key: bytes) -> bytes: # TODO: use payload to determine prefix? diff --git a/tests/test_onion_message.py b/tests/test_onion_message.py index d261bbe6d..a4be179b7 100644 --- a/tests/test_onion_message.py +++ b/tests/test_onion_message.py @@ -314,11 +314,12 @@ class TestOnionMessageManager(ElectrumTestCase): self.carol = keypair(ECPrivkey(privkey_bytes=b'\x43'*32)) self.dave = keypair(ECPrivkey(privkey_bytes=b'\x44'*32)) self.eve = keypair(ECPrivkey(privkey_bytes=b'\x45'*32)) + self.fred = keypair(ECPrivkey(privkey_bytes=b'\x46'*32)) async def run_test1(self, t): t1 = t.submit_send( payload={'message': {'text': 'alice_timeout'.encode('utf-8')}}, - node_id_or_blinded_path=self.alice.pubkey) + node_id_or_blinded_paths=self.alice.pubkey) with self.assertRaises(Timeout): await t1 @@ -326,7 +327,7 @@ class TestOnionMessageManager(ElectrumTestCase): async def run_test2(self, t): t2 = t.submit_send( payload={'message': {'text': 'bob_slow_timeout'.encode('utf-8')}}, - node_id_or_blinded_path=self.bob.pubkey) + node_id_or_blinded_paths=self.bob.pubkey) with self.assertRaises(Timeout): await t2 @@ -334,7 +335,7 @@ class TestOnionMessageManager(ElectrumTestCase): async def run_test3(self, t, rkey): t3 = t.submit_send( payload={'message': {'text': 'carol_with_immediate_reply'.encode('utf-8')}}, - node_id_or_blinded_path=self.carol.pubkey, + node_id_or_blinded_paths=self.carol.pubkey, key=rkey) t3_result = await t3 @@ -343,7 +344,7 @@ class TestOnionMessageManager(ElectrumTestCase): async def run_test4(self, t, rkey): t4 = t.submit_send( payload={'message': {'text': 'dave_with_slow_reply'.encode('utf-8')}}, - node_id_or_blinded_path=self.dave.pubkey, + node_id_or_blinded_paths=self.dave.pubkey, key=rkey) t4_result = await t4 @@ -352,13 +353,24 @@ class TestOnionMessageManager(ElectrumTestCase): async def run_test5(self, t): t5 = t.submit_send( payload={'message': {'text': 'no_peer'.encode('utf-8')}}, - node_id_or_blinded_path=self.eve.pubkey) + node_id_or_blinded_paths=self.eve.pubkey) # will not find route to eve, but has eve's address, but we are configured to not direct connect with self.assertRaises(NoRouteFound) as c: await t5 self.assertEqual(c.exception.peer_address, LNPeerAddr('localhost', 1234, self.eve.pubkey)) + async def run_test6(self, t, rkey): + # bob will not reply, fred will + t6 = t.submit_send( + payload={'message': {'text': 'send_dest_roundrobin'.encode('utf-8')}}, + node_id_or_blinded_paths=[self.bob.pubkey, self.fred.pubkey], + key=rkey + ) + + t6_result = await t6 + self.assertEqual(t6_result, ({'path_id': {'data': b'electrum' + rkey}}, {})) + async def test_request_and_reply(self): n = MockNetwork() lnw = self.create_mock_lnwallet(name='test_request_and_reply', has_anchors=False) @@ -382,12 +394,14 @@ class TestOnionMessageManager(ElectrumTestCase): rkey1 = bfh('0102030405060708') rkey2 = bfh('0102030405060709') + rkey3 = bfh('010203040506070a') lnw.lnpeermgr._peers[self.alice.pubkey] = MockPeer(self.alice.pubkey) lnw.lnpeermgr._peers[self.bob.pubkey] = MockPeer(self.bob.pubkey, on_send_message=slow) lnw.lnpeermgr._peers[self.carol.pubkey] = MockPeer(self.carol.pubkey, on_send_message=partial(withreply, rkey1)) lnw.lnpeermgr._peers[self.dave.pubkey] = MockPeer(self.dave.pubkey, on_send_message=partial(slowwithreply, rkey2)) lnw.channel_db._addresses[self.eve.pubkey] = {NetAddress('localhost', '1234'): int(time.time())} + lnw.lnpeermgr._peers[self.fred.pubkey] = MockPeer(self.fred.pubkey, on_send_message=partial(withreply, rkey3)) t = OnionMessageManager(lnw) t.start_network(network=n) @@ -399,6 +413,7 @@ class TestOnionMessageManager(ElectrumTestCase): await self.run_test3(t, rkey1) await self.run_test4(t, rkey2) await self.run_test5(t) + await self.run_test6(t, rkey3) self.logger.debug('tests in parallel') async with OldTaskGroup() as group: await group.spawn(self.run_test1(t)) @@ -406,6 +421,7 @@ class TestOnionMessageManager(ElectrumTestCase): await group.spawn(self.run_test3(t, rkey1)) await group.spawn(self.run_test4(t, rkey2)) await group.spawn(self.run_test5(t)) + await group.spawn(self.run_test6(t, rkey3)) finally: await asyncio.sleep(TIME_STEP)