diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py index 06f2a943c..bf9b80717 100644 --- a/electrum/lnchannel.py +++ b/electrum/lnchannel.py @@ -1101,8 +1101,6 @@ class Channel(AbstractChannel): util.trigger_callback('channel', self.lnworker.wallet, self) def is_frozen_for_receiving(self) -> bool: - if self.lnworker.uses_trampoline() and not self.lnworker.is_trampoline_peer(self.node_id): - return True return self.storage.get('frozen_for_receiving', False) def set_frozen_for_receiving(self, b: bool) -> None: diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 01b86f738..b532e5a19 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -2545,7 +2545,8 @@ class LNWallet(Logger): def _get_invoice_features(self, amount_msat: Optional[int]) -> LnFeatures: invoice_features = self.features.for_invoice() - if not self.uses_trampoline(): + if not all((not c.is_open() or c.is_frozen_for_receiving()) or self.is_trampoline_peer(c.node_id) \ + for c in self.channels.values()): invoice_features &= ~ LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM needs_jit: bool = self.receive_requires_jit_channel(amount_msat) if needs_jit: @@ -2572,7 +2573,13 @@ class LNWallet(Logger): assert amount_msat is None or amount_msat > 0 timestamp = int(time.time()) - routing_hints = self.calc_routing_hints_for_invoice(amount_msat, channels=channels) + routing_hints = self.calc_routing_hints_for_invoice( + amount_msat, + channels=channels, + # if the invoice_features signal trampoline support all included r_tags should support trampoline forwarding + # TODO: make invoice_features dynamic depending on available trampoline channels + only_trampoline=payment_info.invoice_features.supports(LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM), + ) formatted_r_hints = LnAddr.format_bolt11_routing_info_as_human_readable(routing_hints, has_explicit_r_tagtype=True) self.logger.info(f"creating bolt11 invoice with routing_hints: {formatted_r_hints}, sat: {(amount_msat or 0) // 1000}") payment_secret = self.get_payment_secret(payment_info.payment_hash) @@ -3159,7 +3166,7 @@ class LNWallet(Logger): else: self.logger.info(f'htlc_failed: waiting for other htlcs to fail (phash={payment_hash.hex()})') - def calc_routing_hints_for_invoice(self, amount_msat: Optional[int], channels=None): + def calc_routing_hints_for_invoice(self, amount_msat: Optional[int], *, channels=None, only_trampoline: bool = False): """calculate routing hints (BOLT-11 'r' field)""" routing_hints = [] if self.receive_requires_jit_channel(amount_msat): @@ -3178,6 +3185,8 @@ class LNWallet(Logger): if chan.short_channel_id is not None } for chan in channels: + if only_trampoline and not self.is_trampoline_peer(chan.node_id): + continue alias_or_scid = chan.get_remote_scid_alias() or chan.short_channel_id assert isinstance(alias_or_scid, bytes), alias_or_scid channel_info = get_mychannel_info(chan.short_channel_id, scid_to_my_channels) diff --git a/tests/test_lnwallet.py b/tests/test_lnwallet.py index f50d0d64c..1b7637dee 100644 --- a/tests/test_lnwallet.py +++ b/tests/test_lnwallet.py @@ -1,9 +1,12 @@ import logging import os +import electrum.trampoline from . import ElectrumTestCase +from .test_lnchannel import create_test_channels -from electrum.lnutil import RECEIVED, MIN_FINAL_CLTV_DELTA_ACCEPTED +from electrum.lnutil import RECEIVED, MIN_FINAL_CLTV_DELTA_ACCEPTED, LnFeatures +from electrum.lntransport import LNPeerAddr from electrum.logging import console_stderr_handler from electrum.invoices import LN_EXPIRY_NEVER, PR_UNPAID @@ -52,3 +55,92 @@ class TestLNWallet(ElectrumTestCase): min_final_cltv_delta=min_final_cltv_delta, exp_delay=exp_delay, ) + + async def test_trampoline_invoice_features_and_routing_hints(self): + """ + When the invoice_features signal trampoline support, routing hints must only + contain trampoline nodes. When it does not, all channel can be added as r_tags. + We only signal trampoline support in the invoice if all open channels do support trampoline. + """ + wallet = self.lnwallet_anchors + self.assertFalse(wallet.uses_trampoline()) + + trampoline_peer = self.create_mock_lnwallet(name='trampoline_peer', has_anchors=True) + trampoline_pubkey = trampoline_peer.node_keypair.pubkey + + regular_peer = self.create_mock_lnwallet(name='regular_peer', has_anchors=True) + regular_pubkey = regular_peer.node_keypair.pubkey + + chan_t, _ = create_test_channels(alice_lnwallet=wallet, bob_lnwallet=trampoline_peer, anchor_outputs=True) + chan_r, _ = create_test_channels(alice_lnwallet=wallet, bob_lnwallet=regular_peer, anchor_outputs=True) + wallet._add_channel(chan_t) + wallet._add_channel(chan_r) + + # only trampoline_peer is a known trampoline forwarder + electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = { + 'trampoline_peer': LNPeerAddr( + host="127.0.0.1", + port=9735, + pubkey=trampoline_pubkey, + ), + } + self.addCleanup(lambda: electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS.clear()) + + amount_msat = 100_000 + + # mixed peers: trampoline feature must be stripped, all peers in hints + payment_hash = wallet.create_payment_info(amount_msat=amount_msat) + pi = wallet.get_payment_info(payment_hash, direction=RECEIVED) + self.assertFalse( + pi.invoice_features.supports(LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM), + "trampoline bit should be stripped when not all peers are trampoline", + ) + + lnaddr, _ = wallet.get_bolt11_invoice(payment_info=pi, message='test', fallback_address=None) + hint_node_ids = {route[0][0] for route in lnaddr.get_routing_info('r')} + self.assertEqual(hint_node_ids, {trampoline_pubkey, regular_pubkey}) + + # trampoline feature should not be set if we use trampoline but one peer is not a trampoline + old_check, wallet.uses_trampoline = wallet.uses_trampoline, lambda: True + self.assertTrue(wallet.uses_trampoline()) + + payment_hash = wallet.create_payment_info(amount_msat=amount_msat) + pi = wallet.get_payment_info(payment_hash, direction=RECEIVED) + self.assertFalse( + pi.invoice_features.supports(LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM), + "trampoline feature should not be set if we use trampoline but one peer is not a trampoline", + ) + + wallet.clear_invoices_cache() + lnaddr, _ = wallet.get_bolt11_invoice(payment_info=pi, message='test', fallback_address=None) + hint_node_ids = {route[0][0] for route in lnaddr.get_routing_info('r')} + self.assertEqual(hint_node_ids, {trampoline_pubkey, regular_pubkey}) + + wallet.uses_trampoline = old_check + self.assertFalse(wallet.uses_trampoline()) + + # all peers trampoline: we signal trampoline support, even with trampoline disabled + electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS['regular_peer'] = LNPeerAddr( + host="127.0.0.1", + port=9735, + pubkey=regular_pubkey, + ) + + payment_hash2 = wallet.create_payment_info(amount_msat=amount_msat) + pi2 = wallet.get_payment_info(payment_hash2, direction=RECEIVED) + self.assertTrue( + pi2.invoice_features.supports(LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM), + "trampoline bit should be present when all peers are trampoline", + ) + + wallet.clear_invoices_cache() + lnaddr2, _ = wallet.get_bolt11_invoice(payment_info=pi2, message='test', fallback_address=None) + hint_node_ids2 = {route[0][0] for route in lnaddr2.get_routing_info('r')} + self.assertEqual(hint_node_ids2, {trampoline_pubkey, regular_pubkey}) + + # assert only trampoline peers are included in r_tags if the invoice_features signal trampoline + del electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS['regular_peer'] + wallet.clear_invoices_cache() + lnaddr3, _ = wallet.get_bolt11_invoice(payment_info=pi2, message='test', fallback_address=None) + hint_node_ids3 = {route[0][0] for route in lnaddr3.get_routing_info('r')} + self.assertEqual(hint_node_ids3, {trampoline_pubkey})