partial merge 1: f321x's "lightning: refactor htlc switch"

split-off from https://github.com/spesmilo/electrum/pull/10230
This commit is contained in:
SomberNight
2025-09-29 17:12:25 +00:00
11 changed files with 229 additions and 169 deletions
+6 -6
View File
@@ -1438,7 +1438,7 @@ class Commands(Logger):
assert payment_hash in wallet.lnworker.payment_info, \
f"Couldn't find lightning invoice for {payment_hash=}"
assert payment_hash in wallet.lnworker.dont_settle_htlcs, f"Invoice {payment_hash=} not a hold invoice?"
assert wallet.lnworker.is_accepted_mpp(bfh(payment_hash)), \
assert wallet.lnworker.is_complete_mpp(bfh(payment_hash)), \
f"MPP incomplete, cannot settle hold invoice {payment_hash} yet"
info: Optional['PaymentInfo'] = wallet.lnworker.get_payment_info(bfh(payment_hash))
assert (wallet.lnworker.get_payment_mpp_amount_msat(bfh(payment_hash)) or 0) >= (info.amount_msat or 0)
@@ -1465,7 +1465,7 @@ class Commands(Logger):
wallet.lnworker.set_payment_status(bfh(payment_hash), PR_UNPAID)
wallet.lnworker.delete_payment_info(payment_hash)
wallet.set_label(payment_hash, None)
while wallet.lnworker.is_accepted_mpp(bfh(payment_hash)):
while wallet.lnworker.is_complete_mpp(bfh(payment_hash)):
# wait until the htlcs got failed so the payment won't get settled accidentally in a race
await asyncio.sleep(0.1)
del wallet.lnworker.dont_settle_htlcs[payment_hash]
@@ -1490,7 +1490,7 @@ class Commands(Logger):
"""
assert len(payment_hash) == 64, f"Invalid payment_hash length: {len(payment_hash)} != 64"
info: Optional['PaymentInfo'] = wallet.lnworker.get_payment_info(bfh(payment_hash))
is_accepted_mpp: bool = wallet.lnworker.is_accepted_mpp(bfh(payment_hash))
is_complete_mpp: bool = wallet.lnworker.is_complete_mpp(bfh(payment_hash))
amount_sat = (wallet.lnworker.get_payment_mpp_amount_msat(bfh(payment_hash)) or 0) // 1000
result = {
"status": "unknown",
@@ -1498,10 +1498,10 @@ class Commands(Logger):
}
if info is None:
pass
elif not is_accepted_mpp and not wallet.lnworker.get_preimage_hex(payment_hash):
# is_accepted_mpp is False for settled payments
elif not is_complete_mpp and not wallet.lnworker.get_preimage_hex(payment_hash):
# is_complete_mpp is False for settled payments
result["status"] = "unpaid"
elif is_accepted_mpp and payment_hash in wallet.lnworker.dont_settle_htlcs:
elif is_complete_mpp and payment_hash in wallet.lnworker.dont_settle_htlcs:
result["status"] = "paid"
payment_key: str = wallet.lnworker._get_payment_key(bfh(payment_hash)).hex()
htlc_status = wallet.lnworker.received_mpp_htlcs[payment_key]
+3 -6
View File
@@ -17,6 +17,7 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
import dataclasses
import enum
from collections import defaultdict
from enum import IntEnum, Enum
@@ -1202,12 +1203,10 @@ class Channel(AbstractChannel):
"""Adds a new LOCAL HTLC to the channel.
Action must be initiated by LOCAL.
"""
if isinstance(htlc, dict): # legacy conversion # FIXME remove
htlc = UpdateAddHtlc(**htlc)
assert isinstance(htlc, UpdateAddHtlc)
self._assert_can_add_htlc(htlc_proposer=LOCAL, amount_msat=htlc.amount_msat)
if htlc.htlc_id is None:
htlc = attr.evolve(htlc, htlc_id=self.hm.get_next_htlc_id(LOCAL))
htlc = dataclasses.replace(htlc, htlc_id=self.hm.get_next_htlc_id(LOCAL))
with self.db_lock:
self.hm.send_htlc(htlc)
self.logger.info("add_htlc")
@@ -1217,15 +1216,13 @@ class Channel(AbstractChannel):
"""Adds a new REMOTE HTLC to the channel.
Action must be initiated by REMOTE.
"""
if isinstance(htlc, dict): # legacy conversion # FIXME remove
htlc = UpdateAddHtlc(**htlc)
assert isinstance(htlc, UpdateAddHtlc)
try:
self._assert_can_add_htlc(htlc_proposer=REMOTE, amount_msat=htlc.amount_msat)
except PaymentFailure as e:
raise RemoteMisbehaving(e) from e
if htlc.htlc_id is None: # used in unit tests
htlc = attr.evolve(htlc, htlc_id=self.hm.get_next_htlc_id(REMOTE))
htlc = dataclasses.replace(htlc, htlc_id=self.hm.get_next_htlc_id(REMOTE))
with self.db_lock:
self.hm.recv_htlc(htlc)
if onion_packet:
+31 -14
View File
@@ -36,6 +36,7 @@ from .lnutil import (PaymentFailure, NUM_MAX_HOPS_IN_PAYMENT_PATH,
NUM_MAX_EDGES_IN_PAYMENT_PATH, ShortChannelID, OnionFailureCodeMetaFlag)
from .lnmsg import OnionWireSerializer, read_bigsize_int, write_bigsize_int
from . import lnmsg
from . import util
if TYPE_CHECKING:
from .lnrouter import LNPaymentRoute
@@ -113,11 +114,11 @@ class OnionHopsDataSingle: # called HopData in lnd
class OnionPacket:
def __init__(self, public_key: bytes, hops_data: bytes, hmac: bytes):
def __init__(self, public_key: bytes, hops_data: bytes, hmac: bytes, version: int = 0):
assert len(public_key) == 33
assert len(hops_data) in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE, ONION_MESSAGE_LARGE_SIZE]
assert len(hmac) == PER_HOP_HMAC_SIZE
self.version = 0
self.version = version
self.public_key = public_key
self.hops_data = hops_data # also called RoutingInfo in bolt-04
self.hmac = hmac
@@ -140,13 +141,11 @@ class OnionPacket:
def from_bytes(cls, b: bytes):
if len(b) - 66 not in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE, ONION_MESSAGE_LARGE_SIZE]:
raise Exception('unexpected length {}'.format(len(b)))
version = b[0]
if version != 0:
raise UnsupportedOnionPacketVersion('version {} is not supported'.format(version))
return OnionPacket(
public_key=b[1:34],
hops_data=b[34:-32],
hmac=b[-32:]
hmac=b[-32:],
version=b[0],
)
@@ -361,6 +360,9 @@ def process_onion_packet(
associated_data: bytes = b'',
is_trampoline=False,
tlv_stream_name='payload') -> ProcessedOnionPacket:
# TODO: check Onion features ( PERM|NODE|3 (required_node_feature_missing )
if onion_packet.version != 0:
raise UnsupportedOnionPacketVersion()
if not ecc.ECPubkey.is_pubkey_bytes(onion_packet.public_key):
raise InvalidOnionPubkey()
shared_secret = get_ecdh(our_onion_private_key, onion_packet.public_key)
@@ -369,7 +371,7 @@ def process_onion_packet(
calculated_mac = hmac_oneshot(
mu_key, msg=onion_packet.hops_data+associated_data,
digest=hashlib.sha256)
if onion_packet.hmac != calculated_mac:
if not util.constant_time_compare(onion_packet.hmac, calculated_mac):
raise InvalidOnionMac()
# peel an onion layer off
rho_key = get_bolt04_onion_key(b'rho', shared_secret)
@@ -484,23 +486,38 @@ def obfuscate_onion_error(error_packet, their_public_key, our_onion_private_key)
def _decode_onion_error(error_packet: bytes, payment_path_pubkeys: Sequence[bytes],
session_key: bytes) -> Tuple[bytes, int]:
"""Returns the decoded error bytes, and the index of the sender of the error."""
"""
Returns the decoded error bytes, and the index of the sender of the error.
https://github.com/lightning/bolts/blob/14272b1bd9361750cfdb3e5d35740889a6b510b5/04-onion-routing.md?plain=1#L1096
"""
num_hops = len(payment_path_pubkeys)
hop_shared_secrets, _ = get_shared_secrets_along_route(payment_path_pubkeys, session_key)
for i in range(num_hops):
ammag_key = get_bolt04_onion_key(b'ammag', hop_shared_secrets[i])
um_key = get_bolt04_onion_key(b'um', hop_shared_secrets[i])
result = None
dummy_secret = bytes(32)
# SHOULD continue decrypting, until the loop has been repeated 27 times
for i in range(27):
if i < num_hops:
ammag_key = get_bolt04_onion_key(b'ammag', hop_shared_secrets[i])
um_key = get_bolt04_onion_key(b'um', hop_shared_secrets[i])
else:
# SHOULD use constant `ammag` and `um` keys to obfuscate the route length.
ammag_key = get_bolt04_onion_key(b'ammag', dummy_secret)
um_key = get_bolt04_onion_key(b'um', dummy_secret)
stream_bytes = generate_cipher_stream(ammag_key, len(error_packet))
error_packet = xor_bytes(error_packet, stream_bytes)
hmac_computed = hmac_oneshot(um_key, msg=error_packet[32:], digest=hashlib.sha256)
hmac_found = error_packet[:32]
if hmac_computed == hmac_found:
return error_packet, i
if util.constant_time_compare(hmac_found, hmac_computed) and i < num_hops:
result = error_packet, i
if result is not None:
return result
raise FailedToDecodeOnionError()
def decode_onion_error(error_packet: bytes, payment_path_pubkeys: Sequence[bytes],
session_key: bytes) -> (OnionRoutingFailure, int):
session_key: bytes) -> Tuple[OnionRoutingFailure, int]:
"""Returns the failure message, and the index of the sender of the error."""
decrypted_error, sender_index = _decode_onion_error(error_packet, payment_path_pubkeys, session_key)
failure_msg = get_failure_msg_from_onion_error(decrypted_error)
+12 -16
View File
@@ -24,7 +24,7 @@ from aiorpcx import ignore_after
from .crypto import sha256, sha256d, privkey_to_pubkey
from . import bitcoin, util
from . import constants
from .util import (bfh, log_exceptions, ignore_exceptions, chunks, OldTaskGroup,
from .util import (log_exceptions, ignore_exceptions, chunks, OldTaskGroup,
UnrelatedTransactionException, error_text_bytes_to_safe_str, AsyncHangDetector,
NoDynamicFeeEstimates, event_listener, EventListener)
from . import transaction
@@ -36,7 +36,7 @@ from .lnonion import (new_onion_packet, OnionFailureCode, calc_hops_data_for_pay
OnionPacket, construct_onion_error, obfuscate_onion_error, OnionRoutingFailure,
ProcessedOnionPacket, UnsupportedOnionPacketVersion, InvalidOnionMac, InvalidOnionPubkey,
OnionFailureCodeMetaFlag)
from .lnchannel import Channel, RevokeAndAck, RemoteCtnTooFarInFuture, ChannelState, PeerState, ChanCloseOption, CF_ANNOUNCE_CHANNEL
from .lnchannel import Channel, RevokeAndAck, ChannelState, PeerState, ChanCloseOption, CF_ANNOUNCE_CHANNEL
from . import lnutil
from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc, ChannelConfig,
RemoteConfig, OnlyPubkeyKeypair, ChannelConstraints, RevocationStore,
@@ -45,19 +45,17 @@ from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc, ChannelConf
LOCAL, REMOTE, HTLCOwner,
ln_compare_features, MIN_FINAL_CLTV_DELTA_ACCEPTED,
RemoteMisbehaving, ShortChannelID,
IncompatibleLightningFeatures, derive_payment_secret_from_payment_preimage,
ChannelType, LNProtocolWarning, validate_features,
IncompatibleLightningFeatures, ChannelType, LNProtocolWarning, validate_features,
IncompatibleOrInsaneFeatures, FeeBudgetExceeded,
GossipForwardingMessage, GossipTimestampFilter)
from .lnutil import FeeUpdate, channel_id_from_funding_tx, PaymentFeeBudget
from .lnutil import serialize_htlc_key, Keypair
GossipForwardingMessage, GossipTimestampFilter, channel_id_from_funding_tx,
PaymentFeeBudget, serialize_htlc_key, Keypair, RecvMPPResolution)
from .lntransport import LNTransport, LNTransportBase, LightningPeerConnectionClosed, HandshakeFailed
from .lnmsg import encode_msg, decode_msg, UnknownOptionalMsgType, FailedToParseMsg
from .interface import GracefulDisconnect
from .lnrouter import fee_for_edge_msat
from .json_db import StoredDict
from .invoices import PR_PAID
from .fee_policy import FEE_LN_ETA_TARGET, FEE_LN_MINIMUM_ETA_TARGET, FEERATE_PER_KW_MIN_RELAY_LIGHTNING
from .fee_policy import FEE_LN_ETA_TARGET, FEERATE_PER_KW_MIN_RELAY_LIGHTNING
from .trampoline import decode_routing_info
if TYPE_CHECKING:
@@ -525,12 +523,13 @@ class Peer(Logger, EventListener):
@handle_disconnect
async def main_loop(self):
async with self.taskgroup as group:
await group.spawn(self.htlc_switch())
await group.spawn(self._message_loop())
await group.spawn(self._query_gossip())
await group.spawn(self._process_gossip())
await group.spawn(self._send_own_gossip())
await group.spawn(self._forward_gossip())
if self.network.lngossip != self.lnworker:
await group.spawn(self.htlc_switch())
async def _process_gossip(self):
while True:
@@ -2467,7 +2466,6 @@ class Peer(Logger, EventListener):
exc_incorrect_or_unknown_pd: OnionRoutingFailure,
log_fail_reason: Callable[[str], None],
) -> bool:
from .lnworker import RecvMPPResolution
mpp_resolution = self.lnworker.check_mpp_status(
payment_secret=payment_secret,
short_channel_id=short_channel_id,
@@ -2482,7 +2480,7 @@ class Peer(Logger, EventListener):
elif mpp_resolution == RecvMPPResolution.FAILED:
log_fail_reason(f"mpp_resolution is FAILED")
raise exc_incorrect_or_unknown_pd
elif mpp_resolution == RecvMPPResolution.ACCEPTED:
elif mpp_resolution == RecvMPPResolution.COMPLETE:
return False
else:
raise Exception(f"unexpected {mpp_resolution=}")
@@ -2592,11 +2590,9 @@ class Peer(Logger, EventListener):
raise exc_incorrect_or_unknown_pd
preimage = self.lnworker.get_preimage(payment_hash)
expected_payment_secrets = [self.lnworker.get_payment_secret(htlc.payment_hash)]
if preimage:
expected_payment_secrets.append(derive_payment_secret_from_payment_preimage(preimage)) # legacy secret for old invoices
if payment_secret_from_onion not in expected_payment_secrets:
log_fail_reason(f'incorrect payment secret {payment_secret_from_onion.hex()} != {expected_payment_secrets[0].hex()}')
expected_payment_secret = self.lnworker.get_payment_secret(htlc.payment_hash)
if payment_secret_from_onion != expected_payment_secret:
log_fail_reason(f'incorrect payment secret {payment_secret_from_onion.hex()} != {expected_payment_secret.hex()}')
raise exc_incorrect_or_unknown_pd
invoice_msat = info.amount_msat
if channel_opening_fee:
+2 -2
View File
@@ -432,7 +432,7 @@ def sweep_our_ctx(
ctn=ctn)
for (direction, htlc), (ctx_output_idx, htlc_relative_idx) in htlc_to_ctx_output_idx_map.items():
if direction == RECEIVED:
if not chan.lnworker.is_accepted_mpp(htlc.payment_hash):
if not chan.lnworker.is_complete_mpp(htlc.payment_hash):
# do not redeem this, it might publish the preimage of an incomplete MPP
continue
preimage = chan.lnworker.get_preimage(htlc.payment_hash)
@@ -727,7 +727,7 @@ def sweep_their_ctx(
for (direction, htlc), (ctx_output_idx, htlc_relative_idx) in htlc_to_ctx_output_idx_map.items():
is_received_htlc = direction == RECEIVED
if not is_received_htlc and not is_revocation:
if not chan.lnworker.is_accepted_mpp(htlc.payment_hash):
if not chan.lnworker.is_complete_mpp(htlc.payment_hash):
# do not redeem this, it might publish the preimage of an incomplete MPP
continue
preimage = chan.lnworker.get_preimage(htlc.payment_hash)
+48 -24
View File
@@ -12,9 +12,10 @@ from functools import lru_cache
import electrum_ecc as ecc
from electrum_ecc import CURVE_ORDER, ecdsa_sig64_from_der_sig
from electrum_ecc.util import bip340_tagged_hash
import dataclasses
import attr
from .util import bfh, UserFacingException, list_enabled_bits
from .util import bfh, UserFacingException, list_enabled_bits, is_hex_str
from .util import ShortID as ShortChannelID, format_short_id as format_short_channel_id
from .crypto import sha256, pw_decode_with_version_and_mac
@@ -22,7 +23,8 @@ from .transaction import (
Transaction, PartialTransaction, PartialTxInput, TxOutpoint, PartialTxOutput, opcodes, OPPushDataPubkey
)
from . import bitcoin, crypto, transaction, descriptor, segwit_addr
from .bitcoin import redeem_script_to_address, address_to_script, construct_witness, construct_script
from .bitcoin import redeem_script_to_address, address_to_script, construct_witness, \
construct_script, NLOCKTIME_BLOCKHEIGHT_MAX
from .i18n import _
from .bip32 import BIP32Node, BIP32_PRIME
from .transaction import BCDataStream, OPPushDataGeneric
@@ -1825,19 +1827,6 @@ def validate_features(features: int) -> LnFeatures:
return features
def derive_payment_secret_from_payment_preimage(payment_preimage: bytes) -> bytes:
"""Returns secret to be put into invoice.
Derivation is deterministic, based on the preimage.
Crucially the payment_hash must be derived in an independent way from this.
"""
# Note that this could be random data too, but then we would need to store it.
# We derive it identically to clightning, so that we cannot be distinguished:
# https://github.com/ElementsProject/lightning/blob/faac4b28adee5221e83787d64cd5d30b16b62097/lightningd/invoice.c#L115
modified = bytearray(payment_preimage)
modified[0] ^= 1
return sha256(bytes(modified))
def get_compressed_pubkey_from_bech32(bech32_pubkey: str) -> bytes:
decoded_bech32 = segwit_addr.bech32_decode(bech32_pubkey)
hrp = decoded_bech32.hrp
@@ -1908,26 +1897,61 @@ NUM_MAX_HOPS_IN_PAYMENT_PATH = 20
NUM_MAX_EDGES_IN_PAYMENT_PATH = NUM_MAX_HOPS_IN_PAYMENT_PATH
@attr.s(frozen=True)
@dataclasses.dataclass(frozen=True, kw_only=True)
class UpdateAddHtlc:
amount_msat = attr.ib(type=int, kw_only=True)
payment_hash = attr.ib(type=bytes, kw_only=True, converter=hex_to_bytes, repr=lambda val: val.hex())
cltv_abs = attr.ib(type=int, kw_only=True)
timestamp = attr.ib(type=int, kw_only=True)
htlc_id = attr.ib(type=int, kw_only=True, default=None)
amount_msat: int
payment_hash: bytes
cltv_abs: int
htlc_id: Optional[int] = dataclasses.field(default=None)
timestamp: int = dataclasses.field(default_factory=lambda: int(time.time()))
@staticmethod
@stored_in('adds', tuple)
def from_tuple(amount_msat, payment_hash, cltv_abs, htlc_id, timestamp) -> 'UpdateAddHtlc':
def from_tuple(amount_msat, rhash, cltv_abs, htlc_id, timestamp) -> 'UpdateAddHtlc':
return UpdateAddHtlc(
amount_msat=amount_msat,
payment_hash=payment_hash,
payment_hash=bytes.fromhex(rhash),
cltv_abs=cltv_abs,
htlc_id=htlc_id,
timestamp=timestamp)
def to_json(self):
return self.amount_msat, self.payment_hash, self.cltv_abs, self.htlc_id, self.timestamp
self._validate()
return dataclasses.astuple(self)
def _validate(self):
assert isinstance(self.amount_msat, int), self.amount_msat
assert isinstance(self.payment_hash, bytes) and len(self.payment_hash) == 32
assert isinstance(self.cltv_abs, int) and self.cltv_abs <= NLOCKTIME_BLOCKHEIGHT_MAX, self.cltv_abs
assert isinstance(self.htlc_id, int) or self.htlc_id is None, self.htlc_id
assert isinstance(self.timestamp, int), self.timestamp
def __post_init__(self):
self._validate()
# Note: these states are persisted in the wallet file.
# Do not modify them without performing a wallet db upgrade
class RecvMPPResolution(IntEnum):
WAITING = 0
EXPIRED = 1
COMPLETE = 2
FAILED = 3
class ReceivedMPPStatus(NamedTuple):
resolution: RecvMPPResolution
expected_msat: int
htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]]
@staticmethod
@stored_in('received_mpp_htlcs', tuple)
def from_tuple(resolution, expected_msat, htlc_list) -> 'ReceivedMPPStatus':
htlc_set = set([(ShortChannelID(bytes.fromhex(scid)), UpdateAddHtlc.from_tuple(*x)) for (scid, x) in htlc_list])
return ReceivedMPPStatus(
resolution=RecvMPPResolution(resolution),
expected_msat=expected_msat,
htlc_set=htlc_set)
class OnionFailureCodeMetaFlag(IntFlag):
+44 -33
View File
@@ -20,6 +20,7 @@ import concurrent
from concurrent import futures
import urllib.parse
import itertools
import dataclasses
import aiohttp
import dns.asyncresolver
@@ -67,7 +68,8 @@ from .lnutil import (
LnKeyFamily, LOCAL, REMOTE, MIN_FINAL_CLTV_DELTA_FOR_INVOICE, SENT, RECEIVED, HTLCOwner, UpdateAddHtlc, LnFeatures,
ShortChannelID, HtlcLog, NoPathFound, InvalidGossipMsg, FeeBudgetExceeded, ImportedChannelBackupStorage,
OnchainChannelBackupStorage, ln_compare_features, IncompatibleLightningFeatures, PaymentFeeBudget,
NBLOCK_CLTV_DELTA_TOO_FAR_INTO_FUTURE, GossipForwardingMessage, MIN_FUNDING_SAT
NBLOCK_CLTV_DELTA_TOO_FAR_INTO_FUTURE, GossipForwardingMessage, MIN_FUNDING_SAT,
RecvMPPResolution, ReceivedMPPStatus,
)
from .lnonion import decode_onion_error, OnionFailureCode, OnionRoutingFailure, OnionPacket
from .lnmsg import decode_msg
@@ -106,35 +108,24 @@ class PaymentDirection(IntEnum):
FORWARDING = 3
class PaymentInfo(NamedTuple):
@dataclasses.dataclass(frozen=True, kw_only=True)
class PaymentInfo:
"""Information required to handle incoming htlcs for a payment request"""
payment_hash: bytes
amount_msat: Optional[int]
direction: int
status: int
def validate(self):
assert isinstance(self.payment_hash, bytes) and len(self.payment_hash) == 32
assert self.amount_msat is None or isinstance(self.amount_msat, int)
assert isinstance(self.direction, int)
assert isinstance(self.status, int)
# Note: these states are persisted in the wallet file.
# Do not modify them without performing a wallet db upgrade
class RecvMPPResolution(IntEnum):
WAITING = 0
EXPIRED = 1
ACCEPTED = 2
FAILED = 3
def __post_init__(self):
self.validate()
class ReceivedMPPStatus(NamedTuple):
resolution: RecvMPPResolution
expected_msat: int
htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]]
@stored_in('received_mpp_htlcs', tuple)
def from_tuple(resolution, expected_msat, htlc_list) -> 'ReceivedMPPStatus':
htlc_set = set([(ShortChannelID(bytes.fromhex(scid)), UpdateAddHtlc.from_tuple(*x)) for (scid, x) in htlc_list])
return ReceivedMPPStatus(
resolution=RecvMPPResolution(resolution),
expected_msat=expected_msat,
htlc_set=htlc_set)
SentHtlcKey = Tuple[bytes, ShortChannelID, int] # RHASH, scid, htlc_id
@@ -1567,7 +1558,12 @@ class LNWallet(LNWorker):
raise PaymentFailure(_("A payment was already initiated for this invoice"))
if payment_hash in self.get_payments(status='inflight'):
raise PaymentFailure(_("A previous attempt to pay this invoice did not clear"))
info = PaymentInfo(payment_hash, amount_to_pay, SENT, PR_UNPAID)
info = PaymentInfo(
payment_hash=payment_hash,
amount_msat=amount_to_pay,
direction=SENT,
status=PR_UNPAID,
)
self.save_payment_info(info)
self.wallet.set_label(key, lnaddr.get_description())
self.set_invoice_status(key, PR_INFLIGHT)
@@ -2302,7 +2298,12 @@ class LNWallet(LNWorker):
def create_payment_info(self, *, amount_msat: Optional[int], write_to_disk=True) -> bytes:
payment_preimage = os.urandom(32)
payment_hash = sha256(payment_preimage)
info = PaymentInfo(payment_hash, amount_msat, RECEIVED, PR_UNPAID)
info = PaymentInfo(
payment_hash=payment_hash,
amount_msat=amount_msat,
direction=RECEIVED,
status=PR_UNPAID,
)
self.save_preimage(payment_hash, payment_preimage, write_to_disk=False)
self.save_payment_info(info, write_to_disk=False)
if write_to_disk:
@@ -2376,12 +2377,22 @@ class LNWallet(LNWorker):
with self.lock:
if key in self.payment_info:
amount_msat, direction, status = self.payment_info[key]
return PaymentInfo(payment_hash, amount_msat, direction, status)
return PaymentInfo(
payment_hash=payment_hash,
amount_msat=amount_msat,
direction=direction,
status=status,
)
return None
def add_payment_info_for_hold_invoice(self, payment_hash: bytes, lightning_amount_sat: Optional[int]):
amount = lightning_amount_sat * 1000 if lightning_amount_sat else None
info = PaymentInfo(payment_hash, amount, RECEIVED, PR_UNPAID)
info = PaymentInfo(
payment_hash=payment_hash,
amount_msat=amount,
direction=RECEIVED,
status=PR_UNPAID,
)
self.save_payment_info(info, write_to_disk=False)
def register_hold_invoice(self, payment_hash: bytes, cb: Callable[[bytes], Awaitable[None]]):
@@ -2396,11 +2407,11 @@ class LNWallet(LNWorker):
if old_info := self.get_payment_info(payment_hash=info.payment_hash):
if info == old_info:
return # already saved
if info != old_info._replace(status=info.status):
if info != dataclasses.replace(old_info, status=info.status):
# differs more than in status. let's fail
raise Exception("payment_hash already in use")
key = info.payment_hash.hex()
self.payment_info[key] = info.amount_msat, info.direction, info.status
self.payment_info[key] = dataclasses.astuple(info)[1:] # drop the payment hash at index 0
if write_to_disk:
self.wallet.save_db()
@@ -2433,12 +2444,12 @@ class LNWallet(LNWorker):
payment_keys = [payment_key]
first_timestamp = min([self.get_first_timestamp_of_mpp(pkey) for pkey in payment_keys])
if self.get_payment_status(payment_hash) == PR_PAID:
mpp_resolution = RecvMPPResolution.ACCEPTED
mpp_resolution = RecvMPPResolution.COMPLETE
elif self.stopping_soon:
# try to time out pending HTLCs before shutting down
mpp_resolution = RecvMPPResolution.EXPIRED
elif all([self.is_mpp_amount_reached(pkey) for pkey in payment_keys]):
mpp_resolution = RecvMPPResolution.ACCEPTED
mpp_resolution = RecvMPPResolution.COMPLETE
elif time.time() - first_timestamp > self.MPP_EXPIRY:
mpp_resolution = RecvMPPResolution.EXPIRED
# save resolution, if any.
@@ -2486,10 +2497,10 @@ class LNWallet(LNWorker):
total, expected = amounts
return total >= expected
def is_accepted_mpp(self, payment_hash: bytes) -> bool:
def is_complete_mpp(self, payment_hash: bytes) -> bool:
payment_key = self._get_payment_key(payment_hash)
status = self.received_mpp_htlcs.get(payment_key.hex())
return status and status.resolution == RecvMPPResolution.ACCEPTED
return status and status.resolution == RecvMPPResolution.COMPLETE
def get_payment_mpp_amount_msat(self, payment_hash: bytes) -> Optional[int]:
"""Returns the received mpp amount for given payment hash."""
@@ -2577,7 +2588,7 @@ class LNWallet(LNWorker):
if info is None:
# if we are forwarding
return
info = info._replace(status=status)
info = dataclasses.replace(info, status=status)
self.save_payment_info(info)
def is_forwarded_htlc(self, htlc_key) -> Optional[str]:
+1 -6
View File
@@ -7,7 +7,7 @@ import unittest
from electrum.lnaddr import shorten_amount, unshorten_amount, LnAddr, lnencode, lndecode
from electrum.segwit_addr import bech32_encode, bech32_decode
from electrum import segwit_addr
from electrum.lnutil import UnknownEvenFeatureBits, derive_payment_secret_from_payment_preimage, LnFeatures, IncompatibleLightningFeatures
from electrum.lnutil import UnknownEvenFeatureBits, LnFeatures, IncompatibleLightningFeatures
from electrum import constants
from . import ElectrumTestCase
@@ -164,11 +164,6 @@ class TestBolt11(ElectrumTestCase):
self.assertEqual((1 << 9) + (1 << 15) + (1 << 99), lnaddr.get_tag('9'))
self.assertEqual(b"\x11" * 32, lnaddr.payment_secret)
def test_derive_payment_secret_from_payment_preimage(self):
preimage = bytes.fromhex("cc3fc000bdeff545acee53ada12ff96060834be263f77d645abbebc3a8d53b92")
self.assertEqual("bfd660b559b3f452c6bb05b8d2906f520c151c107b733863ed0cc53fc77021a8",
derive_payment_secret_from_payment_preimage(preimage).hex())
def test_validate_and_compare_features(self):
lnaddr = lndecode("lnbc25m1pvjluezpp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypqsp5zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zygsdq5vdhkven9v5sxyetpdees9q5sqqqqqqqqqqqqqqqpqsqvvh7ut50r00p3pg34ea68k7zfw64f8yx9jcdk35lh5ft8qdr8g4r0xzsdcrmcy9hex8un8d8yraewvhqc9l0sh8l0e0yvmtxde2z0hgpzsje5l")
lnaddr.validate_and_compare_features(LnFeatures((1 << 8) + (1 << 14) + (1 << 15)))
+1 -1
View File
@@ -547,7 +547,7 @@ class TestCommandsTestnet(ElectrumTestCase):
mock_htlc2.amount_msat = 5_500_000
mock_htlc_status = mock.Mock()
mock_htlc_status.htlc_set = [(None, mock_htlc1), (None, mock_htlc2)]
mock_htlc_status.resolution = RecvMPPResolution.ACCEPTED
mock_htlc_status.resolution = RecvMPPResolution.COMPLETE
payment_key = wallet.lnworker._get_payment_key(bytes.fromhex(payment_hash)).hex()
with mock.patch.dict(wallet.lnworker.received_mpp_htlcs, {payment_key: mock_htlc_status}):
+75 -60
View File
@@ -27,6 +27,7 @@ import os
import binascii
from pprint import pformat
import logging
import dataclasses
from electrum import bitcoin
from electrum import lnpeer
@@ -257,31 +258,34 @@ class TestChannel(ElectrumTestCase):
self.paymentPreimage = b"\x01" * 32
paymentHash = bitcoin.sha256(self.paymentPreimage)
self.htlc_dict = {
'payment_hash': paymentHash,
'amount_msat': one_bitcoin_in_msat,
'cltv_abs': 5,
'timestamp': 0,
}
self.htlc = UpdateAddHtlc(
payment_hash=paymentHash,
amount_msat=one_bitcoin_in_msat,
cltv_abs=5,
timestamp=0,
)
# First Alice adds the outgoing HTLC to her local channel's state
# update log. Then Alice sends this wire message over to Bob who adds
# this htlc to his remote state update log.
self.aliceHtlcIndex = self.alice_channel.add_htlc(self.htlc_dict).htlc_id
self.aliceHtlcIndex = self.alice_channel.add_htlc(self.htlc).htlc_id
self.assertNotEqual(list(self.alice_channel.hm.htlcs_by_direction(REMOTE, RECEIVED, 1).values()), [])
before = self.bob_channel.balance_minus_outgoing_htlcs(REMOTE)
beforeLocal = self.bob_channel.balance_minus_outgoing_htlcs(LOCAL)
self.bobHtlcIndex = self.bob_channel.receive_htlc(self.htlc_dict).htlc_id
self.bobHtlcIndex = self.bob_channel.receive_htlc(self.htlc).htlc_id
self.htlc = self.bob_channel.hm.log[REMOTE]['adds'][0]
def test_concurrent_reversed_payment(self):
self.htlc_dict['payment_hash'] = bitcoin.sha256(32 * b'\x02')
self.htlc_dict['amount_msat'] += 1000
self.bob_channel.add_htlc(self.htlc_dict)
self.alice_channel.receive_htlc(self.htlc_dict)
self.htlc = dataclasses.replace(
self.htlc,
payment_hash=bitcoin.sha256(32 * b'\x02'),
amount_msat=self.htlc.amount_msat + 1000,
)
self.bob_channel.add_htlc(self.htlc)
self.alice_channel.receive_htlc(self.htlc)
self.assertNumberNonAnchorOutputs(2, self.alice_channel.get_latest_commitment(LOCAL))
self.assertNumberNonAnchorOutputs(3, self.alice_channel.get_next_commitment(LOCAL))
@@ -561,9 +565,12 @@ class TestChannel(ElectrumTestCase):
tx6 = str(alice_channel.force_close_tx())
self.assertNotEqual(tx5, tx6)
self.htlc_dict['amount_msat'] *= 5
bob_index = bob_channel.add_htlc(self.htlc_dict).htlc_id
alice_index = alice_channel.receive_htlc(self.htlc_dict).htlc_id
self.htlc = dataclasses.replace(
self.htlc,
amount_msat=self.htlc.amount_msat * 5,
)
bob_index = bob_channel.add_htlc(self.htlc).htlc_id
alice_index = alice_channel.receive_htlc(self.htlc).htlc_id
force_state_transition(bob_channel, alice_channel)
@@ -662,18 +669,26 @@ class TestChannel(ElectrumTestCase):
self.alice_to_bob_fee_update(0)
force_state_transition(self.alice_channel, self.bob_channel)
self.htlc_dict['payment_hash'] = bitcoin.sha256(32 * b'\x02')
self.alice_channel.add_htlc(self.htlc_dict)
self.htlc_dict['payment_hash'] = bitcoin.sha256(32 * b'\x03')
self.alice_channel.add_htlc(self.htlc_dict)
self.htlc = dataclasses.replace(
self.htlc,
payment_hash=bitcoin.sha256(32 * b'\x02'),
)
self.alice_channel.add_htlc(self.htlc)
self.htlc = dataclasses.replace(
self.htlc,
payment_hash=bitcoin.sha256(32 * b'\x03'),
)
self.alice_channel.add_htlc(self.htlc)
# now there are three htlcs (one was in setUp)
# Alice now has an available balance of 2 BTC. We'll add a new HTLC of
# value 2 BTC, which should make Alice's balance negative (since she
# has to pay a commitment fee).
new = dict(self.htlc_dict)
new['amount_msat'] *= 2.5
new['payment_hash'] = bitcoin.sha256(32 * b'\x04')
new = dataclasses.replace(
self.htlc,
amount_msat=int(self.htlc.amount_msat * 2.5),
payment_hash=bitcoin.sha256(32 * b'\x04'),
)
with self.assertRaises(lnutil.PaymentFailure) as cm:
self.alice_channel.add_htlc(new)
self.assertIn('Not enough local balance', cm.exception.args[0])
@@ -822,14 +837,14 @@ class TestChanReserve(ElectrumTestCase):
# Bob: 5.0
paymentPreimage = b"\x01" * 32
paymentHash = bitcoin.sha256(paymentPreimage)
htlc_dict = {
'payment_hash': paymentHash,
'amount_msat': int(.5 * one_bitcoin_in_msat),
'cltv_abs': 5,
'timestamp': 0,
}
self.alice_channel.add_htlc(htlc_dict)
self.bob_channel.receive_htlc(htlc_dict)
htlc = UpdateAddHtlc(
payment_hash=paymentHash,
amount_msat=int(.5 * one_bitcoin_in_msat),
cltv_abs=5,
timestamp=0,
)
self.alice_channel.add_htlc(htlc)
self.bob_channel.receive_htlc(htlc)
# Force a state transition, making sure this HTLC is considered valid
# even though the channel reserves are not met.
force_state_transition(self.alice_channel, self.bob_channel)
@@ -847,10 +862,10 @@ class TestChanReserve(ElectrumTestCase):
# Alice: 4.5
# Bob: 5.0
with self.assertRaises(lnutil.PaymentFailure):
htlc_dict['payment_hash'] = bitcoin.sha256(32 * b'\x02')
self.bob_channel.add_htlc(htlc_dict)
htlc = dataclasses.replace(htlc, payment_hash=bitcoin.sha256(32 * b'\x02'))
self.bob_channel.add_htlc(htlc)
with self.assertRaises(lnutil.RemoteMisbehaving):
self.alice_channel.receive_htlc(htlc_dict)
self.alice_channel.receive_htlc(htlc)
def part2(self):
paymentPreimage = b"\x01" * 32
@@ -861,22 +876,22 @@ class TestChanReserve(ElectrumTestCase):
# Resulting balances:
# Alice: 1.5
# Bob: 9.5
htlc_dict = {
'payment_hash': paymentHash,
'amount_msat': int(3.5 * one_bitcoin_in_msat),
'cltv_abs': 5,
}
self.alice_channel.add_htlc(htlc_dict)
self.bob_channel.receive_htlc(htlc_dict)
htlc = UpdateAddHtlc(
payment_hash=paymentHash,
amount_msat=int(3.5 * one_bitcoin_in_msat),
cltv_abs=5,
)
self.alice_channel.add_htlc(htlc)
self.bob_channel.receive_htlc(htlc)
# Add a second HTLC of 1 BTC. This should fail because it will take
# Alice's balance all the way down to her channel reserve, but since
# she is the initiator the additional transaction fee makes her
# balance dip below.
htlc_dict['amount_msat'] = one_bitcoin_in_msat
htlc = dataclasses.replace(htlc, amount_msat=one_bitcoin_in_msat)
with self.assertRaises(lnutil.PaymentFailure):
self.alice_channel.add_htlc(htlc_dict)
self.alice_channel.add_htlc(htlc)
with self.assertRaises(lnutil.RemoteMisbehaving):
self.bob_channel.receive_htlc(htlc_dict)
self.bob_channel.receive_htlc(htlc)
def part3(self):
# Add a HTLC of 2 BTC to Alice, and the settle it.
@@ -885,14 +900,14 @@ class TestChanReserve(ElectrumTestCase):
# Bob: 7.0
paymentPreimage = b"\x01" * 32
paymentHash = bitcoin.sha256(paymentPreimage)
htlc_dict = {
'payment_hash': paymentHash,
'amount_msat': int(2 * one_bitcoin_in_msat),
'cltv_abs': 5,
'timestamp': 0,
}
alice_idx = self.alice_channel.add_htlc(htlc_dict).htlc_id
bob_idx = self.bob_channel.receive_htlc(htlc_dict).htlc_id
htlc = UpdateAddHtlc(
payment_hash=paymentHash,
amount_msat=int(2 * one_bitcoin_in_msat),
cltv_abs=5,
timestamp=0,
)
alice_idx = self.alice_channel.add_htlc(htlc).htlc_id
bob_idx = self.bob_channel.receive_htlc(htlc).htlc_id
force_state_transition(self.alice_channel, self.bob_channel)
self.check_bals(one_bitcoin_in_msat * 3
- self.alice_channel.get_next_fee(LOCAL),
@@ -906,9 +921,9 @@ class TestChanReserve(ElectrumTestCase):
# And now let Bob add an HTLC of 1 BTC. This will take Bob's balance
# all the way down to his channel reserve, but since he is not paying
# the fee this is okay.
htlc_dict['amount_msat'] = one_bitcoin_in_msat
self.bob_channel.add_htlc(htlc_dict)
self.alice_channel.receive_htlc(htlc_dict)
htlc = dataclasses.replace(htlc, amount_msat=one_bitcoin_in_msat)
self.bob_channel.add_htlc(htlc)
self.alice_channel.receive_htlc(htlc)
force_state_transition(self.alice_channel, self.bob_channel)
self.check_bals(one_bitcoin_in_msat * 3 \
- self.alice_channel.get_next_fee(LOCAL),
@@ -943,12 +958,12 @@ class TestDust(ElectrumTestCase):
# to pay for his htlc success transaction
below_dust_for_bob = dust_limit_bob - 1
htlc_amt = below_dust_for_bob + success_weight * (fee_per_kw // 1000)
htlc = {
'payment_hash': paymentHash,
'amount_msat': 1000 * htlc_amt,
'cltv_abs': 5, # consistent with channel policy
'timestamp': 0,
}
htlc = UpdateAddHtlc(
payment_hash=paymentHash,
amount_msat=1000 * htlc_amt,
cltv_abs=5, # consistent with channel policy
timestamp=0,
)
# add the htlc
alice_htlc_id = alice_channel.add_htlc(htlc).htlc_id
+6 -1
View File
@@ -559,7 +559,12 @@ class TestPeer(ElectrumTestCase):
payment_preimage = os.urandom(32)
if payment_hash is None:
payment_hash = sha256(payment_preimage)
info = PaymentInfo(payment_hash, amount_msat, RECEIVED, PR_UNPAID)
info = PaymentInfo(
payment_hash=payment_hash,
amount_msat=amount_msat,
direction=RECEIVED,
status=PR_UNPAID,
)
if payment_preimage:
w2.save_preimage(payment_hash, payment_preimage)
w2.save_payment_info(info)