lnworker: split LNWallet and LNWorker: LNWallet "has an" LNWorker

- LNWallet no longer "is-an" LNWorker, instead LNWallet "has-an" LNWorker
- the motivation is to make the unit tests nicer, and allow writing unit tests for more things
  - I hope this makes it possible to e.g. test lnsweep in the unit tests
  - some stuff we would previously have to write a regtest for, maybe we can write a unit test for, now
- in unit tests, MockLNWallet now
  - inherits LNWallet
  - the Wallet is no longer being mocked
This commit is contained in:
SomberNight
2025-12-17 15:16:05 +00:00
parent bdcd3f9c7c
commit 1006e8092f
17 changed files with 345 additions and 354 deletions
+5 -5
View File
@@ -1687,7 +1687,7 @@ class Commands(Logger):
arg:int:timeout:Timeout in seconds (default=20)
"""
lnworker = self.network.lngossip if gossip else wallet.lnworker
peer = await lnworker.add_peer(connection_string)
peer = await lnworker.lnpeermgr.add_peer(connection_string)
try:
await util.wait_for2(peer.initialized, timeout=LN_P2P_NETWORK_TIMEOUT)
except (CancelledError, Exception) as e:
@@ -1700,7 +1700,7 @@ class Commands(Logger):
"""Display statistics about lightninig gossip"""
lngossip = self.network.lngossip
channel_db = lngossip.channel_db
forwarded = dict([(key.hex(), p._num_gossip_messages_forwarded) for key, p in wallet.lnworker.peers.items()]),
forwarded = dict([(key.hex(), p._num_gossip_messages_forwarded) for key, p in wallet.lnworker.lnpeermgr.peers.items()]),
out = {
'received': {
'channel_announcements': lngossip._num_chan_ann,
@@ -1731,7 +1731,7 @@ class Commands(Logger):
'initialized': p.is_initialized(),
'features': str(LnFeatures(p.features)),
'channels': [c.funding_outpoint.to_str() for c in p.channels.values()],
} for p in lnworker.peers.values()]
} for p in lnworker.lnpeermgr.peers.values()]
@command('wpnl')
async def open_channel(self, connection_string, amount, push_amount=0, public=False, zeroconf=False, password=None, wallet: Abstract_Wallet = None):
@@ -1748,7 +1748,7 @@ class Commands(Logger):
raise UserFacingException("This wallet cannot create new channels")
funding_sat = satoshis(amount)
push_sat = satoshis(push_amount)
peer = await wallet.lnworker.add_peer(connection_string)
peer = await wallet.lnworker.lnpeermgr.add_peer(connection_string)
chan, funding_tx = await wallet.lnworker.open_channel_with_peer(
peer, funding_sat,
push_sat=push_sat,
@@ -2197,7 +2197,7 @@ class Commands(Logger):
pubkey = bfh(node_id)
assert len(pubkey) == 33, 'invalid node_id'
peer = wallet.lnworker.peers[pubkey]
peer = wallet.lnworker.lnpeermgr.peers[pubkey]
assert peer, 'node_id not a peer'
path = [pubkey, wallet.lnworker.node_keypair.pubkey]
+1 -1
View File
@@ -94,7 +94,7 @@ class QEChannelDetails(AuthMixin, QObject, QtEventListener):
def name(self) -> str:
if not self._channel:
return ''
return self._wallet.wallet.lnworker.get_node_alias(self._channel.node_id) or ''
return self._wallet.wallet.lnworker.lnpeermgr.get_node_alias(self._channel.node_id) or ''
@pyqtProperty(str, notify=channelChanged)
def pubkey(self) -> str:
+1 -1
View File
@@ -88,7 +88,7 @@ class QEChannelListModel(QAbstractListModel, QtEventListener):
item = {
'cid': lnc.channel_id.hex(),
'node_id': lnc.node_id.hex(),
'node_alias': lnworker.get_node_alias(lnc.node_id) or '',
'node_alias': lnworker.lnpeermgr.get_node_alias(lnc.node_id) or '',
'short_cid': lnc.short_id_for_GUI(),
'state': lnc.get_state_for_GUI(),
'state_code': int(lnc.get_state()),
+1 -1
View File
@@ -258,7 +258,7 @@ class QEInvoice(QObject, QtEventListener):
def name_for_node_id(self, node_id):
lnworker = self._wallet.wallet.lnworker
return (lnworker.get_node_alias(node_id) if lnworker else None) or node_id.hex()
return (lnworker.lnpeermgr.get_node_alias(node_id) if lnworker else None) or node_id.hex()
def set_effective_invoice(self, invoice: Invoice):
self._effectiveInvoice = invoice
+1 -1
View File
@@ -525,7 +525,7 @@ class QEWallet(AuthMixin, QObject, QtEventListener):
@pyqtProperty(int, notify=peersUpdated)
def lightningNumPeers(self):
if self.isLightning:
return self.wallet.lnworker.num_peers()
return self.wallet.lnworker.lnpeermgr.num_peers()
return 0
@pyqtSlot()
+1 -1
View File
@@ -98,7 +98,7 @@ class ChannelsList(MyTreeView):
labels[subject] = label
status = chan.get_state_for_GUI()
closed = chan.is_closed()
node_alias = self.lnworker.get_node_alias(chan.node_id) or chan.node_id.hex()
node_alias = self.lnworker.lnpeermgr.get_node_alias(chan.node_id) or chan.node_id.hex()
capacity_str = self.main_window.format_amount(chan.get_capacity(), whitespaces=True)
return {
self.Columns.SHORT_CHANID: chan.short_id_for_GUI(),
+1 -1
View File
@@ -62,7 +62,7 @@ class LightningDialog(QDialog, QtEventListener):
self.register_callbacks()
self.network.channel_db.update_counts() # trigger callback
if self.network.lngossip:
self.on_event_gossip_peers(self.network.lngossip.num_peers())
self.on_event_gossip_peers(self.network.lngossip.lnpeermgr.num_peers())
self.on_event_unknown_channels(len(self.network.lngossip.unknown_ids))
else:
self.num_peers.setText(_('Lightning gossip not active.'))
+5 -5
View File
@@ -417,9 +417,9 @@ class AbstractChannel(Logger, ABC):
if not self.is_funding_tx_mined(funding_height):
# funding tx is invalid (invalid amount or address) we need to get rid of the channel again
self.should_request_force_close = True
if self.lnworker and self.node_id in self.lnworker.peers:
if self.lnworker and (peer := self.lnworker.lnpeermgr.get_peer_by_pubkey(self.node_id)):
# reconnect to trigger force close request
self.lnworker.peers[self.node_id].close_and_cleanup()
peer.close_and_cleanup()
else:
# remove zeroconf flag as we are now confirmed, this is to prevent an electrum server causing
# us to remove a channel later in update_unfunded_state by omitting its funding tx
@@ -779,7 +779,7 @@ class Channel(AbstractChannel):
self,
state: 'StoredDict', *,
name=None,
lnworker=None, # None only in unittests
lnworker: 'LNWallet' = None, # None only in unittests
initial_feerate=None,
jit_opening_fee: Optional[int] = None,
):
@@ -1022,8 +1022,8 @@ class Channel(AbstractChannel):
elif self.is_static_remotekey_enabled():
our_payment_pubkey = self.config[LOCAL].payment_basepoint.pubkey
addr = make_commitment_output_to_remote_address(our_payment_pubkey, has_anchors=self.has_anchors())
if self.lnworker:
assert self.lnworker.wallet.is_mine(addr)
#if self.lnworker:
# assert self.lnworker.wallet.is_mine(addr) # FIXME xxxxx chan should be deterministic. NEEDS to be fixed before merge
return addr
def has_anchors(self) -> bool:
+4 -4
View File
@@ -81,7 +81,7 @@ class Peer(Logger, EventListener):
def __init__(
self,
lnworker: Union['LNGossip', 'LNWallet'],
lnworker: Union['LNWallet', 'LNGossip'],
pubkey: bytes,
transport: LNTransportBase,
*, is_channel_backup= False):
@@ -402,7 +402,7 @@ class Peer(Logger, EventListener):
if constants.net.rev_genesis_bytes() not in their_chains:
raise GracefulDisconnect(f"no common chain found with remote. (they sent: {their_chains})")
# all checks passed
self.lnworker.on_peer_successfully_established(self)
self.lnworker.lnpeermgr.on_peer_successfully_established(self)
self._received_init = True
self.maybe_set_initialized()
@@ -888,7 +888,7 @@ class Peer(Logger, EventListener):
self.transport.close()
except Exception:
pass
self.lnworker.peer_closed(self)
self.lnworker.lnpeermgr.peer_closed(self)
self.got_disconnected.set()
def is_shutdown_anysegwit(self):
@@ -3064,7 +3064,7 @@ class Peer(Logger, EventListener):
or not self.lnworker.is_payment_bundle_complete(payment_key):
# maybe this set is COMPLETE but the bundle is not yet completed, so the bundle can be considered WAITING
if int(time.time()) - first_htlc_timestamp > self.lnworker.MPP_EXPIRY \
or self.lnworker.stopping_soon:
or self.lnworker.lnpeermgr.stopping_soon:
_log_fail_reason(f"MPP TIMEOUT (> {self.lnworker.MPP_EXPIRY} sec)")
return OnionFailureCode.MPP_TIMEOUT, None, None
+188 -94
View File
@@ -39,7 +39,7 @@ from .util import (
profiler, OldTaskGroup, ESocksProxy, NetworkRetryManager, JsonRPCClient, NotEnoughFunds, EventListener,
event_listener, bfh, InvoiceError, resolve_dns_srv, is_ip_address, log_exceptions, ignore_exceptions,
make_aiohttp_session, random_shuffled_copy, is_private_netaddress,
UnrelatedTransactionException, LightningHistoryItem
UnrelatedTransactionException, LightningHistoryItem, get_asyncio_loop,
)
from .fee_policy import (
FeePolicy, FEERATE_FALLBACK_STATIC_FEE, FEE_LN_ETA_TARGET, FEE_LN_LOW_ETA_TARGET,
@@ -215,9 +215,15 @@ LNGOSSIP_FEATURES = (
)
class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
class LNPeerManager(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
def __init__(self, node_keypair, features: LnFeatures, *, config: 'SimpleConfig'):
def __init__(
self, node_keypair,
*,
lnwallet_or_lngossip: 'LNWallet | LNGossip',
features: LnFeatures,
config: 'SimpleConfig',
):
Logger.__init__(self)
NetworkRetryManager.__init__(
self,
@@ -228,6 +234,7 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
)
self.lock = threading.RLock()
self.node_keypair = node_keypair
self._lnwallet_or_lngossip = lnwallet_or_lngossip
self._peers = {} # type: Dict[bytes, Peer] # pubkey -> Peer # needs self.lock
self._channelless_incoming_peers = set() # type: Set[bytes] # node_ids # needs self.lock
self.taskgroup = OldTaskGroup()
@@ -252,7 +259,10 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
return self._peers.copy()
def channels_for_peer(self, node_id: bytes) -> Dict[bytes, Channel]:
return {}
return self._lnwallet_or_lngossip.channels_for_peer(node_id)
def get_peer_by_pubkey(self, pubkey: bytes) -> Optional[Peer]:
return self._peers.get(pubkey)
def get_node_alias(self, node_id: bytes) -> Optional[str]:
"""Returns the alias of the node, or None if unknown."""
@@ -269,7 +279,7 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
return node_alias
async def maybe_listen(self):
# FIXME: only one LNWorker can listen at a time (single port)
# FIXME: only one LNPeerManager can listen at a time (single port)
listen_addr = self.config.LIGHTNING_LISTEN
if listen_addr:
self.logger.info(f'lightning_listen enabled. will try to bind: {listen_addr!r}')
@@ -368,13 +378,17 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
return None
self._channelless_incoming_peers.add(node_id)
# checks done: we are adding this peer.
peer = Peer(self, node_id, transport)
peer = Peer(self._lnwallet_or_lngossip, node_id, transport)
assert node_id not in self._peers
self._peers[node_id] = peer
await self.taskgroup.spawn(peer.main_loop())
return peer
def peer_closed(self, peer: Peer) -> None:
if isinstance(self._lnwallet_or_lngossip, LNWallet):
for chan in self.channels_for_peer(peer.pubkey).values():
chan.peer_state = PeerState.DISCONNECTED
util.trigger_callback('channel', self._lnwallet_or_lngossip.wallet, chan)
with self.lock:
peer2 = self._peers.get(peer.pubkey)
if peer2 is peer:
@@ -392,14 +406,25 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
return True
return False
def start_network(self, network: 'Network'):
def start_network(
self, network: 'Network', *,
listen: bool = False,
maintain_random_peers: bool = False,
) -> None:
assert network
assert self.network is None, "already started"
self.network = network
self._add_peers_from_config()
asyncio.run_coroutine_threadsafe(self.main_loop(), self.network.asyncio_loop)
asyncio.run_coroutine_threadsafe(self.main_loop(), get_asyncio_loop())
if listen:
tg_coro = self.taskgroup.spawn(self.maybe_listen())
asyncio.run_coroutine_threadsafe(tg_coro, get_asyncio_loop())
if maintain_random_peers:
tg_coro = self.taskgroup.spawn(self._maintain_connectivity())
asyncio.run_coroutine_threadsafe(tg_coro, get_asyncio_loop())
async def stop(self):
self.stopping_soon = True
if self.listen_server:
self.listen_server.close()
self.unregister_callbacks()
@@ -410,7 +435,7 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
for host, port, pubkey in peer_list:
asyncio.run_coroutine_threadsafe(
self._add_peer(host, int(port), bfh(pubkey)),
self.network.asyncio_loop)
get_asyncio_loop())
def is_good_peer(self, peer: LNPeerAddr) -> bool:
# the purpose of this method is to filter peers that advertise the desired feature bits
@@ -573,8 +598,44 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
peer = await self._add_peer(host, port, node_id)
return peer
async def reestablish_peer_for_given_channel(self, chan: Channel) -> None:
await self.taskgroup.spawn(self._reestablish_peer_for_given_channel(chan))
class LNGossip(LNWorker):
@ignore_exceptions
@log_exceptions
async def _reestablish_peer_for_given_channel(self, chan: Channel) -> None:
now = time.time()
peer_addresses = []
if self.uses_trampoline():
addr = trampolines_by_id().get(chan.node_id)
if addr:
peer_addresses.append(addr)
else:
# will try last good address first, from gossip
last_good_addr = self.channel_db.get_last_good_address(chan.node_id)
if last_good_addr:
peer_addresses.append(last_good_addr)
# will try addresses for node_id from gossip
addrs_from_gossip = self.channel_db.get_node_addresses(chan.node_id) or []
for host, port, ts in addrs_from_gossip:
peer_addresses.append(LNPeerAddr(host, port, chan.node_id))
# will try addresses stored in channel storage
peer_addresses += list(chan.get_peer_addresses())
# Done gathering addresses.
# Now select first one that has not failed recently.
for peer in peer_addresses:
if self._can_retry_addr(peer, urgent=True, now=now):
await self._add_peer(peer.host, peer.port, peer.pubkey)
return
async def reestablish_peer_for_zero_conf_trusted_node(self) -> None:
if self.config.ZEROCONF_TRUSTED_NODE:
peer = LNPeerAddr.from_str(self.config.ZEROCONF_TRUSTED_NODE)
if self._can_retry_addr(peer, urgent=True):
await self._add_peer(peer.host, peer.port, peer.pubkey)
class LNGossip(Logger):
"""The LNGossip class is a separate, unannounced Lightning node with random id that is just querying
gossip from other nodes. The LNGossip node does not satisfy gossip queries, this is done by the
LNWallet class(es). LNWallets are the advertised nodes used for actual payments and only satisfy
@@ -584,11 +645,14 @@ class LNGossip(LNWorker):
max_age = 14*24*3600
def __init__(self, config: 'SimpleConfig'):
self.config = config
seed = os.urandom(32)
node = BIP32Node.from_rootseed(seed, xtype='standard')
xprv = node.to_xprv()
node_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.NODE_KEY)
LNWorker.__init__(self, node_keypair, LNGOSSIP_FEATURES, config=config)
Logger.__init__(self)
self.lnpeermgr = LNPeerManager(node_keypair, features=LNGOSSIP_FEATURES, config=self.config, lnwallet_or_lngossip=self)
self.taskgroup = OldTaskGroup()
self.unknown_ids = set()
self._forwarding_gossip = [] # type: List[GossipForwardingMessage]
self._last_gossip_batch_ts = 0 # type: int
@@ -600,15 +664,44 @@ class LNGossip(LNWorker):
self._num_chan_upd = 0
self._num_chan_upd_good = 0
@property
def features(self) -> 'LnFeatures':
return self.lnpeermgr.features
@property
def network(self) -> Optional['Network']:
return self.lnpeermgr.network
@property
def channel_db(self) -> 'ChannelDB':
return self.network.channel_db if self.network else None
def uses_trampoline(self) -> bool:
return not bool(self.channel_db)
async def main_loop(self):
self.logger.info("starting taskgroup.")
try:
async with self.taskgroup as group:
await group.spawn(asyncio.Event().wait) # run forever (until cancel)
except Exception as e:
self.logger.exception("taskgroup died.")
finally:
self.logger.info("taskgroup stopped.")
def start_network(self, network: 'Network'):
super().start_network(network)
asyncio.run_coroutine_threadsafe(self.main_loop(), get_asyncio_loop())
self.lnpeermgr.start_network(network, maintain_random_peers=True)
for coro in [
self._maintain_connectivity(),
self.maintain_db(),
self._maintain_forwarding_gossip()
]:
tg_coro = self.taskgroup.spawn(coro)
asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop)
asyncio.run_coroutine_threadsafe(tg_coro, get_asyncio_loop())
async def stop(self):
await self.lnpeermgr.stop()
await self.taskgroup.cancel_remaining()
async def maintain_db(self):
await self.channel_db.data_loaded.wait()
@@ -637,7 +730,7 @@ class LNGossip(LNWorker):
new = set(ids) - set(known)
self.unknown_ids.update(new)
util.trigger_callback('unknown_channels', len(self.unknown_ids))
util.trigger_callback('gossip_peers', self.num_peers())
util.trigger_callback('gossip_peers', self.lnpeermgr.num_peers())
util.trigger_callback('ln_gossip_sync_progress')
def get_ids_to_query(self) -> Sequence[bytes]:
@@ -652,7 +745,7 @@ class LNGossip(LNWorker):
"""Estimates the gossip synchronization process and returns the number
of synchronized channels, the total channels in the network and a
rescaled percentage of the synchronization process."""
if self.num_peers() == 0:
if self.lnpeermgr.num_peers() == 0:
return None, None, None
nchans_with_0p, nchans_with_1p, nchans_with_2p = self.channel_db.get_num_channels_partitioned_by_policy_count()
num_db_channels = nchans_with_0p + nchans_with_1p + nchans_with_2p
@@ -730,6 +823,9 @@ class LNGossip(LNWorker):
# flush the gossip queue so we don't forward old gossip after sync is complete
self.channel_db.get_forwarding_gossip_batch()
def channels_for_peer(self, node_id: bytes) -> Dict[bytes, Channel]:
return {}
class PaySession(Logger):
@@ -875,7 +971,7 @@ class PaySession(Logger):
return nhtlcs_resolved == self._nhtlcs_inflight
class LNWallet(LNWorker):
class LNWallet(Logger):
lnwatcher: Optional['LNWatcher']
MPP_EXPIRY = 120
@@ -884,7 +980,7 @@ class LNWallet(LNWorker):
MPP_SPLIT_PART_FRACTION = 0.2
MPP_SPLIT_PART_MINAMT_MSAT = 5_000_000
def __init__(self, wallet: 'Abstract_Wallet', xprv):
def __init__(self, wallet: 'Abstract_Wallet', xprv, *, features: LnFeatures = None):
self.wallet = wallet
self.config = wallet.config
self.db = wallet.db
@@ -894,16 +990,20 @@ class LNWallet(LNWorker):
self.payment_secret_key = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.PAYMENT_SECRET_KEY).privkey
self.funding_root_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.FUNDING_ROOT_KEY)
Logger.__init__(self)
features = LNWALLET_FEATURES
if self.config.ENABLE_ANCHOR_CHANNELS:
features |= LnFeatures.OPTION_ANCHORS_ZERO_FEE_HTLC_OPT
if self.config.ACCEPT_ZEROCONF_CHANNELS:
features |= LnFeatures.OPTION_ZEROCONF_OPT
if self.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS or self.config.EXPERIMENTAL_LN_FORWARD_TRAMPOLINE_PAYMENTS:
features |= LnFeatures.OPTION_ONION_MESSAGE_OPT
if self.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS and self.config.LIGHTNING_USE_GOSSIP:
features |= LnFeatures.GOSSIP_QUERIES_OPT # signal we have gossip to fetch
LNWorker.__init__(self, self.node_keypair, features, config=self.config)
if features is None:
features = LNWALLET_FEATURES
if self.config.ENABLE_ANCHOR_CHANNELS:
features |= LnFeatures.OPTION_ANCHORS_ZERO_FEE_HTLC_OPT
if self.config.ACCEPT_ZEROCONF_CHANNELS:
features |= LnFeatures.OPTION_ZEROCONF_OPT
if self.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS or self.config.EXPERIMENTAL_LN_FORWARD_TRAMPOLINE_PAYMENTS:
features |= LnFeatures.OPTION_ONION_MESSAGE_OPT
if self.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS and self.config.LIGHTNING_USE_GOSSIP:
features |= LnFeatures.GOSSIP_QUERIES_OPT # signal we have gossip to fetch
Logger.__init__(self)
self.lock = threading.RLock()
self.lnpeermgr = LNPeerManager(self.node_keypair, features=features, config=self.config, lnwallet_or_lngossip=self)
self.taskgroup = OldTaskGroup()
self.lnwatcher = LNWatcher(self)
self.lnrater: LNRater = None
# "RHASH:direction" -> amount_msat, status, min_final_cltv_delta, expiry_delay, creation_ts, invoice_features
@@ -997,6 +1097,21 @@ class LNWallet(LNWorker):
return any(chan.has_anchors() and not chan.is_closed()
for chan in self.channels.values())
@property
def features(self) -> 'LnFeatures':
return self.lnpeermgr.features
@property
def network(self) -> Optional['Network']:
return self.lnpeermgr.network
@property
def channel_db(self) -> 'ChannelDB':
return self.network.channel_db if self.network else None
def uses_trampoline(self) -> bool:
return not bool(self.channel_db)
@property
def channels(self) -> Mapping[bytes, Channel]:
"""Returns a read-only copy of channels."""
@@ -1060,29 +1175,39 @@ class LNWallet(LNWorker):
await watchtower.add_sweep_tx(outpoint, ctn, tx.inputs()[0].prevout.to_str(), tx.serialize())
self.watchtower_ctns[outpoint] = ctn
async def main_loop(self):
self.logger.info("starting taskgroup.")
try:
async with self.taskgroup as group:
await group.spawn(asyncio.Event().wait) # run forever (until cancel)
except Exception as e:
self.logger.exception("taskgroup died.")
finally:
self.logger.info("taskgroup stopped.")
def start_network(self, network: 'Network'):
super().start_network(network)
asyncio.run_coroutine_threadsafe(self.main_loop(), get_asyncio_loop())
self.lnpeermgr.start_network(network, listen=True)
self.lnwatcher.start_network(network)
self.swap_manager.start_network(network)
self.lnrater = LNRater(self, network)
self.onion_message_manager.start_network(network=network)
for coro in [
self.maybe_listen(),
self.lnwatcher.trigger_callbacks(), # shortcut (don't block) if funding tx locked and verified
self.reestablish_peers_and_channels(),
self.sync_with_remote_watchtower(),
]:
tg_coro = self.taskgroup.spawn(coro)
asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop)
asyncio.run_coroutine_threadsafe(tg_coro, get_asyncio_loop())
async def stop(self):
self.stopping_soon = True
if self.listen_server: # stop accepting new peers
self.listen_server.close()
self.lnpeermgr.stopping_soon = True
if self.lnpeermgr.listen_server: # stop accepting new peers
self.lnpeermgr.listen_server.close()
async with ignore_after(self.TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS):
await self.wait_for_received_pending_htlcs_to_get_removed()
await LNWorker.stop(self)
await self.lnpeermgr.stop()
if self.lnwatcher:
self.lnwatcher.stop()
self.lnwatcher = None
@@ -1090,30 +1215,25 @@ class LNWallet(LNWorker):
await self.swap_manager.stop()
if self.onion_message_manager:
await self.onion_message_manager.stop()
await self.taskgroup.cancel_remaining()
async def wait_for_received_pending_htlcs_to_get_removed(self):
assert self.stopping_soon is True
assert self.lnpeermgr.stopping_soon is True
# We try to fail pending MPP HTLCs, and wait a bit for them to get removed.
# Note: even without MPP, if we just failed/fulfilled an HTLC, it is good
# to wait a bit for it to become irrevocably removed.
# Note: we don't wait for *all htlcs* to get removed, only for those
# that we can already fail/fulfill. e.g. forwarded htlcs cannot be removed
async with OldTaskGroup() as group:
for peer in self.peers.values():
for peer in self.lnpeermgr.peers.values():
await group.spawn(peer.wait_one_htlc_switch_iteration())
while True:
if all(not peer.received_htlcs_pending_removal for peer in self.peers.values()):
if all(not peer.received_htlcs_pending_removal for peer in self.lnpeermgr.peers.values()):
break
async with OldTaskGroup(wait=any) as group:
for peer in self.peers.values():
for peer in self.lnpeermgr.peers.values():
await group.spawn(peer.received_htlc_removed_event.wait())
def peer_closed(self, peer):
for chan in self.channels_for_peer(peer.pubkey).values():
chan.peer_state = PeerState.DISCONNECTED
util.trigger_callback('channel', self.wallet, chan)
super().peer_closed(peer)
def get_payments(self, *, status=None) -> Mapping[bytes, List[HTLCWithStatus]]:
out = defaultdict(list)
for chan in self.channels.values():
@@ -1263,7 +1383,7 @@ class LNWallet(LNWorker):
node_ids = [chan.node_id for chan in self.channels.values() if not chan.is_closed()]
return node_ids
def channels_for_peer(self, node_id):
def channels_for_peer(self, node_id: bytes) -> Dict[bytes, Channel]:
assert type(node_id) is bytes
return {chan_id: chan for (chan_id, chan) in self.channels.items()
if chan.node_id == node_id}
@@ -1307,12 +1427,12 @@ class LNWallet(LNWorker):
await self.schedule_force_closing(chan.channel_id)
elif chan.get_state() == ChannelState.FUNDED:
peer = self._peers.get(chan.node_id)
peer = self.lnpeermgr.get_peer_by_pubkey(chan.node_id)
if peer and peer.is_initialized() and chan.peer_state == PeerState.GOOD:
peer.send_channel_ready(chan)
elif chan.get_state() == ChannelState.OPEN:
peer = self._peers.get(chan.node_id)
peer = self.lnpeermgr.get_peer_by_pubkey(chan.node_id)
if peer and peer.is_initialized() and chan.peer_state == PeerState.GOOD:
peer.maybe_update_fee(chan)
peer.maybe_send_announcement_signatures(chan)
@@ -1326,9 +1446,10 @@ class LNWallet(LNWorker):
await self.network.try_broadcasting(force_close_tx, 'force-close')
def get_peer_by_static_jit_scid_alias(self, scid_alias: bytes) -> Optional[Peer]:
for nodeid, peer in self.peers.items():
for nodeid, peer in self.lnpeermgr.peers.items():
if scid_alias == self._scid_alias_of_node(nodeid):
return peer
return None
def _scid_alias_of_node(self, nodeid: bytes) -> bytes:
# scid alias for just-in-time channels
@@ -1557,7 +1678,7 @@ class LNWallet(LNWorker):
password: str = None,
) -> Tuple[Channel, PartialTransaction]:
fut = asyncio.run_coroutine_threadsafe(self.add_peer(connect_str), self.network.asyncio_loop)
fut = asyncio.run_coroutine_threadsafe(self.lnpeermgr.add_peer(connect_str), get_asyncio_loop())
try:
peer = fut.result()
except concurrent.futures.TimeoutError:
@@ -1569,7 +1690,7 @@ class LNWallet(LNWorker):
push_sat=push_amt_sat,
public=public,
password=password)
fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
fut = asyncio.run_coroutine_threadsafe(coro, get_asyncio_loop())
try:
chan, funding_tx = fut.result()
except concurrent.futures.TimeoutError:
@@ -1860,7 +1981,7 @@ class LNWallet(LNWorker):
short_channel_id = shi.route[0].short_channel_id
chan = self.get_channel_by_short_id(short_channel_id)
assert chan, ShortChannelID(short_channel_id)
peer = self._peers.get(shi.route[0].node_id)
peer = self.lnpeermgr.get_peer_by_pubkey(shi.route[0].node_id)
if not peer:
raise PaymentFailure('Dropped peer')
await peer.initialized
@@ -2040,7 +2161,7 @@ class LNWallet(LNWorker):
# until trampoline is advertised in lnfeatures, check against hardcoded list
if is_hardcoded_trampoline(node_id):
return True
peer = self._peers.get(node_id)
peer = self.lnpeermgr.get_peer_by_pubkey(node_id)
if not peer:
return False
return (peer.their_features.supports(LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ECLAIR)
@@ -2794,7 +2915,7 @@ class LNWallet(LNWorker):
return
upstream_chan_scid, _ = deserialize_htlc_key(upstream_key)
upstream_chan = self.get_channel_by_short_id(upstream_chan_scid)
upstream_peer = self.peers.get(upstream_chan.node_id) if upstream_chan else None
upstream_peer = self.lnpeermgr.get_peer_by_pubkey(upstream_chan.node_id) if upstream_chan else None
if upstream_peer:
upstream_peer.downstream_htlc_resolved_event.set()
upstream_peer.downstream_htlc_resolved_event.clear()
@@ -3110,7 +3231,7 @@ class LNWallet(LNWorker):
# invalid connection string
return False
# only return True if we are connected to the zeroconf provider
return node_id in self.peers
return self.lnpeermgr.get_peer_by_pubkey(node_id) is not None
def _suggest_channels_for_rebalance(self, direction, amount_sat) -> Sequence[Tuple[Channel, int]]:
"""
@@ -3239,7 +3360,9 @@ class LNWallet(LNWorker):
async def close_channel(self, chan_id):
chan = self._channels[chan_id]
peer = self._peers[chan.node_id]
peer = self.lnpeermgr.get_peer_by_pubkey(chan.node_id)
if peer is None:
raise KeyError
return await peer.close_channel(chan_id)
def _force_close_channel(self, chan_id: bytes) -> Transaction:
@@ -3288,52 +3411,23 @@ class LNWallet(LNWorker):
util.trigger_callback('channels_updated', self.wallet)
util.trigger_callback('wallet_updated', self.wallet)
@ignore_exceptions
@log_exceptions
async def reestablish_peer_for_given_channel(self, chan: Channel) -> None:
now = time.time()
peer_addresses = []
if self.uses_trampoline():
addr = trampolines_by_id().get(chan.node_id)
if addr:
peer_addresses.append(addr)
else:
# will try last good address first, from gossip
last_good_addr = self.channel_db.get_last_good_address(chan.node_id)
if last_good_addr:
peer_addresses.append(last_good_addr)
# will try addresses for node_id from gossip
addrs_from_gossip = self.channel_db.get_node_addresses(chan.node_id) or []
for host, port, ts in addrs_from_gossip:
peer_addresses.append(LNPeerAddr(host, port, chan.node_id))
# will try addresses stored in channel storage
peer_addresses += list(chan.get_peer_addresses())
# Done gathering addresses.
# Now select first one that has not failed recently.
for peer in peer_addresses:
if self._can_retry_addr(peer, urgent=True, now=now):
await self._add_peer(peer.host, peer.port, peer.pubkey)
return
async def reestablish_peers_and_channels(self):
while True:
await asyncio.sleep(1)
if self.stopping_soon:
if self.lnpeermgr.stopping_soon:
return
if self.config.ZEROCONF_TRUSTED_NODE:
peer = LNPeerAddr.from_str(self.config.ZEROCONF_TRUSTED_NODE)
if self._can_retry_addr(peer, urgent=True):
await self._add_peer(peer.host, peer.port, peer.pubkey)
await self.lnpeermgr.reestablish_peer_for_zero_conf_trusted_node()
for chan in self.channels.values():
# reestablish
# note: we delegate filtering out uninteresting chans to this:
if not chan.should_try_to_reestablish_peer():
continue
peer = self._peers.get(chan.node_id, None)
peer = self.lnpeermgr.get_peer_by_pubkey(chan.node_id)
if peer:
# FIXME maybe this should be the responsibility of the peer itself, done in peer.main_loop:
await peer.taskgroup.spawn(peer.reestablish_channel(chan))
else:
await self.taskgroup.spawn(self.reestablish_peer_for_given_channel(chan))
await self.lnpeermgr.reestablish_peer_for_given_channel(chan)
def current_target_feerate_per_kw(self, *, has_anchors: bool) -> Optional[int]:
target: int = FEE_LN_MINIMUM_ETA_TARGET if has_anchors else FEE_LN_ETA_TARGET
@@ -3396,12 +3490,12 @@ class LNWallet(LNWorker):
async def request_force_close(self, channel_id: bytes, *, connect_str=None) -> None:
if chan := self.get_channel_by_id(channel_id):
peer = self._peers.get(chan.node_id)
peer = self.lnpeermgr.get_peer_by_pubkey(chan.node_id)
chan.should_request_force_close = True
if peer:
peer.close_and_cleanup() # to force a reconnect
elif connect_str:
peer = await self.add_peer(connect_str)
peer = await self.lnpeermgr.add_peer(connect_str)
await peer.request_force_close(channel_id)
elif channel_id in self.channel_backups:
await self._request_force_close_from_backup(channel_id)
@@ -3688,7 +3782,7 @@ class LNWallet(LNWorker):
f"maybe_forward_htlc. will forward HTLC: inc_chan={incoming_chan.short_channel_id}. inc_htlc={str(htlc)}. "
f"next_chan={next_chan.get_id_for_log()}.")
next_peer = self.peers.get(next_chan.node_id)
next_peer = self.lnpeermgr.get_peer_by_pubkey(next_chan.node_id)
if next_peer is None:
log_fail_reason(f"next_peer offline ({next_chan.node_id.hex()})")
raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data=outgoing_chan_upd_message)
@@ -3774,7 +3868,7 @@ class LNWallet(LNWorker):
raise OnionRoutingFailure(code=OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON, data=b'')
# do we have a connection to the node?
next_peer = self.peers.get(outgoing_node_id)
next_peer = self.lnpeermgr.get_peer_by_pubkey(outgoing_node_id)
if next_peer and next_peer.accepts_zeroconf():
self.logger.info(f'JIT: found next_peer')
for next_chan in next_peer.channels.values():
+12 -12
View File
@@ -162,7 +162,7 @@ def create_onion_message_route_to(lnwallet: 'LNWallet', node_id: bytes) -> Seque
): return path
# alt: dest is existing peer?
if lnwallet.peers.get(node_id):
if lnwallet.lnpeermgr.get_peer_by_pubkey(node_id):
return [PathEdge(short_channel_id=None, start_node=None, end_node=node_id)]
# if we have an address, pass it.
@@ -219,7 +219,7 @@ def send_onion_message_to(
encrypted_recipient_data=our_payload['encrypted_recipient_data']
)
peer = lnwallet.peers.get(recipient_data['next_node_id']['node_id'])
peer = lnwallet.lnpeermgr.get_peer_by_pubkey(recipient_data['next_node_id']['node_id'])
assert peer, 'next_node_id not a peer'
# blinding override?
@@ -241,7 +241,7 @@ def send_onion_message_to(
if not isinstance(remaining_blinded_path, list): # doesn't return list when num items == 1
remaining_blinded_path = [remaining_blinded_path]
peer = lnwallet.peers.get(introduction_point)
peer = lnwallet.lnpeermgr.get_peer_by_pubkey(introduction_point)
# if blinded path introduction point is our direct peer, no need to route-find
if peer:
# start of blinded path is our peer
@@ -250,7 +250,7 @@ def send_onion_message_to(
path = create_onion_message_route_to(lnwallet, introduction_point)
# first edge must be to our peer
peer = lnwallet.peers.get(path[0].end_node)
peer = lnwallet.lnpeermgr.get_peer_by_pubkey(path[0].end_node)
assert peer, 'first hop not a peer'
# last edge is to introduction point and start of blinded path. remove from route
@@ -321,7 +321,7 @@ def send_onion_message_to(
raise Exception('cannot send to myself')
hops_data = []
peer = lnwallet.peers.get(pubkey)
peer = lnwallet.lnpeermgr.get_peer_by_pubkey(pubkey)
if peer:
# destination is our direct peer, no need to route-find
@@ -330,7 +330,7 @@ def send_onion_message_to(
path = create_onion_message_route_to(lnwallet, pubkey)
# first edge must be to our peer
peer = lnwallet.peers.get(path[0].end_node)
peer = lnwallet.lnpeermgr.get_peer_by_pubkey(path[0].end_node)
assert peer, 'first hop not a peer'
hops_data = [
@@ -379,9 +379,9 @@ def get_blinded_reply_paths(
- reply_path introduction points are direct peers only (TODO: longer reply paths)"""
# TODO: build longer paths and/or add dummy hops to increase privacy
my_active_channels = [chan for chan in lnwallet.channels.values() if chan.is_active()]
my_onionmsg_channels = [chan for chan in my_active_channels if lnwallet.peers.get(chan.node_id) and
lnwallet.peers.get(chan.node_id).their_features.supports(LnFeatures.OPTION_ONION_MESSAGE_OPT)]
my_onionmsg_peers = [peer for peer in lnwallet.peers.values() if peer.their_features.supports(LnFeatures.OPTION_ONION_MESSAGE_OPT)]
my_onionmsg_channels = [chan for chan in my_active_channels if lnwallet.lnpeermgr.get_peer_by_pubkey(chan.node_id) and
lnwallet.lnpeermgr.get_peer_by_pubkey(chan.node_id).their_features.supports(LnFeatures.OPTION_ONION_MESSAGE_OPT)]
my_onionmsg_peers = [peer for peer in lnwallet.lnpeermgr.peers.values() if peer.their_features.supports(LnFeatures.OPTION_ONION_MESSAGE_OPT)]
result = []
mynodeid = lnwallet.node_keypair.pubkey
@@ -472,7 +472,7 @@ class OnionMessageManager(Logger):
try:
onion_packet_b = onion_packet.to_bytes()
next_peer = self.lnwallet.peers.get(node_id)
next_peer = self.lnwallet.lnpeermgr.get_peer_by_pubkey(node_id)
if not next_peer.their_features.supports(LnFeatures.OPTION_ONION_MESSAGE_OPT):
self.logger.debug('forward dropped, next peer is not ONION_MESSAGE capable')
@@ -528,7 +528,7 @@ class OnionMessageManager(Logger):
req.future.set_exception(copy.copy(e))
# NOTE: above, when passing the caught exception instance e directly it leads to GeneratorExit() in
if isinstance(e, NoRouteFound) and e.peer_address:
await self.lnwallet.add_peer(str(e.peer_address))
await self.lnwallet.lnpeermgr.add_peer(str(e.peer_address))
else:
self.logger.debug(f'resubmit {key=}')
self.send_queue.put_nowait((now() + self.REQUEST_REPLY_RETRY_DELAY, expires, key))
@@ -700,7 +700,7 @@ class OnionMessageManager(Logger):
'onion_message dropped (not forwarding due to lightning_forward_payments config option disabled')
return
# is next_node one of our peers?
next_peer = self.lnwallet.peers.get(next_node_id)
next_peer = self.lnwallet.lnpeermgr.get_peer_by_pubkey(next_node_id)
if not next_peer:
self.logger.info(f'next node {next_node_id.hex()} not a peer, dropping message')
return
+1 -1
View File
@@ -81,7 +81,7 @@ async def worker(work_queue: asyncio.Queue, results_queue: asyncio.Queue, flag):
print(f"worker connecting to {connect_str}")
try:
peer = await wallet.lnworker.add_peer(connect_str)
peer = await wallet.lnworker.lnpeermgr.add_peer(connect_str)
res = await util.wait_for2(peer.initialized, TIMEOUT)
if res:
if peer.features & flag == work['features'] & flag:
+2 -2
View File
@@ -3451,7 +3451,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
lightning_has_channels = (
self.lnworker and len([chan for chan in self.lnworker.channels.values() if chan.is_open()]) > 0
)
lightning_online = self.lnworker and self.lnworker.num_peers() > 0
lightning_online = self.lnworker and self.lnworker.lnpeermgr.num_peers() > 0
num_sats_can_receive = self.lnworker.num_sats_can_receive() if self.lnworker else 0
can_receive_lightning = self.lnworker and num_sats_can_receive > 0 and amount_sat <= num_sats_can_receive
try:
@@ -3459,7 +3459,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
except Exception:
zeroconf_nodeid = None
can_get_zeroconf_channel = (self.lnworker and self.config.ACCEPT_ZEROCONF_CHANNELS
and zeroconf_nodeid in self.lnworker.peers)
and self.lnworker.lnpeermgr.get_peer_by_pubkey(zeroconf_nodeid) is not None)
status = self.get_invoice_status(req)
if status == PR_EXPIRED:
+4 -3
View File
@@ -760,22 +760,23 @@ class TestCommandsTestnet(ElectrumTestCase):
# Mock the network and lnworker
mock_lnworker = mock.Mock()
mock_lnworker.lnpeermgr = mock.Mock()
w.lnworker = mock_lnworker
mock_peer = mock.Mock()
mock_peer.initialized = asyncio.Future()
connection_string = "test_node_id@127.0.0.1:9735"
called = False
async def lnworker_add_peer(*args, **kwargs):
async def lnpeermgr_add_peer(*args, **kwargs):
assert args[0] == connection_string
nonlocal called
called += 1
return mock_peer
mock_lnworker.add_peer = lnworker_add_peer
mock_lnworker.lnpeermgr.add_peer = lnpeermgr_add_peer
# check if add_peer times out if peer doesn't initialize (LN_P2P_NETWORK_TIMEOUT is 0.001s)
with self.assertRaises(UserFacingException):
await cmds.add_peer(connection_string=connection_string, wallet=w)
# check if add_peer called lnworker.add_peer
# check if add_peer called lnpeermgr.add_peer
assert called == 1
mock_peer.initialized = asyncio.Future()
+4 -2
View File
@@ -23,6 +23,7 @@
# (around commit 42de4400bff5105352d0552155f73589166d162b).
import unittest
from functools import lru_cache
from unittest import mock
import os
import binascii
@@ -40,7 +41,7 @@ from electrum.crypto import privkey_to_pubkey
from electrum.lnutil import SENT, LOCAL, REMOTE, RECEIVED, UpdateAddHtlc
from electrum.lnutil import effective_htlc_tx_weight
from electrum.logging import console_stderr_handler
from electrum.lnchannel import ChannelState
from electrum.lnchannel import ChannelState, Channel
from electrum.json_db import StoredDict
from electrum.coinchooser import PRNG
@@ -124,6 +125,7 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator,
return StoredDict(state, None)
@lru_cache()
def bip32(sequence):
node = bip32_utils.BIP32Node.from_rootseed(b"9dk", xtype='standard').subkey_at_private_derivation(sequence)
k = node.eckey.get_secret_bytes()
@@ -137,7 +139,7 @@ def create_test_channels(*, feerate=6000, local_msat=None, remote_msat=None,
alice_pubkey=b"\x01"*33, bob_pubkey=b"\x02"*33, random_seed=None,
anchor_outputs=False,
local_max_inflight=None, remote_max_inflight=None,
max_accepted_htlcs=5):
max_accepted_htlcs=5) -> tuple[Channel, Channel]:
if random_seed is None: # needed for deterministic randomness
random_seed = os.urandom(32)
random_gen = PRNG(random_seed)
+103 -210
View File
@@ -10,6 +10,7 @@ from collections import defaultdict
import logging
import concurrent
from concurrent import futures
from functools import lru_cache
from unittest import mock
from typing import Iterable, NamedTuple, Tuple, List, Dict, Sequence
from types import MappingProxyType
@@ -24,6 +25,7 @@ import electrum.trampoline
from electrum import bitcoin
from electrum import util
from electrum import constants
from electrum import bip32
from electrum.network import Network
from electrum import simple_config, lnutil
from electrum.lnaddr import lnencode, LnAddr, lndecode
@@ -37,7 +39,7 @@ from electrum.lnutil import Keypair, PaymentFailure, LnFeatures, HTLCOwner, Paym
from electrum.lnchannel import ChannelState, PeerState, Channel
from electrum.lnrouter import LNPathFinder, PathEdge, LNPathInconsistent
from electrum.channel_db import ChannelDB
from electrum.lnworker import LNWallet, NoPathFound, SentHtlcInfo, PaySession
from electrum.lnworker import LNWallet, NoPathFound, SentHtlcInfo, PaySession, LNPeerManager
from electrum.lnmsg import encode_msg, decode_msg
from electrum import lnmsg
from electrum.logging import console_stderr_handler, Logger
@@ -49,10 +51,11 @@ from electrum.interface import GracefulDisconnect
from electrum.simple_config import SimpleConfig
from electrum.fee_policy import FeeTimeEstimates, FEE_ETA_TARGETS
from electrum.mpp_split import split_amount_normal
from electrum.wallet import Abstract_Wallet
from .test_lnchannel import create_test_channels
from .test_bitcoin import needs_test_with_all_chacha20_implementations
from . import ElectrumTestCase
from . import ElectrumTestCase, restore_wallet_from_text__for_unittest
def keypair():
@@ -62,9 +65,6 @@ def keypair():
privkey=priv)
return k1
@contextmanager
def noop_lock():
yield
class MockNetwork:
def __init__(self, tx_queue, *, config: SimpleConfig):
@@ -120,144 +120,100 @@ class MockADB:
def get_local_height(self):
return self._blockchain.height()
class MockWallet:
receive_requests = {}
adb = MockADB()
def get_invoice(self, key):
pass
def get_request(self, key):
pass
def get_key_for_receive_request(self, x):
pass
def set_label(self, x, y):
pass
def save_db(self):
pass
def is_lightning_backup(self):
return False
def is_mine(self, addr):
return True
def get_fingerprint(self):
return ''
def get_new_sweep_address_for_channel(self):
# note: sweep is not tested here, only in regtest
return "tb1qqu5newtapamjchgxf0nty6geuykhvwas45q4q4"
def is_up_to_date(self):
return True
class MockLNGossip:
def get_sync_progress_estimate(self):
return None, None, None
class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
class MockLNPeerManager(LNPeerManager):
def __init__(
self,
*,
node_keypair,
config: SimpleConfig,
features: LnFeatures,
lnwallet: LNWallet,
network: 'MockNetwork',
):
LNPeerManager.__init__(
self,
node_keypair=node_keypair,
lnwallet_or_lngossip=lnwallet,
features=features,
config=config,
)
self.network = network
@lru_cache()
def _bip32_from_name(name: str) -> bip32.BIP32Node:
# note: unlike a serialized xprv, the bip32 node can be cached easily,
# as it does not depend on constant.net (testnet/mainnet) network bytes
sequence = [ord(c) for c in name]
bip32_node = bip32.BIP32Node.from_rootseed(b"9dk", xtype='standard').subkey_at_private_derivation(sequence)
return bip32_node
class MockLNWallet(LNWallet):
MPP_EXPIRY = 2 # HTLC timestamps are cast to int, so this cannot be 1
PAYMENT_TIMEOUT = 120
TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 0
MPP_SPLIT_PART_FRACTION = 1 # this disables the forced splitting
MPP_SPLIT_PART_MINAMT_MSAT = 5_000_000
def __init__(self, *, local_keypair: Keypair, chans: Iterable['Channel'], tx_queue, name, has_anchors):
def __init__(self, *, tx_queue, name, has_anchors, ln_xprv: str = None):
self.name = name
Logger.__init__(self)
NetworkRetryManager.__init__(self, max_retry_delay_normal=1, init_retry_delay_normal=1)
self.node_keypair = local_keypair
self.payment_secret_key = os.urandom(32) # does not need to be deterministic in tests
self._user_dir = tempfile.mkdtemp(prefix="electrum-lnpeer-test-")
self.config = SimpleConfig({}, read_user_dir_function=lambda: self._user_dir)
self.network = MockNetwork(tx_queue, config=self.config)
self.taskgroup = OldTaskGroup()
self.config.ENABLE_ANCHOR_CHANNELS = has_anchors
self.config.INITIAL_TRAMPOLINE_FEE_LEVEL = 0
network = MockNetwork(tx_queue, config=self.config)
wallet = restore_wallet_from_text__for_unittest(
"9dk", path=None, passphrase=name, config=self.config)['wallet'] # type: Abstract_Wallet
wallet.is_up_to_date = lambda: True
wallet.adb.network = wallet.network = network
features = LnFeatures(0)
features |= LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT
features |= LnFeatures.OPTION_UPFRONT_SHUTDOWN_SCRIPT_OPT
features |= LnFeatures.VAR_ONION_OPT
features |= LnFeatures.PAYMENT_SECRET_OPT
features |= LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM
features |= LnFeatures.OPTION_CHANNEL_TYPE_OPT
features |= LnFeatures.OPTION_SCID_ALIAS_OPT
features |= LnFeatures.OPTION_STATIC_REMOTEKEY_OPT
if ln_xprv is None:
ln_xprv = _bip32_from_name(name).to_xprv()
LNWallet.__init__(self, wallet=wallet, xprv=ln_xprv, features=features)
self.lnpeermgr = MockLNPeerManager(
node_keypair=self.node_keypair,
config=self.config,
features=features,
lnwallet=self,
network=network,
)
self.lnwatcher = None
self.swap_manager = None
self.onion_message_manager = None
self.listen_server = None
self._channels = {chan.channel_id: chan for chan in chans}
self.payment_info = {}
self.logs = defaultdict(list)
self.wallet = MockWallet()
self.features = LnFeatures(0)
self.features |= LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT
self.features |= LnFeatures.OPTION_UPFRONT_SHUTDOWN_SCRIPT_OPT
self.features |= LnFeatures.VAR_ONION_OPT
self.features |= LnFeatures.PAYMENT_SECRET_OPT
self.features |= LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM
self.features |= LnFeatures.OPTION_CHANNEL_TYPE_OPT
self.features |= LnFeatures.OPTION_SCID_ALIAS_OPT
self.features |= LnFeatures.OPTION_STATIC_REMOTEKEY_OPT
self.config.ENABLE_ANCHOR_CHANNELS = has_anchors
for chan in chans:
chan.lnworker = self
self._peers = {} # bytes -> Peer
# used in tests
self.enable_htlc_settle = True
self.enable_htlc_forwarding = True
self.received_mpp_htlcs = dict()
self._paysessions = dict()
self.sent_htlcs_info = dict()
self.sent_buckets = defaultdict(set)
self.active_forwardings = {}
self.forwarding_failures = {}
self.inflight_payments = set()
self._preimages = {}
self.stopping_soon = False
self.downstream_to_upstream_htlc = {}
self.dont_expire_htlcs = {}
self.dont_settle_htlcs = {}
self.hold_invoice_callbacks = {}
self._payment_bundles_pkey_to_canon = {} # type: Dict[bytes, bytes]
self._payment_bundles_canon_to_pkeylist = {} # type: Dict[bytes, Sequence[bytes]]
self.config.INITIAL_TRAMPOLINE_FEE_LEVEL = 0
self._channel_sending_capacity_lock = asyncio.Lock()
self.logger.info(f"created LNWallet[{name}] with nodeID={local_keypair.pubkey.hex()}")
self.logger.info(f"created LNWallet[{name}] with nodeID={self.node_keypair.pubkey.hex()}")
def clear_invoices_cache(self):
pass
def _add_channel(self, chan: Channel):
self._channels[chan.channel_id] = chan
chan.lnworker = self
def get_invoice_status(self, key):
pass
@property
def lock(self):
return noop_lock()
@property
def channel_db(self):
return self.network.channel_db if self.network else None
def uses_trampoline(self):
return not bool(self.channel_db)
@property
def channels(self):
return self._channels
@property
def peers(self):
return self._peers
def get_channel_by_short_id(self, short_channel_id):
with self.lock:
for chan in self._channels.values():
if chan.short_channel_id == short_channel_id:
return chan
def channel_state_changed(self, chan):
pass
@LNWallet.features.setter
def features(self, value):
self.lnpeermgr.features = value
def save_channel(self, chan):
print("Ignoring channel save")
pass
#print("Ignoring channel save")
def diagnostic_name(self):
return self.name
@@ -290,69 +246,6 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
budget=PaymentFeeBudget.from_invoice_amount(invoice_amount_msat=amount_msat, config=self.config),
)]
get_payments = LNWallet.get_payments
get_payment_secret = LNWallet.get_payment_secret
get_payment_info = LNWallet.get_payment_info
save_payment_info = LNWallet.save_payment_info
set_invoice_status = LNWallet.set_invoice_status
set_request_status = LNWallet.set_request_status
set_payment_status = LNWallet.set_payment_status
get_payment_status = LNWallet.get_payment_status
htlc_fulfilled = LNWallet.htlc_fulfilled
htlc_failed = LNWallet.htlc_failed
save_preimage = LNWallet.save_preimage
get_preimage = LNWallet.get_preimage
create_route_for_single_htlc = LNWallet.create_route_for_single_htlc
create_routes_for_payment = LNWallet.create_routes_for_payment
_check_bolt11_invoice = LNWallet._check_bolt11_invoice
pay_to_route = LNWallet.pay_to_route
pay_to_node = LNWallet.pay_to_node
pay_invoice = LNWallet.pay_invoice
force_close_channel = LNWallet.force_close_channel
schedule_force_closing = LNWallet.schedule_force_closing
on_peer_successfully_established = LNWallet.on_peer_successfully_established
get_channel_by_id = LNWallet.get_channel_by_id
channels_for_peer = LNWallet.channels_for_peer
calc_routing_hints_for_invoice = LNWallet.calc_routing_hints_for_invoice
get_channels_for_receiving = LNWallet.get_channels_for_receiving
handle_error_code_from_failed_htlc = LNWallet.handle_error_code_from_failed_htlc
is_trampoline_peer = LNWallet.is_trampoline_peer
wait_for_received_pending_htlcs_to_get_removed = LNWallet.wait_for_received_pending_htlcs_to_get_removed
#on_event_proxy_set = LNWallet.on_event_proxy_set
_decode_channel_update_msg = LNWallet._decode_channel_update_msg
_handle_chanupd_from_failed_htlc = LNWallet._handle_chanupd_from_failed_htlc
is_forwarded_htlc = LNWallet.is_forwarded_htlc
notify_upstream_peer = LNWallet.notify_upstream_peer
_force_close_channel = LNWallet._force_close_channel
suggest_payment_splits = LNWallet.suggest_payment_splits
register_hold_invoice = LNWallet.register_hold_invoice
unregister_hold_invoice = LNWallet.unregister_hold_invoice
add_payment_info_for_hold_invoice = LNWallet.add_payment_info_for_hold_invoice
update_or_create_mpp_with_received_htlc = LNWallet.update_or_create_mpp_with_received_htlc
set_mpp_resolution = LNWallet.set_mpp_resolution
get_mpp_amounts = LNWallet.get_mpp_amounts
bundle_payments = LNWallet.bundle_payments
get_payment_bundle = LNWallet.get_payment_bundle
_get_payment_key = LNWallet._get_payment_key
save_forwarding_failure = LNWallet.save_forwarding_failure
get_forwarding_failure = LNWallet.get_forwarding_failure
maybe_cleanup_forwarding = LNWallet.maybe_cleanup_forwarding
current_target_feerate_per_kw = LNWallet.current_target_feerate_per_kw
current_low_feerate_per_kw_srk_channel = LNWallet.current_low_feerate_per_kw_srk_channel
create_onion_for_route = LNWallet.create_onion_for_route
maybe_forward_htlc_set = LNWallet.maybe_forward_htlc_set
_maybe_forward_htlc = LNWallet._maybe_forward_htlc
_maybe_forward_trampoline = LNWallet._maybe_forward_trampoline
_maybe_refuse_to_forward_htlc_that_corresponds_to_payreq_we_created = LNWallet._maybe_refuse_to_forward_htlc_that_corresponds_to_payreq_we_created
set_htlc_set_error = LNWallet.set_htlc_set_error
is_payment_bundle_complete = LNWallet.is_payment_bundle_complete
delete_payment_bundle = LNWallet.delete_payment_bundle
_process_htlc_log = LNWallet._process_htlc_log
_get_invoice_features = LNWallet._get_invoice_features
receive_requires_jit_channel = LNWallet.receive_requires_jit_channel
can_get_zeroconf_channel = LNWallet.can_get_zeroconf_channel
class MockTransport:
def __init__(self, name):
@@ -667,25 +560,24 @@ class TestPeerDirect(TestPeer):
def prepare_peers(
self, alice_channel: Channel, bob_channel: Channel,
*, k1: Keypair = None, k2: Keypair = None,
):
if k1 is None:
k1 = keypair()
if k2 is None:
k2 = keypair()
q1, q2 = asyncio.Queue(), asyncio.Queue()
w1 = MockLNWallet(tx_queue=q1, name=bob_channel.name, has_anchors=self.TEST_ANCHOR_CHANNELS)
w2 = MockLNWallet(tx_queue=q2, name=alice_channel.name, has_anchors=self.TEST_ANCHOR_CHANNELS)
k1 = w1.node_keypair
k2 = w2.node_keypair
alice_channel.node_id = k2.pubkey
bob_channel.node_id = k1.pubkey
alice_channel.storage['node_id'] = alice_channel.node_id
bob_channel.storage['node_id'] = bob_channel.node_id
t1, t2 = transport_pair(k1, k2, alice_channel.name, bob_channel.name)
q1, q2 = asyncio.Queue(), asyncio.Queue()
w1 = MockLNWallet(local_keypair=k1, chans=[alice_channel], tx_queue=q1, name=bob_channel.name, has_anchors=self.TEST_ANCHOR_CHANNELS)
w2 = MockLNWallet(local_keypair=k2, chans=[bob_channel], tx_queue=q2, name=alice_channel.name, has_anchors=self.TEST_ANCHOR_CHANNELS)
w1._add_channel(alice_channel)
w2._add_channel(bob_channel)
self._lnworkers_created.extend([w1, w2])
p1 = PeerInTests(w1, k2.pubkey, t1)
p2 = PeerInTests(w2, k1.pubkey, t2)
w1._peers[p1.pubkey] = p1
w2._peers[p2.pubkey] = p2
w1.lnpeermgr._peers[p1.pubkey] = p1
w2.lnpeermgr._peers[p2.pubkey] = p2
# mark_open won't work if state is already OPEN.
# so set it to FUNDED
alice_channel._state = ChannelState.FUNDED
@@ -790,10 +682,9 @@ class TestPeerDirect(TestPeer):
----sig-->
"""
chan_AB, chan_BA = create_test_channels()
k1, k2 = keypair(), keypair()
# note: we don't start peer.htlc_switch() so that the fake htlcs are left alone.
async def f():
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA, k1=k1, k2=k2)
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA)
async with OldTaskGroup() as group:
await group.spawn(p1._message_loop())
await group.spawn(p2._message_loop())
@@ -807,7 +698,7 @@ class TestPeerDirect(TestPeer):
await group.cancel_remaining()
# simulating disconnection. recreate transports.
self.logger.info("simulating disconnection. recreating transports.")
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA, k1=k1, k2=k2)
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA)
for chan in (chan_AB, chan_BA):
chan.peer_state = PeerState.DISCONNECTED
async with OldTaskGroup() as group:
@@ -846,10 +737,9 @@ class TestPeerDirect(TestPeer):
----rev-->
"""
chan_AB, chan_BA = create_test_channels()
k1, k2 = keypair(), keypair()
# note: we don't start peer.htlc_switch() so that the fake htlcs are left alone.
async def f():
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA, k1=k1, k2=k2)
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA)
async with OldTaskGroup() as group:
await group.spawn(p1._message_loop())
await group.spawn(p2._message_loop())
@@ -864,7 +754,7 @@ class TestPeerDirect(TestPeer):
await group.cancel_remaining()
# simulating disconnection. recreate transports.
self.logger.info("simulating disconnection. recreating transports.")
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA, k1=k1, k2=k2)
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA)
for chan in (chan_AB, chan_BA):
chan.peer_state = PeerState.DISCONNECTED
async with OldTaskGroup() as group:
@@ -1788,7 +1678,7 @@ class TestPeerDirect(TestPeer):
with self.assertRaises(NoPathFound) as e:
await w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr)
peer = w1.peers[route[0].node_id]
peer = w1.lnpeermgr._peers[route[0].node_id]
# AssertionError is ok since we shouldn't use old routes, and the
# route finding should fail when channel is closed
async def f():
@@ -2126,12 +2016,19 @@ class TestPeerDirect(TestPeer):
class TestPeerForwarding(TestPeer):
def prepare_chans_and_peers_in_graph(self, graph_definition) -> Graph:
keys = {k: keypair() for k in graph_definition}
workers = {} # type: Dict[str, MockLNWallet]
txs_queues = {k: asyncio.Queue() for k in graph_definition}
# create workers
for a, definition in graph_definition.items():
workers[a] = MockLNWallet(tx_queue=txs_queues[a], name=a, has_anchors=self.TEST_ANCHOR_CHANNELS)
self._lnworkers_created.extend(list(workers.values()))
keys = {name: w.node_keypair for name, w in workers.items()}
channels = {} # type: Dict[Tuple[str, str], Channel]
transports = {}
workers = {} # type: Dict[str, MockLNWallet]
peers = {}
# create channels
for a, definition in graph_definition.items():
for b, channel_def in definition.get('channels', {}).items():
@@ -2145,6 +2042,8 @@ class TestPeerForwarding(TestPeer):
anchor_outputs=self.TEST_ANCHOR_CHANNELS
)
channels[(a, b)], channels[(b, a)] = channel_ab, channel_ba
workers[a]._add_channel(channel_ab)
workers[b]._add_channel(channel_ba)
transport_ab, transport_ba = transport_pair(keys[a], keys[b], channel_ab.name, channel_ba.name)
transports[(a, b)], transports[(b, a)] = transport_ab, transport_ba
# set fees
@@ -2153,12 +2052,6 @@ class TestPeerForwarding(TestPeer):
channel_ba.forwarding_fee_proportional_millionths = channel_def['remote_fee_rate_millionths']
channel_ba.forwarding_fee_base_msat = channel_def['remote_base_fee_msat']
# create workers and peers
for a, definition in graph_definition.items():
channels_of_node = [c for k, c in channels.items() if k[0] == a]
workers[a] = MockLNWallet(local_keypair=keys[a], chans=channels_of_node, tx_queue=txs_queues[a], name=a, has_anchors=self.TEST_ANCHOR_CHANNELS)
self._lnworkers_created.extend(list(workers.values()))
# create peers
for ab in channels.keys():
peers[ab] = Peer(workers[ab[0]], keys[ab[1]].pubkey, transports[ab])
@@ -2167,7 +2060,7 @@ class TestPeerForwarding(TestPeer):
for a, w in workers.items():
for ab, peer_ab in peers.items():
if ab[0] == a:
w._peers[peer_ab.pubkey] = peer_ab
w.lnpeermgr._peers[peer_ab.pubkey] = peer_ab
# set forwarding properties
for a, definition in graph_definition.items():
+11 -10
View File
@@ -352,9 +352,8 @@ class TestOnionMessageManager(ElectrumTestCase):
async def test_request_and_reply(self):
n = MockNetwork()
k = keypair()
q1, q2 = asyncio.Queue(), asyncio.Queue()
lnw = MockLNWallet(local_keypair=k, chans=[], tx_queue=q1, name='test_request_and_reply', has_anchors=False)
lnw = MockLNWallet(tx_queue=q1, name='test_request_and_reply', has_anchors=False)
def slow(*args, **kwargs):
time.sleep(2*TIME_STEP)
@@ -369,10 +368,10 @@ class TestOnionMessageManager(ElectrumTestCase):
rkey1 = bfh('0102030405060708')
rkey2 = bfh('0102030405060709')
lnw.peers[self.alice.pubkey] = MockPeer(self.alice.pubkey)
lnw.peers[self.bob.pubkey] = MockPeer(self.bob.pubkey, on_send_message=slow)
lnw.peers[self.carol.pubkey] = MockPeer(self.carol.pubkey, on_send_message=partial(withreply, rkey1))
lnw.peers[self.dave.pubkey] = MockPeer(self.dave.pubkey, on_send_message=partial(slowwithreply, rkey2))
lnw.lnpeermgr._peers[self.alice.pubkey] = MockPeer(self.alice.pubkey)
lnw.lnpeermgr._peers[self.bob.pubkey] = MockPeer(self.bob.pubkey, on_send_message=slow)
lnw.lnpeermgr._peers[self.carol.pubkey] = MockPeer(self.carol.pubkey, on_send_message=partial(withreply, rkey1))
lnw.lnpeermgr._peers[self.dave.pubkey] = MockPeer(self.dave.pubkey, on_send_message=partial(slowwithreply, rkey2))
t = OnionMessageManager(lnw)
t.start_network(network=n)
@@ -401,7 +400,8 @@ class TestOnionMessageManager(ElectrumTestCase):
async def test_forward(self):
n = MockNetwork()
q1 = asyncio.Queue()
lnw = MockLNWallet(local_keypair=self.alice, chans=[], tx_queue=q1, name='alice', has_anchors=False)
lnw = MockLNWallet(tx_queue=q1, name='alice', has_anchors=False)
lnw.node_keypair = self.alice
self.was_sent = False
@@ -414,8 +414,8 @@ class TestOnionMessageManager(ElectrumTestCase):
self.assertEqual(message_type, 'onion_message')
self.assertEqual(payload['onion_message_packet'], kwargs['onion_message_packet'])
lnw.peers[self.bob.pubkey] = MockPeer(self.bob.pubkey, on_send_message=partial(on_send, 'bob'))
lnw.peers[self.carol.pubkey] = MockPeer(self.carol.pubkey, on_send_message=partial(on_send, 'carol'))
lnw.lnpeermgr._peers[self.bob.pubkey] = MockPeer(self.bob.pubkey, on_send_message=partial(on_send, 'bob'))
lnw.lnpeermgr._peers[self.carol.pubkey] = MockPeer(self.carol.pubkey, on_send_message=partial(on_send, 'carol'))
t = OnionMessageManager(lnw)
t.start_network(network=n)
@@ -438,7 +438,8 @@ class TestOnionMessageManager(ElectrumTestCase):
async def test_receive_unsolicited(self):
n = MockNetwork()
q1 = asyncio.Queue()
lnw = MockLNWallet(local_keypair=self.dave, chans=[], tx_queue=q1, name='dave', has_anchors=False)
lnw = MockLNWallet(tx_queue=q1, name='dave', has_anchors=False)
lnw.node_keypair = self.dave
t = OnionMessageManager(lnw)
t.start_network(network=n)