diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py index ea47dec90..06f2a943c 100644 --- a/electrum/lnchannel.py +++ b/electrum/lnchannel.py @@ -1479,7 +1479,7 @@ class Channel(AbstractChannel): # small value htlcs: even a large htlc might not appear in the outgoing channel's ctx, e.g. maybe it was # not committed yet - we should still make sure it gets removed on the incoming channel. (see #9631) if preimage: - self.lnworker.save_preimage(payment_hash, preimage) + self.lnworker.save_preimage(payment_hash, preimage, mark_as_public=True) for htlc, is_sent in found.values(): if is_sent: self.lnworker.htlc_fulfilled(self, payment_hash, htlc.htlc_id) @@ -1720,6 +1720,7 @@ class Channel(AbstractChannel): assert htlc_id not in self.hm.log[REMOTE]['settles'] self.hm.send_settle(htlc_id) self.htlc_settle_time[htlc_id] = now() + self.lnworker.save_preimage(htlc.payment_hash, preimage, mark_as_public=True) def get_payment_hash(self, htlc_id: int) -> bytes: htlc = self.hm.get_htlc_by_id(LOCAL, htlc_id) @@ -1737,6 +1738,7 @@ class Channel(AbstractChannel): assert htlc_id not in self.hm.log[LOCAL]['settles'] with self.db_lock: self.hm.recv_settle(htlc_id) + self.lnworker.save_preimage(htlc.payment_hash, preimage, mark_as_public=True) def fail_htlc(self, htlc_id: int) -> None: """Fail a pending received HTLC. diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 5e83affdb..820181fb9 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -1983,7 +1983,6 @@ class Peer(Logger, EventListener): f"chan={chan.get_id_for_log()}. {htlc_id=}. {chan.get_state()=!r}. {chan.peer_state=!r}") return chan.receive_htlc_settle(preimage, htlc_id) # TODO handle exc and maybe fail channel (e.g. bad htlc_id) - self.lnworker.save_preimage(payment_hash, preimage) self.maybe_send_commitment(chan) def on_update_fail_malformed_htlc(self, chan: Channel, payload): diff --git a/electrum/lnsweep.py b/electrum/lnsweep.py index 761f7061e..42c2bc482 100644 --- a/electrum/lnsweep.py +++ b/electrum/lnsweep.py @@ -480,14 +480,22 @@ def _maybe_reveal_preimage_for_htlc( htlc: 'UpdateAddHtlc', sweep_info_name: str, ) -> Tuple[Optional[bytes], Optional[KeepWatchingTXO]]: - """Given a Remote-added-HTLC, return the preimage if it's okay to reveal it on-chain.""" - if not chan.lnworker.is_complete_mpp(htlc.payment_hash): + """Given a Remote-added-HTLC, return the preimage if it's okay to reveal it on-chain. + + note: to be safe, even if we don't/can't reveal the preimage now, we should tell lnwatcher to + keep watching this HTLC at least until its CLTV, in case circumstances change. + """ + if not chan.lnworker.is_preimage_public(htlc.payment_hash) and not chan.lnworker.is_complete_mpp(htlc.payment_hash): # - do not redeem this, it might publish the preimage of an incomplete MPP # - OTOH maybe this chan just got closed, and we are still receiving new htlcs # for this MPP set. So the MPP set might still transition to complete! # The MPP_TIMEOUT is only around 2 minutes, so this window is short. # The default keep_watching logic in lnwatcher is sufficient to call us again. - return None, None + keep_watching_txo = KeepWatchingTXO( + name=sweep_info_name + "_preimage_not_public", + until_height=htlc.cltv_abs, + ) + return None, keep_watching_txo if htlc.payment_hash.hex() in chan.lnworker.dont_settle_htlcs: # we should not reveal the preimage *for now*, but we might still decide to reveal it later keep_watching_txo = KeepWatchingTXO( @@ -496,6 +504,15 @@ def _maybe_reveal_preimage_for_htlc( ) return None, keep_watching_txo preimage = chan.lnworker.get_preimage(htlc.payment_hash) + if preimage is None: + keep_watching_txo = KeepWatchingTXO( + name=sweep_info_name + "_preimage_missing", + until_height=htlc.cltv_abs, + ) + return None, keep_watching_txo + # this preimage will be revealed + assert preimage + chan.lnworker.save_preimage(htlc.payment_hash, preimage, mark_as_public=True) return preimage, None diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 20253a1b1..462eaf507 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -1022,7 +1022,7 @@ class LNWallet(Logger): self.lnrater: LNRater = None # "RHASH:direction" -> amount_msat, status, min_final_cltv_delta, expiry_delay, creation_ts, invoice_features self.payment_info = self.db.get_dict('lightning_payments') # type: dict[str, Tuple[Optional[int], int, int, int, int, int]] - self._preimages = self.db.get_dict('lightning_preimages') # RHASH -> preimage + self._preimages = self.db.get_dict('lightning_preimages') # RHASH -> (preimage, is_public) self._bolt11_cache = {} # note: this sweep_address is only used as fallback; as it might result in address-reuse self.logs = defaultdict(list) # type: Dict[str, List[HtlcLog]] # key is RHASH # (not persisted) @@ -2699,19 +2699,35 @@ class LNWallet(Logger): del self._payment_bundles_pkey_to_canon[pkey] del self._payment_bundles_canon_to_pkeylist[canon_pkey] - def save_preimage(self, payment_hash: bytes, preimage: bytes, *, write_to_disk: bool = True): + def save_preimage( + self, + payment_hash: bytes, + preimage: bytes, + *, + write_to_disk: bool = True, + mark_as_public: Optional[bool] = None, # see is_preimage_public + ): + assert isinstance(payment_hash, bytes), f"expected bytes, but got {type(payment_hash)}" + assert isinstance(preimage, bytes), f"expected bytes, but got {type(preimage)}" if sha256(preimage) != payment_hash: raise Exception("tried to save incorrect preimage for payment_hash") - if self._preimages.get(payment_hash.hex()) is not None: - return # we already have this preimage - self.logger.debug(f"saving preimage for {payment_hash.hex()}") - self._preimages[payment_hash.hex()] = preimage.hex() + old_tuple = _, old_is_public = self._preimages.get(payment_hash.hex(), (None, None)) + if mark_as_public is None: # if unset, keep current DB value + mark_as_public = old_is_public or False + if old_is_public and not mark_as_public: + raise Exception("preimage mark_as_public: True->False transition is forbidden") + # sanity checks and conversions done. + new_tuple = preimage.hex(), mark_as_public + if old_tuple == new_tuple: # no change + return + self.logger.debug(f"saving preimage for {payment_hash.hex()} (public={mark_as_public})") + self._preimages[payment_hash.hex()] = new_tuple if write_to_disk: self.wallet.save_db() def get_preimage(self, payment_hash: bytes) -> Optional[bytes]: assert isinstance(payment_hash, bytes), f"expected bytes, but got {type(payment_hash)}" - preimage_hex = self._preimages.get(payment_hash.hex()) + preimage_hex, _ = self._preimages.get(payment_hash.hex(), (None, None)) if preimage_hex is None: return None preimage_bytes = bytes.fromhex(preimage_hex) @@ -2723,6 +2739,20 @@ class LNWallet(Logger): preimage_bytes = self.get_preimage(bytes.fromhex(payment_hash)) or b"" return preimage_bytes.hex() or None + def is_preimage_public(self, payment_hash: bytes) -> bool: + """If another LN node knows a preimage besides us, we consider it public. + If a preimage is public, it is safe to reveal it in an arbitrary context. + + For example, if there is a pending incoming partial MPP for an invoice we created, + we must not reveal the preimage, otherwise we will get paid less than invoice amount. + What if there is a force-close around that time? When is it safe to reveal the preimage on-chain? + e.g. if we already revealed the preimage either offchain or onchain, it is fine to reveal it again. + """ + assert isinstance(payment_hash, bytes), f"expected bytes, but got {type(payment_hash)}" + preimage_hex, is_public = self._preimages.get(payment_hash.hex(), (None, None)) + assert preimage_hex is not None + return bool(is_public) + def get_payment_info(self, payment_hash: bytes, *, direction: lnutil.Direction) -> Optional[PaymentInfo]: """returns None if payment_hash is a payment we are forwarding""" key = PaymentInfo.calc_db_key(payment_hash_hex=payment_hash.hex(), direction=direction) diff --git a/electrum/submarine_swaps.py b/electrum/submarine_swaps.py index 8ea53df59..1a761230d 100644 --- a/electrum/submarine_swaps.py +++ b/electrum/submarine_swaps.py @@ -488,7 +488,7 @@ class SwapManager(Logger): if preimage: swap.preimage = preimage self.logger.info(f'found preimage: {preimage.hex()}') - self.lnworker.save_preimage(swap.payment_hash, preimage) + self.lnworker.save_preimage(swap.payment_hash, preimage, mark_as_public=True) else: # this is our refund tx if spent_height > 0: diff --git a/electrum/wallet_db.py b/electrum/wallet_db.py index 36df85498..7f0c89370 100644 --- a/electrum/wallet_db.py +++ b/electrum/wallet_db.py @@ -69,7 +69,7 @@ class WalletUnfinished(WalletFileException): # seed_version is now used for the version of the wallet file OLD_SEED_VERSION = 4 # electrum versions < 2.0 NEW_SEED_VERSION = 11 # electrum versions >= 2.0 -FINAL_SEED_VERSION = 67 # electrum >= 2.7 will set this to prevent +FINAL_SEED_VERSION = 68 # electrum >= 2.7 will set this to prevent # old versions from overwriting new format @@ -242,6 +242,7 @@ class WalletDBUpgrader(Logger): self._convert_version_65() self._convert_version_66() self._convert_version_67() + self._convert_version_68() self.put('seed_version', FINAL_SEED_VERSION) # just to be sure def _convert_wallet_type(self): @@ -1354,6 +1355,16 @@ class WalletDBUpgrader(Logger): self.data['channels'] = channels self.data['seed_version'] = 67 + def _convert_version_68(self): + if not self._is_upgrade_method_needed(67, 67): + return + old_preimages = self.data.get('lightning_preimages', {}) + new_preimages = {} + for _hash, preimage in old_preimages.items(): + new_preimages[_hash] = (preimage, False) + self.data['lightning_preimages'] = new_preimages + self.data['seed_version'] = 68 + def _convert_imported(self): if not self._is_upgrade_method_needed(0, 13): return diff --git a/tests/test_commands.py b/tests/test_commands.py index eb6d634da..dddf293fd 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -571,7 +571,7 @@ class TestCommandsTestnet(ElectrumTestCase): wallet=wallet, ) assert settle_result['settled'] == payment_hash - assert wallet.lnworker._preimages[payment_hash] == preimage.hex() + assert wallet.lnworker._preimages[payment_hash][0] == preimage.hex() with (mock.patch.object( wallet.lnworker, 'get_payment_value', diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 6ba3f750a..e1a9c0b67 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -472,11 +472,12 @@ class TestPeer(ElectrumTestCase): def prepare_recipient(self, w2, payment_hash, test_hold_invoice, test_failure): if not test_hold_invoice and not test_failure: return - preimage = bytes.fromhex(w2._preimages.pop(payment_hash.hex())) + preimage_hex, is_public = w2._preimages.pop(payment_hash.hex()) + preimage = bytes.fromhex(preimage_hex) if test_hold_invoice: async def cb(payment_hash): if not test_failure: - w2.save_preimage(payment_hash, preimage) + w2.save_preimage(payment_hash, preimage, mark_as_public=is_public) else: raise OnionRoutingFailure(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'') w2.register_hold_invoice(payment_hash, cb)