diff --git a/electrum/commands.py b/electrum/commands.py index 6ee60d139..09bc31a14 100644 --- a/electrum/commands.py +++ b/electrum/commands.py @@ -1438,7 +1438,7 @@ class Commands(Logger): assert payment_hash in wallet.lnworker.payment_info, \ f"Couldn't find lightning invoice for {payment_hash=}" assert payment_hash in wallet.lnworker.dont_settle_htlcs, f"Invoice {payment_hash=} not a hold invoice?" - assert wallet.lnworker.is_accepted_mpp(bfh(payment_hash)), \ + assert wallet.lnworker.is_complete_mpp(bfh(payment_hash)), \ f"MPP incomplete, cannot settle hold invoice {payment_hash} yet" info: Optional['PaymentInfo'] = wallet.lnworker.get_payment_info(bfh(payment_hash)) assert (wallet.lnworker.get_payment_mpp_amount_msat(bfh(payment_hash)) or 0) >= (info.amount_msat or 0) @@ -1465,7 +1465,7 @@ class Commands(Logger): wallet.lnworker.set_payment_status(bfh(payment_hash), PR_UNPAID) wallet.lnworker.delete_payment_info(payment_hash) wallet.set_label(payment_hash, None) - while wallet.lnworker.is_accepted_mpp(bfh(payment_hash)): + while wallet.lnworker.is_complete_mpp(bfh(payment_hash)): # wait until the htlcs got failed so the payment won't get settled accidentally in a race await asyncio.sleep(0.1) del wallet.lnworker.dont_settle_htlcs[payment_hash] @@ -1490,7 +1490,7 @@ class Commands(Logger): """ assert len(payment_hash) == 64, f"Invalid payment_hash length: {len(payment_hash)} != 64" info: Optional['PaymentInfo'] = wallet.lnworker.get_payment_info(bfh(payment_hash)) - is_accepted_mpp: bool = wallet.lnworker.is_accepted_mpp(bfh(payment_hash)) + is_complete_mpp: bool = wallet.lnworker.is_complete_mpp(bfh(payment_hash)) amount_sat = (wallet.lnworker.get_payment_mpp_amount_msat(bfh(payment_hash)) or 0) // 1000 result = { "status": "unknown", @@ -1498,10 +1498,10 @@ class Commands(Logger): } if info is None: pass - elif not is_accepted_mpp and not wallet.lnworker.get_preimage_hex(payment_hash): - # is_accepted_mpp is False for settled payments + elif not is_complete_mpp and not wallet.lnworker.get_preimage_hex(payment_hash): + # is_complete_mpp is False for settled payments result["status"] = "unpaid" - elif is_accepted_mpp and payment_hash in wallet.lnworker.dont_settle_htlcs: + elif is_complete_mpp and payment_hash in wallet.lnworker.dont_settle_htlcs: result["status"] = "paid" payment_key: str = wallet.lnworker._get_payment_key(bfh(payment_hash)).hex() htlc_status = wallet.lnworker.received_mpp_htlcs[payment_key] diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py index b6e19647d..2d57a4a8b 100644 --- a/electrum/lnchannel.py +++ b/electrum/lnchannel.py @@ -17,6 +17,7 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. +import dataclasses import enum from collections import defaultdict from enum import IntEnum, Enum @@ -1202,12 +1203,10 @@ class Channel(AbstractChannel): """Adds a new LOCAL HTLC to the channel. Action must be initiated by LOCAL. """ - if isinstance(htlc, dict): # legacy conversion # FIXME remove - htlc = UpdateAddHtlc(**htlc) assert isinstance(htlc, UpdateAddHtlc) self._assert_can_add_htlc(htlc_proposer=LOCAL, amount_msat=htlc.amount_msat) if htlc.htlc_id is None: - htlc = attr.evolve(htlc, htlc_id=self.hm.get_next_htlc_id(LOCAL)) + htlc = dataclasses.replace(htlc, htlc_id=self.hm.get_next_htlc_id(LOCAL)) with self.db_lock: self.hm.send_htlc(htlc) self.logger.info("add_htlc") @@ -1217,15 +1216,13 @@ class Channel(AbstractChannel): """Adds a new REMOTE HTLC to the channel. Action must be initiated by REMOTE. """ - if isinstance(htlc, dict): # legacy conversion # FIXME remove - htlc = UpdateAddHtlc(**htlc) assert isinstance(htlc, UpdateAddHtlc) try: self._assert_can_add_htlc(htlc_proposer=REMOTE, amount_msat=htlc.amount_msat) except PaymentFailure as e: raise RemoteMisbehaving(e) from e if htlc.htlc_id is None: # used in unit tests - htlc = attr.evolve(htlc, htlc_id=self.hm.get_next_htlc_id(REMOTE)) + htlc = dataclasses.replace(htlc, htlc_id=self.hm.get_next_htlc_id(REMOTE)) with self.db_lock: self.hm.recv_htlc(htlc) if onion_packet: diff --git a/electrum/lnonion.py b/electrum/lnonion.py index fa7337295..b80c8821a 100644 --- a/electrum/lnonion.py +++ b/electrum/lnonion.py @@ -36,6 +36,7 @@ from .lnutil import (PaymentFailure, NUM_MAX_HOPS_IN_PAYMENT_PATH, NUM_MAX_EDGES_IN_PAYMENT_PATH, ShortChannelID, OnionFailureCodeMetaFlag) from .lnmsg import OnionWireSerializer, read_bigsize_int, write_bigsize_int from . import lnmsg +from . import util if TYPE_CHECKING: from .lnrouter import LNPaymentRoute @@ -113,11 +114,11 @@ class OnionHopsDataSingle: # called HopData in lnd class OnionPacket: - def __init__(self, public_key: bytes, hops_data: bytes, hmac: bytes): + def __init__(self, public_key: bytes, hops_data: bytes, hmac: bytes, version: int = 0): assert len(public_key) == 33 assert len(hops_data) in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE, ONION_MESSAGE_LARGE_SIZE] assert len(hmac) == PER_HOP_HMAC_SIZE - self.version = 0 + self.version = version self.public_key = public_key self.hops_data = hops_data # also called RoutingInfo in bolt-04 self.hmac = hmac @@ -140,13 +141,11 @@ class OnionPacket: def from_bytes(cls, b: bytes): if len(b) - 66 not in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE, ONION_MESSAGE_LARGE_SIZE]: raise Exception('unexpected length {}'.format(len(b))) - version = b[0] - if version != 0: - raise UnsupportedOnionPacketVersion('version {} is not supported'.format(version)) return OnionPacket( public_key=b[1:34], hops_data=b[34:-32], - hmac=b[-32:] + hmac=b[-32:], + version=b[0], ) @@ -361,6 +360,9 @@ def process_onion_packet( associated_data: bytes = b'', is_trampoline=False, tlv_stream_name='payload') -> ProcessedOnionPacket: + # TODO: check Onion features ( PERM|NODE|3 (required_node_feature_missing ) + if onion_packet.version != 0: + raise UnsupportedOnionPacketVersion() if not ecc.ECPubkey.is_pubkey_bytes(onion_packet.public_key): raise InvalidOnionPubkey() shared_secret = get_ecdh(our_onion_private_key, onion_packet.public_key) @@ -369,7 +371,7 @@ def process_onion_packet( calculated_mac = hmac_oneshot( mu_key, msg=onion_packet.hops_data+associated_data, digest=hashlib.sha256) - if onion_packet.hmac != calculated_mac: + if not util.constant_time_compare(onion_packet.hmac, calculated_mac): raise InvalidOnionMac() # peel an onion layer off rho_key = get_bolt04_onion_key(b'rho', shared_secret) @@ -484,23 +486,38 @@ def obfuscate_onion_error(error_packet, their_public_key, our_onion_private_key) def _decode_onion_error(error_packet: bytes, payment_path_pubkeys: Sequence[bytes], session_key: bytes) -> Tuple[bytes, int]: - """Returns the decoded error bytes, and the index of the sender of the error.""" + """ + Returns the decoded error bytes, and the index of the sender of the error. + https://github.com/lightning/bolts/blob/14272b1bd9361750cfdb3e5d35740889a6b510b5/04-onion-routing.md?plain=1#L1096 + """ num_hops = len(payment_path_pubkeys) hop_shared_secrets, _ = get_shared_secrets_along_route(payment_path_pubkeys, session_key) - for i in range(num_hops): - ammag_key = get_bolt04_onion_key(b'ammag', hop_shared_secrets[i]) - um_key = get_bolt04_onion_key(b'um', hop_shared_secrets[i]) + result = None + dummy_secret = bytes(32) + # SHOULD continue decrypting, until the loop has been repeated 27 times + for i in range(27): + if i < num_hops: + ammag_key = get_bolt04_onion_key(b'ammag', hop_shared_secrets[i]) + um_key = get_bolt04_onion_key(b'um', hop_shared_secrets[i]) + else: + # SHOULD use constant `ammag` and `um` keys to obfuscate the route length. + ammag_key = get_bolt04_onion_key(b'ammag', dummy_secret) + um_key = get_bolt04_onion_key(b'um', dummy_secret) + stream_bytes = generate_cipher_stream(ammag_key, len(error_packet)) error_packet = xor_bytes(error_packet, stream_bytes) hmac_computed = hmac_oneshot(um_key, msg=error_packet[32:], digest=hashlib.sha256) hmac_found = error_packet[:32] - if hmac_computed == hmac_found: - return error_packet, i + if util.constant_time_compare(hmac_found, hmac_computed) and i < num_hops: + result = error_packet, i + + if result is not None: + return result raise FailedToDecodeOnionError() def decode_onion_error(error_packet: bytes, payment_path_pubkeys: Sequence[bytes], - session_key: bytes) -> (OnionRoutingFailure, int): + session_key: bytes) -> Tuple[OnionRoutingFailure, int]: """Returns the failure message, and the index of the sender of the error.""" decrypted_error, sender_index = _decode_onion_error(error_packet, payment_path_pubkeys, session_key) failure_msg = get_failure_msg_from_onion_error(decrypted_error) diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index ea0d16ec2..48b442693 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -24,7 +24,7 @@ from aiorpcx import ignore_after from .crypto import sha256, sha256d, privkey_to_pubkey from . import bitcoin, util from . import constants -from .util import (bfh, log_exceptions, ignore_exceptions, chunks, OldTaskGroup, +from .util import (log_exceptions, ignore_exceptions, chunks, OldTaskGroup, UnrelatedTransactionException, error_text_bytes_to_safe_str, AsyncHangDetector, NoDynamicFeeEstimates, event_listener, EventListener) from . import transaction @@ -36,7 +36,7 @@ from .lnonion import (new_onion_packet, OnionFailureCode, calc_hops_data_for_pay OnionPacket, construct_onion_error, obfuscate_onion_error, OnionRoutingFailure, ProcessedOnionPacket, UnsupportedOnionPacketVersion, InvalidOnionMac, InvalidOnionPubkey, OnionFailureCodeMetaFlag) -from .lnchannel import Channel, RevokeAndAck, RemoteCtnTooFarInFuture, ChannelState, PeerState, ChanCloseOption, CF_ANNOUNCE_CHANNEL +from .lnchannel import Channel, RevokeAndAck, ChannelState, PeerState, ChanCloseOption, CF_ANNOUNCE_CHANNEL from . import lnutil from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc, ChannelConfig, RemoteConfig, OnlyPubkeyKeypair, ChannelConstraints, RevocationStore, @@ -45,19 +45,17 @@ from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc, ChannelConf LOCAL, REMOTE, HTLCOwner, ln_compare_features, MIN_FINAL_CLTV_DELTA_ACCEPTED, RemoteMisbehaving, ShortChannelID, - IncompatibleLightningFeatures, derive_payment_secret_from_payment_preimage, - ChannelType, LNProtocolWarning, validate_features, + IncompatibleLightningFeatures, ChannelType, LNProtocolWarning, validate_features, IncompatibleOrInsaneFeatures, FeeBudgetExceeded, - GossipForwardingMessage, GossipTimestampFilter) -from .lnutil import FeeUpdate, channel_id_from_funding_tx, PaymentFeeBudget -from .lnutil import serialize_htlc_key, Keypair + GossipForwardingMessage, GossipTimestampFilter, channel_id_from_funding_tx, + PaymentFeeBudget, serialize_htlc_key, Keypair, RecvMPPResolution) from .lntransport import LNTransport, LNTransportBase, LightningPeerConnectionClosed, HandshakeFailed from .lnmsg import encode_msg, decode_msg, UnknownOptionalMsgType, FailedToParseMsg from .interface import GracefulDisconnect from .lnrouter import fee_for_edge_msat from .json_db import StoredDict from .invoices import PR_PAID -from .fee_policy import FEE_LN_ETA_TARGET, FEE_LN_MINIMUM_ETA_TARGET, FEERATE_PER_KW_MIN_RELAY_LIGHTNING +from .fee_policy import FEE_LN_ETA_TARGET, FEERATE_PER_KW_MIN_RELAY_LIGHTNING from .trampoline import decode_routing_info if TYPE_CHECKING: @@ -525,12 +523,13 @@ class Peer(Logger, EventListener): @handle_disconnect async def main_loop(self): async with self.taskgroup as group: - await group.spawn(self.htlc_switch()) await group.spawn(self._message_loop()) await group.spawn(self._query_gossip()) await group.spawn(self._process_gossip()) await group.spawn(self._send_own_gossip()) await group.spawn(self._forward_gossip()) + if self.network.lngossip != self.lnworker: + await group.spawn(self.htlc_switch()) async def _process_gossip(self): while True: @@ -2467,7 +2466,6 @@ class Peer(Logger, EventListener): exc_incorrect_or_unknown_pd: OnionRoutingFailure, log_fail_reason: Callable[[str], None], ) -> bool: - from .lnworker import RecvMPPResolution mpp_resolution = self.lnworker.check_mpp_status( payment_secret=payment_secret, short_channel_id=short_channel_id, @@ -2482,7 +2480,7 @@ class Peer(Logger, EventListener): elif mpp_resolution == RecvMPPResolution.FAILED: log_fail_reason(f"mpp_resolution is FAILED") raise exc_incorrect_or_unknown_pd - elif mpp_resolution == RecvMPPResolution.ACCEPTED: + elif mpp_resolution == RecvMPPResolution.COMPLETE: return False else: raise Exception(f"unexpected {mpp_resolution=}") @@ -2592,11 +2590,9 @@ class Peer(Logger, EventListener): raise exc_incorrect_or_unknown_pd preimage = self.lnworker.get_preimage(payment_hash) - expected_payment_secrets = [self.lnworker.get_payment_secret(htlc.payment_hash)] - if preimage: - expected_payment_secrets.append(derive_payment_secret_from_payment_preimage(preimage)) # legacy secret for old invoices - if payment_secret_from_onion not in expected_payment_secrets: - log_fail_reason(f'incorrect payment secret {payment_secret_from_onion.hex()} != {expected_payment_secrets[0].hex()}') + expected_payment_secret = self.lnworker.get_payment_secret(htlc.payment_hash) + if payment_secret_from_onion != expected_payment_secret: + log_fail_reason(f'incorrect payment secret {payment_secret_from_onion.hex()} != {expected_payment_secret.hex()}') raise exc_incorrect_or_unknown_pd invoice_msat = info.amount_msat if channel_opening_fee: diff --git a/electrum/lnsweep.py b/electrum/lnsweep.py index c7170e1e9..f63064162 100644 --- a/electrum/lnsweep.py +++ b/electrum/lnsweep.py @@ -432,7 +432,7 @@ def sweep_our_ctx( ctn=ctn) for (direction, htlc), (ctx_output_idx, htlc_relative_idx) in htlc_to_ctx_output_idx_map.items(): if direction == RECEIVED: - if not chan.lnworker.is_accepted_mpp(htlc.payment_hash): + if not chan.lnworker.is_complete_mpp(htlc.payment_hash): # do not redeem this, it might publish the preimage of an incomplete MPP continue preimage = chan.lnworker.get_preimage(htlc.payment_hash) @@ -727,7 +727,7 @@ def sweep_their_ctx( for (direction, htlc), (ctx_output_idx, htlc_relative_idx) in htlc_to_ctx_output_idx_map.items(): is_received_htlc = direction == RECEIVED if not is_received_htlc and not is_revocation: - if not chan.lnworker.is_accepted_mpp(htlc.payment_hash): + if not chan.lnworker.is_complete_mpp(htlc.payment_hash): # do not redeem this, it might publish the preimage of an incomplete MPP continue preimage = chan.lnworker.get_preimage(htlc.payment_hash) diff --git a/electrum/lnutil.py b/electrum/lnutil.py index 3903f6979..025cc9ad8 100644 --- a/electrum/lnutil.py +++ b/electrum/lnutil.py @@ -12,9 +12,10 @@ from functools import lru_cache import electrum_ecc as ecc from electrum_ecc import CURVE_ORDER, ecdsa_sig64_from_der_sig from electrum_ecc.util import bip340_tagged_hash +import dataclasses import attr -from .util import bfh, UserFacingException, list_enabled_bits +from .util import bfh, UserFacingException, list_enabled_bits, is_hex_str from .util import ShortID as ShortChannelID, format_short_id as format_short_channel_id from .crypto import sha256, pw_decode_with_version_and_mac @@ -22,7 +23,8 @@ from .transaction import ( Transaction, PartialTransaction, PartialTxInput, TxOutpoint, PartialTxOutput, opcodes, OPPushDataPubkey ) from . import bitcoin, crypto, transaction, descriptor, segwit_addr -from .bitcoin import redeem_script_to_address, address_to_script, construct_witness, construct_script +from .bitcoin import redeem_script_to_address, address_to_script, construct_witness, \ + construct_script, NLOCKTIME_BLOCKHEIGHT_MAX from .i18n import _ from .bip32 import BIP32Node, BIP32_PRIME from .transaction import BCDataStream, OPPushDataGeneric @@ -1825,19 +1827,6 @@ def validate_features(features: int) -> LnFeatures: return features -def derive_payment_secret_from_payment_preimage(payment_preimage: bytes) -> bytes: - """Returns secret to be put into invoice. - Derivation is deterministic, based on the preimage. - Crucially the payment_hash must be derived in an independent way from this. - """ - # Note that this could be random data too, but then we would need to store it. - # We derive it identically to clightning, so that we cannot be distinguished: - # https://github.com/ElementsProject/lightning/blob/faac4b28adee5221e83787d64cd5d30b16b62097/lightningd/invoice.c#L115 - modified = bytearray(payment_preimage) - modified[0] ^= 1 - return sha256(bytes(modified)) - - def get_compressed_pubkey_from_bech32(bech32_pubkey: str) -> bytes: decoded_bech32 = segwit_addr.bech32_decode(bech32_pubkey) hrp = decoded_bech32.hrp @@ -1908,26 +1897,61 @@ NUM_MAX_HOPS_IN_PAYMENT_PATH = 20 NUM_MAX_EDGES_IN_PAYMENT_PATH = NUM_MAX_HOPS_IN_PAYMENT_PATH -@attr.s(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class UpdateAddHtlc: - amount_msat = attr.ib(type=int, kw_only=True) - payment_hash = attr.ib(type=bytes, kw_only=True, converter=hex_to_bytes, repr=lambda val: val.hex()) - cltv_abs = attr.ib(type=int, kw_only=True) - timestamp = attr.ib(type=int, kw_only=True) - htlc_id = attr.ib(type=int, kw_only=True, default=None) + amount_msat: int + payment_hash: bytes + cltv_abs: int + htlc_id: Optional[int] = dataclasses.field(default=None) + timestamp: int = dataclasses.field(default_factory=lambda: int(time.time())) @staticmethod @stored_in('adds', tuple) - def from_tuple(amount_msat, payment_hash, cltv_abs, htlc_id, timestamp) -> 'UpdateAddHtlc': + def from_tuple(amount_msat, rhash, cltv_abs, htlc_id, timestamp) -> 'UpdateAddHtlc': return UpdateAddHtlc( amount_msat=amount_msat, - payment_hash=payment_hash, + payment_hash=bytes.fromhex(rhash), cltv_abs=cltv_abs, htlc_id=htlc_id, timestamp=timestamp) def to_json(self): - return self.amount_msat, self.payment_hash, self.cltv_abs, self.htlc_id, self.timestamp + self._validate() + return dataclasses.astuple(self) + + def _validate(self): + assert isinstance(self.amount_msat, int), self.amount_msat + assert isinstance(self.payment_hash, bytes) and len(self.payment_hash) == 32 + assert isinstance(self.cltv_abs, int) and self.cltv_abs <= NLOCKTIME_BLOCKHEIGHT_MAX, self.cltv_abs + assert isinstance(self.htlc_id, int) or self.htlc_id is None, self.htlc_id + assert isinstance(self.timestamp, int), self.timestamp + + def __post_init__(self): + self._validate() + + +# Note: these states are persisted in the wallet file. +# Do not modify them without performing a wallet db upgrade +class RecvMPPResolution(IntEnum): + WAITING = 0 + EXPIRED = 1 + COMPLETE = 2 + FAILED = 3 + + +class ReceivedMPPStatus(NamedTuple): + resolution: RecvMPPResolution + expected_msat: int + htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]] + + @staticmethod + @stored_in('received_mpp_htlcs', tuple) + def from_tuple(resolution, expected_msat, htlc_list) -> 'ReceivedMPPStatus': + htlc_set = set([(ShortChannelID(bytes.fromhex(scid)), UpdateAddHtlc.from_tuple(*x)) for (scid, x) in htlc_list]) + return ReceivedMPPStatus( + resolution=RecvMPPResolution(resolution), + expected_msat=expected_msat, + htlc_set=htlc_set) class OnionFailureCodeMetaFlag(IntFlag): diff --git a/electrum/lnworker.py b/electrum/lnworker.py index c66843ee0..98bd12988 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -20,6 +20,7 @@ import concurrent from concurrent import futures import urllib.parse import itertools +import dataclasses import aiohttp import dns.asyncresolver @@ -67,7 +68,8 @@ from .lnutil import ( LnKeyFamily, LOCAL, REMOTE, MIN_FINAL_CLTV_DELTA_FOR_INVOICE, SENT, RECEIVED, HTLCOwner, UpdateAddHtlc, LnFeatures, ShortChannelID, HtlcLog, NoPathFound, InvalidGossipMsg, FeeBudgetExceeded, ImportedChannelBackupStorage, OnchainChannelBackupStorage, ln_compare_features, IncompatibleLightningFeatures, PaymentFeeBudget, - NBLOCK_CLTV_DELTA_TOO_FAR_INTO_FUTURE, GossipForwardingMessage, MIN_FUNDING_SAT + NBLOCK_CLTV_DELTA_TOO_FAR_INTO_FUTURE, GossipForwardingMessage, MIN_FUNDING_SAT, + RecvMPPResolution, ReceivedMPPStatus, ) from .lnonion import decode_onion_error, OnionFailureCode, OnionRoutingFailure, OnionPacket from .lnmsg import decode_msg @@ -106,35 +108,24 @@ class PaymentDirection(IntEnum): FORWARDING = 3 -class PaymentInfo(NamedTuple): +@dataclasses.dataclass(frozen=True, kw_only=True) +class PaymentInfo: + """Information required to handle incoming htlcs for a payment request""" payment_hash: bytes amount_msat: Optional[int] direction: int status: int + def validate(self): + assert isinstance(self.payment_hash, bytes) and len(self.payment_hash) == 32 + assert self.amount_msat is None or isinstance(self.amount_msat, int) + assert isinstance(self.direction, int) + assert isinstance(self.status, int) -# Note: these states are persisted in the wallet file. -# Do not modify them without performing a wallet db upgrade -class RecvMPPResolution(IntEnum): - WAITING = 0 - EXPIRED = 1 - ACCEPTED = 2 - FAILED = 3 + def __post_init__(self): + self.validate() -class ReceivedMPPStatus(NamedTuple): - resolution: RecvMPPResolution - expected_msat: int - htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]] - - @stored_in('received_mpp_htlcs', tuple) - def from_tuple(resolution, expected_msat, htlc_list) -> 'ReceivedMPPStatus': - htlc_set = set([(ShortChannelID(bytes.fromhex(scid)), UpdateAddHtlc.from_tuple(*x)) for (scid, x) in htlc_list]) - return ReceivedMPPStatus( - resolution=RecvMPPResolution(resolution), - expected_msat=expected_msat, - htlc_set=htlc_set) - SentHtlcKey = Tuple[bytes, ShortChannelID, int] # RHASH, scid, htlc_id @@ -1567,7 +1558,12 @@ class LNWallet(LNWorker): raise PaymentFailure(_("A payment was already initiated for this invoice")) if payment_hash in self.get_payments(status='inflight'): raise PaymentFailure(_("A previous attempt to pay this invoice did not clear")) - info = PaymentInfo(payment_hash, amount_to_pay, SENT, PR_UNPAID) + info = PaymentInfo( + payment_hash=payment_hash, + amount_msat=amount_to_pay, + direction=SENT, + status=PR_UNPAID, + ) self.save_payment_info(info) self.wallet.set_label(key, lnaddr.get_description()) self.set_invoice_status(key, PR_INFLIGHT) @@ -2302,7 +2298,12 @@ class LNWallet(LNWorker): def create_payment_info(self, *, amount_msat: Optional[int], write_to_disk=True) -> bytes: payment_preimage = os.urandom(32) payment_hash = sha256(payment_preimage) - info = PaymentInfo(payment_hash, amount_msat, RECEIVED, PR_UNPAID) + info = PaymentInfo( + payment_hash=payment_hash, + amount_msat=amount_msat, + direction=RECEIVED, + status=PR_UNPAID, + ) self.save_preimage(payment_hash, payment_preimage, write_to_disk=False) self.save_payment_info(info, write_to_disk=False) if write_to_disk: @@ -2376,12 +2377,22 @@ class LNWallet(LNWorker): with self.lock: if key in self.payment_info: amount_msat, direction, status = self.payment_info[key] - return PaymentInfo(payment_hash, amount_msat, direction, status) + return PaymentInfo( + payment_hash=payment_hash, + amount_msat=amount_msat, + direction=direction, + status=status, + ) return None def add_payment_info_for_hold_invoice(self, payment_hash: bytes, lightning_amount_sat: Optional[int]): amount = lightning_amount_sat * 1000 if lightning_amount_sat else None - info = PaymentInfo(payment_hash, amount, RECEIVED, PR_UNPAID) + info = PaymentInfo( + payment_hash=payment_hash, + amount_msat=amount, + direction=RECEIVED, + status=PR_UNPAID, + ) self.save_payment_info(info, write_to_disk=False) def register_hold_invoice(self, payment_hash: bytes, cb: Callable[[bytes], Awaitable[None]]): @@ -2396,11 +2407,11 @@ class LNWallet(LNWorker): if old_info := self.get_payment_info(payment_hash=info.payment_hash): if info == old_info: return # already saved - if info != old_info._replace(status=info.status): + if info != dataclasses.replace(old_info, status=info.status): # differs more than in status. let's fail raise Exception("payment_hash already in use") key = info.payment_hash.hex() - self.payment_info[key] = info.amount_msat, info.direction, info.status + self.payment_info[key] = dataclasses.astuple(info)[1:] # drop the payment hash at index 0 if write_to_disk: self.wallet.save_db() @@ -2433,12 +2444,12 @@ class LNWallet(LNWorker): payment_keys = [payment_key] first_timestamp = min([self.get_first_timestamp_of_mpp(pkey) for pkey in payment_keys]) if self.get_payment_status(payment_hash) == PR_PAID: - mpp_resolution = RecvMPPResolution.ACCEPTED + mpp_resolution = RecvMPPResolution.COMPLETE elif self.stopping_soon: # try to time out pending HTLCs before shutting down mpp_resolution = RecvMPPResolution.EXPIRED elif all([self.is_mpp_amount_reached(pkey) for pkey in payment_keys]): - mpp_resolution = RecvMPPResolution.ACCEPTED + mpp_resolution = RecvMPPResolution.COMPLETE elif time.time() - first_timestamp > self.MPP_EXPIRY: mpp_resolution = RecvMPPResolution.EXPIRED # save resolution, if any. @@ -2486,10 +2497,10 @@ class LNWallet(LNWorker): total, expected = amounts return total >= expected - def is_accepted_mpp(self, payment_hash: bytes) -> bool: + def is_complete_mpp(self, payment_hash: bytes) -> bool: payment_key = self._get_payment_key(payment_hash) status = self.received_mpp_htlcs.get(payment_key.hex()) - return status and status.resolution == RecvMPPResolution.ACCEPTED + return status and status.resolution == RecvMPPResolution.COMPLETE def get_payment_mpp_amount_msat(self, payment_hash: bytes) -> Optional[int]: """Returns the received mpp amount for given payment hash.""" @@ -2577,7 +2588,7 @@ class LNWallet(LNWorker): if info is None: # if we are forwarding return - info = info._replace(status=status) + info = dataclasses.replace(info, status=status) self.save_payment_info(info) def is_forwarded_htlc(self, htlc_key) -> Optional[str]: diff --git a/tests/test_bolt11.py b/tests/test_bolt11.py index c4f756b1b..fd1884ff3 100644 --- a/tests/test_bolt11.py +++ b/tests/test_bolt11.py @@ -7,7 +7,7 @@ import unittest from electrum.lnaddr import shorten_amount, unshorten_amount, LnAddr, lnencode, lndecode from electrum.segwit_addr import bech32_encode, bech32_decode from electrum import segwit_addr -from electrum.lnutil import UnknownEvenFeatureBits, derive_payment_secret_from_payment_preimage, LnFeatures, IncompatibleLightningFeatures +from electrum.lnutil import UnknownEvenFeatureBits, LnFeatures, IncompatibleLightningFeatures from electrum import constants from . import ElectrumTestCase @@ -164,11 +164,6 @@ class TestBolt11(ElectrumTestCase): self.assertEqual((1 << 9) + (1 << 15) + (1 << 99), lnaddr.get_tag('9')) self.assertEqual(b"\x11" * 32, lnaddr.payment_secret) - def test_derive_payment_secret_from_payment_preimage(self): - preimage = bytes.fromhex("cc3fc000bdeff545acee53ada12ff96060834be263f77d645abbebc3a8d53b92") - self.assertEqual("bfd660b559b3f452c6bb05b8d2906f520c151c107b733863ed0cc53fc77021a8", - derive_payment_secret_from_payment_preimage(preimage).hex()) - def test_validate_and_compare_features(self): lnaddr = lndecode("lnbc25m1pvjluezpp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypqsp5zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zygsdq5vdhkven9v5sxyetpdees9q5sqqqqqqqqqqqqqqqpqsqvvh7ut50r00p3pg34ea68k7zfw64f8yx9jcdk35lh5ft8qdr8g4r0xzsdcrmcy9hex8un8d8yraewvhqc9l0sh8l0e0yvmtxde2z0hgpzsje5l") lnaddr.validate_and_compare_features(LnFeatures((1 << 8) + (1 << 14) + (1 << 15))) diff --git a/tests/test_commands.py b/tests/test_commands.py index f1169ec40..ccd052b38 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -547,7 +547,7 @@ class TestCommandsTestnet(ElectrumTestCase): mock_htlc2.amount_msat = 5_500_000 mock_htlc_status = mock.Mock() mock_htlc_status.htlc_set = [(None, mock_htlc1), (None, mock_htlc2)] - mock_htlc_status.resolution = RecvMPPResolution.ACCEPTED + mock_htlc_status.resolution = RecvMPPResolution.COMPLETE payment_key = wallet.lnworker._get_payment_key(bytes.fromhex(payment_hash)).hex() with mock.patch.dict(wallet.lnworker.received_mpp_htlcs, {payment_key: mock_htlc_status}): diff --git a/tests/test_lnchannel.py b/tests/test_lnchannel.py index 39d8d935c..59d7f0024 100644 --- a/tests/test_lnchannel.py +++ b/tests/test_lnchannel.py @@ -27,6 +27,7 @@ import os import binascii from pprint import pformat import logging +import dataclasses from electrum import bitcoin from electrum import lnpeer @@ -257,31 +258,34 @@ class TestChannel(ElectrumTestCase): self.paymentPreimage = b"\x01" * 32 paymentHash = bitcoin.sha256(self.paymentPreimage) - self.htlc_dict = { - 'payment_hash': paymentHash, - 'amount_msat': one_bitcoin_in_msat, - 'cltv_abs': 5, - 'timestamp': 0, - } + self.htlc = UpdateAddHtlc( + payment_hash=paymentHash, + amount_msat=one_bitcoin_in_msat, + cltv_abs=5, + timestamp=0, + ) # First Alice adds the outgoing HTLC to her local channel's state # update log. Then Alice sends this wire message over to Bob who adds # this htlc to his remote state update log. - self.aliceHtlcIndex = self.alice_channel.add_htlc(self.htlc_dict).htlc_id + self.aliceHtlcIndex = self.alice_channel.add_htlc(self.htlc).htlc_id self.assertNotEqual(list(self.alice_channel.hm.htlcs_by_direction(REMOTE, RECEIVED, 1).values()), []) before = self.bob_channel.balance_minus_outgoing_htlcs(REMOTE) beforeLocal = self.bob_channel.balance_minus_outgoing_htlcs(LOCAL) - self.bobHtlcIndex = self.bob_channel.receive_htlc(self.htlc_dict).htlc_id + self.bobHtlcIndex = self.bob_channel.receive_htlc(self.htlc).htlc_id self.htlc = self.bob_channel.hm.log[REMOTE]['adds'][0] def test_concurrent_reversed_payment(self): - self.htlc_dict['payment_hash'] = bitcoin.sha256(32 * b'\x02') - self.htlc_dict['amount_msat'] += 1000 - self.bob_channel.add_htlc(self.htlc_dict) - self.alice_channel.receive_htlc(self.htlc_dict) + self.htlc = dataclasses.replace( + self.htlc, + payment_hash=bitcoin.sha256(32 * b'\x02'), + amount_msat=self.htlc.amount_msat + 1000, + ) + self.bob_channel.add_htlc(self.htlc) + self.alice_channel.receive_htlc(self.htlc) self.assertNumberNonAnchorOutputs(2, self.alice_channel.get_latest_commitment(LOCAL)) self.assertNumberNonAnchorOutputs(3, self.alice_channel.get_next_commitment(LOCAL)) @@ -561,9 +565,12 @@ class TestChannel(ElectrumTestCase): tx6 = str(alice_channel.force_close_tx()) self.assertNotEqual(tx5, tx6) - self.htlc_dict['amount_msat'] *= 5 - bob_index = bob_channel.add_htlc(self.htlc_dict).htlc_id - alice_index = alice_channel.receive_htlc(self.htlc_dict).htlc_id + self.htlc = dataclasses.replace( + self.htlc, + amount_msat=self.htlc.amount_msat * 5, + ) + bob_index = bob_channel.add_htlc(self.htlc).htlc_id + alice_index = alice_channel.receive_htlc(self.htlc).htlc_id force_state_transition(bob_channel, alice_channel) @@ -662,18 +669,26 @@ class TestChannel(ElectrumTestCase): self.alice_to_bob_fee_update(0) force_state_transition(self.alice_channel, self.bob_channel) - self.htlc_dict['payment_hash'] = bitcoin.sha256(32 * b'\x02') - self.alice_channel.add_htlc(self.htlc_dict) - self.htlc_dict['payment_hash'] = bitcoin.sha256(32 * b'\x03') - self.alice_channel.add_htlc(self.htlc_dict) + self.htlc = dataclasses.replace( + self.htlc, + payment_hash=bitcoin.sha256(32 * b'\x02'), + ) + self.alice_channel.add_htlc(self.htlc) + self.htlc = dataclasses.replace( + self.htlc, + payment_hash=bitcoin.sha256(32 * b'\x03'), + ) + self.alice_channel.add_htlc(self.htlc) # now there are three htlcs (one was in setUp) # Alice now has an available balance of 2 BTC. We'll add a new HTLC of # value 2 BTC, which should make Alice's balance negative (since she # has to pay a commitment fee). - new = dict(self.htlc_dict) - new['amount_msat'] *= 2.5 - new['payment_hash'] = bitcoin.sha256(32 * b'\x04') + new = dataclasses.replace( + self.htlc, + amount_msat=int(self.htlc.amount_msat * 2.5), + payment_hash=bitcoin.sha256(32 * b'\x04'), + ) with self.assertRaises(lnutil.PaymentFailure) as cm: self.alice_channel.add_htlc(new) self.assertIn('Not enough local balance', cm.exception.args[0]) @@ -822,14 +837,14 @@ class TestChanReserve(ElectrumTestCase): # Bob: 5.0 paymentPreimage = b"\x01" * 32 paymentHash = bitcoin.sha256(paymentPreimage) - htlc_dict = { - 'payment_hash': paymentHash, - 'amount_msat': int(.5 * one_bitcoin_in_msat), - 'cltv_abs': 5, - 'timestamp': 0, - } - self.alice_channel.add_htlc(htlc_dict) - self.bob_channel.receive_htlc(htlc_dict) + htlc = UpdateAddHtlc( + payment_hash=paymentHash, + amount_msat=int(.5 * one_bitcoin_in_msat), + cltv_abs=5, + timestamp=0, + ) + self.alice_channel.add_htlc(htlc) + self.bob_channel.receive_htlc(htlc) # Force a state transition, making sure this HTLC is considered valid # even though the channel reserves are not met. force_state_transition(self.alice_channel, self.bob_channel) @@ -847,10 +862,10 @@ class TestChanReserve(ElectrumTestCase): # Alice: 4.5 # Bob: 5.0 with self.assertRaises(lnutil.PaymentFailure): - htlc_dict['payment_hash'] = bitcoin.sha256(32 * b'\x02') - self.bob_channel.add_htlc(htlc_dict) + htlc = dataclasses.replace(htlc, payment_hash=bitcoin.sha256(32 * b'\x02')) + self.bob_channel.add_htlc(htlc) with self.assertRaises(lnutil.RemoteMisbehaving): - self.alice_channel.receive_htlc(htlc_dict) + self.alice_channel.receive_htlc(htlc) def part2(self): paymentPreimage = b"\x01" * 32 @@ -861,22 +876,22 @@ class TestChanReserve(ElectrumTestCase): # Resulting balances: # Alice: 1.5 # Bob: 9.5 - htlc_dict = { - 'payment_hash': paymentHash, - 'amount_msat': int(3.5 * one_bitcoin_in_msat), - 'cltv_abs': 5, - } - self.alice_channel.add_htlc(htlc_dict) - self.bob_channel.receive_htlc(htlc_dict) + htlc = UpdateAddHtlc( + payment_hash=paymentHash, + amount_msat=int(3.5 * one_bitcoin_in_msat), + cltv_abs=5, + ) + self.alice_channel.add_htlc(htlc) + self.bob_channel.receive_htlc(htlc) # Add a second HTLC of 1 BTC. This should fail because it will take # Alice's balance all the way down to her channel reserve, but since # she is the initiator the additional transaction fee makes her # balance dip below. - htlc_dict['amount_msat'] = one_bitcoin_in_msat + htlc = dataclasses.replace(htlc, amount_msat=one_bitcoin_in_msat) with self.assertRaises(lnutil.PaymentFailure): - self.alice_channel.add_htlc(htlc_dict) + self.alice_channel.add_htlc(htlc) with self.assertRaises(lnutil.RemoteMisbehaving): - self.bob_channel.receive_htlc(htlc_dict) + self.bob_channel.receive_htlc(htlc) def part3(self): # Add a HTLC of 2 BTC to Alice, and the settle it. @@ -885,14 +900,14 @@ class TestChanReserve(ElectrumTestCase): # Bob: 7.0 paymentPreimage = b"\x01" * 32 paymentHash = bitcoin.sha256(paymentPreimage) - htlc_dict = { - 'payment_hash': paymentHash, - 'amount_msat': int(2 * one_bitcoin_in_msat), - 'cltv_abs': 5, - 'timestamp': 0, - } - alice_idx = self.alice_channel.add_htlc(htlc_dict).htlc_id - bob_idx = self.bob_channel.receive_htlc(htlc_dict).htlc_id + htlc = UpdateAddHtlc( + payment_hash=paymentHash, + amount_msat=int(2 * one_bitcoin_in_msat), + cltv_abs=5, + timestamp=0, + ) + alice_idx = self.alice_channel.add_htlc(htlc).htlc_id + bob_idx = self.bob_channel.receive_htlc(htlc).htlc_id force_state_transition(self.alice_channel, self.bob_channel) self.check_bals(one_bitcoin_in_msat * 3 - self.alice_channel.get_next_fee(LOCAL), @@ -906,9 +921,9 @@ class TestChanReserve(ElectrumTestCase): # And now let Bob add an HTLC of 1 BTC. This will take Bob's balance # all the way down to his channel reserve, but since he is not paying # the fee this is okay. - htlc_dict['amount_msat'] = one_bitcoin_in_msat - self.bob_channel.add_htlc(htlc_dict) - self.alice_channel.receive_htlc(htlc_dict) + htlc = dataclasses.replace(htlc, amount_msat=one_bitcoin_in_msat) + self.bob_channel.add_htlc(htlc) + self.alice_channel.receive_htlc(htlc) force_state_transition(self.alice_channel, self.bob_channel) self.check_bals(one_bitcoin_in_msat * 3 \ - self.alice_channel.get_next_fee(LOCAL), @@ -943,12 +958,12 @@ class TestDust(ElectrumTestCase): # to pay for his htlc success transaction below_dust_for_bob = dust_limit_bob - 1 htlc_amt = below_dust_for_bob + success_weight * (fee_per_kw // 1000) - htlc = { - 'payment_hash': paymentHash, - 'amount_msat': 1000 * htlc_amt, - 'cltv_abs': 5, # consistent with channel policy - 'timestamp': 0, - } + htlc = UpdateAddHtlc( + payment_hash=paymentHash, + amount_msat=1000 * htlc_amt, + cltv_abs=5, # consistent with channel policy + timestamp=0, + ) # add the htlc alice_htlc_id = alice_channel.add_htlc(htlc).htlc_id diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index b1bc6c5af..3eb4313f6 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -559,7 +559,12 @@ class TestPeer(ElectrumTestCase): payment_preimage = os.urandom(32) if payment_hash is None: payment_hash = sha256(payment_preimage) - info = PaymentInfo(payment_hash, amount_msat, RECEIVED, PR_UNPAID) + info = PaymentInfo( + payment_hash=payment_hash, + amount_msat=amount_msat, + direction=RECEIVED, + status=PR_UNPAID, + ) if payment_preimage: w2.save_preimage(payment_hash, payment_preimage) w2.save_payment_info(info)