diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 5270561b4..a5200c9a0 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -1455,10 +1455,12 @@ class Peer(Logger, EventListener): # until this msg is processed. If we are behind (lost state), and send chan_reest to the remote, # when the remote realizes we are behind, they might send an "error" message - but the spec mandates # they send chan_reest first. If we processed the error first, we might force-close and lose money! + # FIXME there are a lot of "SHOULD send an error and fail the channel" BOLT-02 cases here + # where we don't send the error, but directly fail the channel their_next_local_ctn = msg["next_commitment_number"] their_oldest_unrevoked_remote_ctn = msg["next_revocation_number"] - their_local_pcp = msg.get("my_current_per_commitment_point") - their_claim_of_our_last_per_commitment_secret = msg.get("your_last_per_commitment_secret") + their_local_pcp = msg["my_current_per_commitment_point"] + their_claim_of_our_last_per_commitment_secret = msg["your_last_per_commitment_secret"] self.logger.info( f'channel_reestablish ({chan.get_id_for_log()}): received channel_reestablish with ' f'(their_next_local_ctn={their_next_local_ctn}, ' @@ -1469,21 +1471,18 @@ class Peer(Logger, EventListener): f"chan={chan.get_id_for_log()}. {chan.get_state()=!r}. {chan.peer_state=!r}") return # sanity checks of received values - if their_next_local_ctn < 0: - raise RemoteMisbehaving(f"channel reestablish: their_next_local_ctn < 0") - if their_oldest_unrevoked_remote_ctn < 0: - raise RemoteMisbehaving(f"channel reestablish: their_oldest_unrevoked_remote_ctn < 0") + assert their_next_local_ctn >= 0 # already done by lnmsg, as type is u64 + assert their_oldest_unrevoked_remote_ctn >= 0 # ctns oldest_unrevoked_local_ctn = chan.get_oldest_unrevoked_ctn(LOCAL) - latest_local_ctn = chan.get_latest_ctn(LOCAL) - next_local_ctn = chan.get_next_ctn(LOCAL) - oldest_unrevoked_remote_ctn = chan.get_oldest_unrevoked_ctn(REMOTE) latest_remote_ctn = chan.get_latest_ctn(REMOTE) next_remote_ctn = chan.get_next_ctn(REMOTE) # compare remote ctns we_are_ahead = False - they_are_ahead = False + they_are_ahead_with_proof = False + they_are_ahead_without_proof = False we_must_resend_revoke_and_ack = False + # check "next_commitment_number" if next_remote_ctn != their_next_local_ctn: if their_next_local_ctn == latest_remote_ctn and chan.hm.is_revack_pending(REMOTE): # We will replay the local updates (see reestablish_channel), which should contain a commitment_signed @@ -1496,8 +1495,8 @@ class Peer(Logger, EventListener): if their_next_local_ctn < next_remote_ctn: we_are_ahead = True else: - they_are_ahead = True - # compare local ctns + they_are_ahead_without_proof = True + # check "next_revocation_number" if oldest_unrevoked_local_ctn != their_oldest_unrevoked_remote_ctn: if oldest_unrevoked_local_ctn - 1 == their_oldest_unrevoked_remote_ctn: # A node: @@ -1512,12 +1511,10 @@ class Peer(Logger, EventListener): if their_oldest_unrevoked_remote_ctn < oldest_unrevoked_local_ctn: we_are_ahead = True else: - they_are_ahead = True - # option_data_loss_protect + they_are_ahead_with_proof = True # the claimed value will be checked against DLP + # option_data_loss_protect (DLP) assert self.features.supports(LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT) def are_datalossprotect_fields_valid() -> bool: - if their_local_pcp is None or their_claim_of_our_last_per_commitment_secret is None: - return False if their_oldest_unrevoked_remote_ctn > 0: our_pcs, __ = chan.get_secret_and_point(LOCAL, their_oldest_unrevoked_remote_ctn - 1) else: @@ -1531,12 +1528,13 @@ class Peer(Logger, EventListener): assert chan.is_static_remotekey_enabled() return True if not are_datalossprotect_fields_valid(): + self.schedule_force_closing(chan.channel_id) raise RemoteMisbehaving("channel_reestablish: data loss protect fields invalid") fut = self.channel_reestablish_msg[chan.channel_id] - if they_are_ahead: + if they_are_ahead_with_proof: # order matters, WE_ARE_TOXIC case must be checked first. self.logger.warning( f"channel_reestablish ({chan.get_id_for_log()}): " - f"remote is ahead of us! They should force-close. Remote PCP: {their_local_pcp.hex()}") + f"remote is ahead of us (with proof)! They should force-close.") # data_loss_protect_remote_pcp is used in lnsweep chan.set_data_loss_protect_remote_pcp(their_next_local_ctn - 1, their_local_pcp) chan.set_state(ChannelState.WE_ARE_TOXIC) @@ -1544,7 +1542,14 @@ class Peer(Logger, EventListener): chan.peer_state = PeerState.BAD # raise after we send channel_reestablish, so the remote can realize they are ahead # FIXME what if we have multiple chans with peer? timing... - fut.set_exception(GracefulDisconnect("remote ahead of us")) + fut.set_exception(GracefulDisconnect("remote ahead of us (with proof)")) + elif they_are_ahead_without_proof: + self.logger.warning( + f"channel_reestablish ({chan.get_id_for_log()}): " + f"remote is ahead of us (without proof)! trying to force-close.") + self.schedule_force_closing(chan.channel_id) + # FIXME what if we have multiple chans with peer? timing... + fut.set_exception(GracefulDisconnect("remote ahead of us (without proof)")) elif we_are_ahead: self.logger.warning(f"channel_reestablish ({chan.get_id_for_log()}): we are ahead of remote! trying to force-close.") self.schedule_force_closing(chan.channel_id) diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 8669931c2..43ccef6d9 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -19,6 +19,7 @@ import statistics from aiorpcx import timeout_after, TaskTimeout from electrum_ecc import ECPrivkey +import electrum_ecc as ecc import electrum import electrum.trampoline @@ -45,7 +46,7 @@ from electrum import lnmsg from electrum.logging import console_stderr_handler, Logger from electrum.lnworker import PaymentInfo from electrum.lnonion import OnionFailureCode, OnionRoutingFailure, OnionHopsDataSingle, OnionPacket -from electrum.lnutil import LOCAL, REMOTE, UpdateAddHtlc, RecvMPPResolution +from electrum.lnutil import LOCAL, REMOTE, UpdateAddHtlc, RecvMPPResolution, RevocationStore from electrum.invoices import PR_PAID, PR_UNPAID, Invoice, LN_EXPIRY_NEVER from electrum.interface import GracefulDisconnect from electrum.simple_config import SimpleConfig @@ -713,6 +714,99 @@ class TestPeerDirect(TestPeer): with self.subTest(msg="bob is slow"): await f(alice_slow=False, bob_slow=True) + async def test_reestablish_fake_data(self): + async def f( + alice_slow: bool, + bob_slow: bool, + *, + ctn_delta: int = 0, + revnum_delta: int = 0, + last_rev_secret: bytes = None, + ) -> tuple[Channel, Channel]: + alice_lnwallet, bob_lnwallet = self.prepare_lnwallets(self.GRAPH_DEFINITIONS['single_chan']).values() + alice_channel, bob_channel = create_test_channels(alice_lnwallet=alice_lnwallet, bob_lnwallet=bob_lnwallet) + p1, p2, w1, w2 = self.prepare_peers(alice_channel, bob_channel) + # first make some payments, to bump the channel ctns a bit + async def pay(): + for pnum in range(2): + lnaddr, pay_req = self.prepare_invoice(w2) + result, log = await w1.pay_invoice(pay_req) + self.assertEqual(result, True) + gath.cancel() + gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch()) + with self.assertRaises(asyncio.CancelledError): + await gath + for chan in (alice_channel, bob_channel): + chan.peer_state = PeerState.DISCONNECTED + + # now reestablish the channel + async def alice_sends_reest(): + nonlocal last_rev_secret + if alice_slow: await asyncio.sleep(0.05) + chan = alice_channel + next_local_ctn = chan.get_next_ctn(LOCAL) + ctn_delta + assert next_local_ctn >= 0, next_local_ctn + oldest_unrevoked_remote_ctn = chan.get_oldest_unrevoked_ctn(REMOTE) + revnum_delta + assert oldest_unrevoked_remote_ctn >= 0, oldest_unrevoked_remote_ctn + if last_rev_secret is None: + if revnum_delta <= 0: + last_rev_secret = chan.revocation_store.retrieve_secret(RevocationStore.START_INDEX - oldest_unrevoked_remote_ctn + 1) + else: # Alice is using *magic* here, i.e. cheating: she uses Bob's channel to learn future unrevealed secrets + last_rev_secret, _point = bob_channel.get_secret_and_point(LOCAL, oldest_unrevoked_remote_ctn - 1) + p1.send_message( + "channel_reestablish", + channel_id=chan.channel_id, + next_commitment_number=next_local_ctn, + next_revocation_number=oldest_unrevoked_remote_ctn, + your_last_per_commitment_secret=last_rev_secret, + my_current_per_commitment_point=ecc.GENERATOR.get_public_key_bytes(compressed=True), + ) + async def bob_sends_reest(): + if bob_slow: await asyncio.sleep(0.05) + await p2.reestablish_channel(bob_channel) + + async def exit_after_bob_receives_reest(): + await p2.channel_reestablish_msg[bob_channel.channel_id] + + with self.assertRaises((GracefulDisconnect, lnutil.RemoteMisbehaving)): + async with OldTaskGroup() as group: + await group.spawn(p1._message_loop()) + await group.spawn(p1.htlc_switch()) + await group.spawn(p2._message_loop()) + await group.spawn(p2.htlc_switch()) + await p1.initialized + await p2.initialized + await group.spawn(alice_sends_reest) + await group.spawn(bob_sends_reest) + await group.spawn(exit_after_bob_receives_reest) + return alice_channel, bob_channel + + cs = ChannelState + for (alice_slow, bob_slow) in ( + (False, False), # both fast: FIXME: we want to test the case where both Alice and Bob sends channel-reestablish before + # receiving what the other sent. This is not a reliable way to do that... + (True, False), + (False, True), + ): + kwargs = {"alice_slow": alice_slow, "bob_slow": bob_slow} + # note: Alice's channel state will stay OPEN in every case: + # she is intentionally sending weird data that does not reflect her true state. + with self.subTest(msg="next_local_ctn from past", **kwargs): + a_chan, b_chan = await f(ctn_delta=-2, **kwargs) + self.assertEqual((a_chan._state, b_chan._state), (cs.OPEN, cs.FORCE_CLOSING)) + with self.subTest(msg="next_local_ctn from future", **kwargs): + a_chan, b_chan = await f(ctn_delta=1000, **kwargs) + self.assertEqual((a_chan._state, b_chan._state), (cs.OPEN, cs.FORCE_CLOSING)) + with self.subTest(msg="oldest_unrevoked_remote_ctn from past", **kwargs): + a_chan, b_chan = await f(revnum_delta=-2, **kwargs) + self.assertEqual((a_chan._state, b_chan._state), (cs.OPEN, cs.FORCE_CLOSING)) + with self.subTest(msg="oldest_unrevoked_remote_ctn from future", **kwargs): + a_chan, b_chan = await f(revnum_delta=1000, **kwargs) + self.assertEqual((a_chan._state, b_chan._state), (cs.OPEN, cs.WE_ARE_TOXIC)) + with self.subTest(msg="invalid last_rev_secret", **kwargs): + a_chan, b_chan = await f(last_rev_secret=sha256("fake_data"), **kwargs) + self.assertEqual((a_chan._state, b_chan._state), (cs.OPEN, cs.FORCE_CLOSING)) + @staticmethod def _send_fake_htlc(peer: Peer, chan: Channel) -> UpdateAddHtlc: htlc = UpdateAddHtlc(amount_msat=10000, payment_hash=os.urandom(32), cltv_abs=999, timestamp=1)