lnworker: split LNWallet and LNWorker: LNWallet "has an" LNWorker

- LNWallet no longer "is-an" LNWorker, instead LNWallet "has-an" LNWorker
- the motivation is to make the unit tests nicer, and allow writing unit tests for more things
  - I hope this makes it possible to e.g. test lnsweep in the unit tests
  - some stuff we would previously have to write a regtest for, maybe we can write a unit test for, now
- in unit tests, MockLNWallet now
  - inherits LNWallet
  - the Wallet is no longer being mocked
This commit is contained in:
SomberNight
2025-12-17 15:16:05 +00:00
parent bdcd3f9c7c
commit 1006e8092f
17 changed files with 345 additions and 354 deletions
+5 -5
View File
@@ -1687,7 +1687,7 @@ class Commands(Logger):
arg:int:timeout:Timeout in seconds (default=20)
"""
lnworker = self.network.lngossip if gossip else wallet.lnworker
peer = await lnworker.add_peer(connection_string)
peer = await lnworker.lnpeermgr.add_peer(connection_string)
try:
await util.wait_for2(peer.initialized, timeout=LN_P2P_NETWORK_TIMEOUT)
except (CancelledError, Exception) as e:
@@ -1700,7 +1700,7 @@ class Commands(Logger):
"""Display statistics about lightninig gossip"""
lngossip = self.network.lngossip
channel_db = lngossip.channel_db
forwarded = dict([(key.hex(), p._num_gossip_messages_forwarded) for key, p in wallet.lnworker.peers.items()]),
forwarded = dict([(key.hex(), p._num_gossip_messages_forwarded) for key, p in wallet.lnworker.lnpeermgr.peers.items()]),
out = {
'received': {
'channel_announcements': lngossip._num_chan_ann,
@@ -1731,7 +1731,7 @@ class Commands(Logger):
'initialized': p.is_initialized(),
'features': str(LnFeatures(p.features)),
'channels': [c.funding_outpoint.to_str() for c in p.channels.values()],
} for p in lnworker.peers.values()]
} for p in lnworker.lnpeermgr.peers.values()]
@command('wpnl')
async def open_channel(self, connection_string, amount, push_amount=0, public=False, zeroconf=False, password=None, wallet: Abstract_Wallet = None):
@@ -1748,7 +1748,7 @@ class Commands(Logger):
raise UserFacingException("This wallet cannot create new channels")
funding_sat = satoshis(amount)
push_sat = satoshis(push_amount)
peer = await wallet.lnworker.add_peer(connection_string)
peer = await wallet.lnworker.lnpeermgr.add_peer(connection_string)
chan, funding_tx = await wallet.lnworker.open_channel_with_peer(
peer, funding_sat,
push_sat=push_sat,
@@ -2197,7 +2197,7 @@ class Commands(Logger):
pubkey = bfh(node_id)
assert len(pubkey) == 33, 'invalid node_id'
peer = wallet.lnworker.peers[pubkey]
peer = wallet.lnworker.lnpeermgr.peers[pubkey]
assert peer, 'node_id not a peer'
path = [pubkey, wallet.lnworker.node_keypair.pubkey]
+1 -1
View File
@@ -94,7 +94,7 @@ class QEChannelDetails(AuthMixin, QObject, QtEventListener):
def name(self) -> str:
if not self._channel:
return ''
return self._wallet.wallet.lnworker.get_node_alias(self._channel.node_id) or ''
return self._wallet.wallet.lnworker.lnpeermgr.get_node_alias(self._channel.node_id) or ''
@pyqtProperty(str, notify=channelChanged)
def pubkey(self) -> str:
+1 -1
View File
@@ -88,7 +88,7 @@ class QEChannelListModel(QAbstractListModel, QtEventListener):
item = {
'cid': lnc.channel_id.hex(),
'node_id': lnc.node_id.hex(),
'node_alias': lnworker.get_node_alias(lnc.node_id) or '',
'node_alias': lnworker.lnpeermgr.get_node_alias(lnc.node_id) or '',
'short_cid': lnc.short_id_for_GUI(),
'state': lnc.get_state_for_GUI(),
'state_code': int(lnc.get_state()),
+1 -1
View File
@@ -258,7 +258,7 @@ class QEInvoice(QObject, QtEventListener):
def name_for_node_id(self, node_id):
lnworker = self._wallet.wallet.lnworker
return (lnworker.get_node_alias(node_id) if lnworker else None) or node_id.hex()
return (lnworker.lnpeermgr.get_node_alias(node_id) if lnworker else None) or node_id.hex()
def set_effective_invoice(self, invoice: Invoice):
self._effectiveInvoice = invoice
+1 -1
View File
@@ -525,7 +525,7 @@ class QEWallet(AuthMixin, QObject, QtEventListener):
@pyqtProperty(int, notify=peersUpdated)
def lightningNumPeers(self):
if self.isLightning:
return self.wallet.lnworker.num_peers()
return self.wallet.lnworker.lnpeermgr.num_peers()
return 0
@pyqtSlot()
+1 -1
View File
@@ -98,7 +98,7 @@ class ChannelsList(MyTreeView):
labels[subject] = label
status = chan.get_state_for_GUI()
closed = chan.is_closed()
node_alias = self.lnworker.get_node_alias(chan.node_id) or chan.node_id.hex()
node_alias = self.lnworker.lnpeermgr.get_node_alias(chan.node_id) or chan.node_id.hex()
capacity_str = self.main_window.format_amount(chan.get_capacity(), whitespaces=True)
return {
self.Columns.SHORT_CHANID: chan.short_id_for_GUI(),
+1 -1
View File
@@ -62,7 +62,7 @@ class LightningDialog(QDialog, QtEventListener):
self.register_callbacks()
self.network.channel_db.update_counts() # trigger callback
if self.network.lngossip:
self.on_event_gossip_peers(self.network.lngossip.num_peers())
self.on_event_gossip_peers(self.network.lngossip.lnpeermgr.num_peers())
self.on_event_unknown_channels(len(self.network.lngossip.unknown_ids))
else:
self.num_peers.setText(_('Lightning gossip not active.'))
+5 -5
View File
@@ -417,9 +417,9 @@ class AbstractChannel(Logger, ABC):
if not self.is_funding_tx_mined(funding_height):
# funding tx is invalid (invalid amount or address) we need to get rid of the channel again
self.should_request_force_close = True
if self.lnworker and self.node_id in self.lnworker.peers:
if self.lnworker and (peer := self.lnworker.lnpeermgr.get_peer_by_pubkey(self.node_id)):
# reconnect to trigger force close request
self.lnworker.peers[self.node_id].close_and_cleanup()
peer.close_and_cleanup()
else:
# remove zeroconf flag as we are now confirmed, this is to prevent an electrum server causing
# us to remove a channel later in update_unfunded_state by omitting its funding tx
@@ -779,7 +779,7 @@ class Channel(AbstractChannel):
self,
state: 'StoredDict', *,
name=None,
lnworker=None, # None only in unittests
lnworker: 'LNWallet' = None, # None only in unittests
initial_feerate=None,
jit_opening_fee: Optional[int] = None,
):
@@ -1022,8 +1022,8 @@ class Channel(AbstractChannel):
elif self.is_static_remotekey_enabled():
our_payment_pubkey = self.config[LOCAL].payment_basepoint.pubkey
addr = make_commitment_output_to_remote_address(our_payment_pubkey, has_anchors=self.has_anchors())
if self.lnworker:
assert self.lnworker.wallet.is_mine(addr)
#if self.lnworker:
# assert self.lnworker.wallet.is_mine(addr) # FIXME xxxxx chan should be deterministic. NEEDS to be fixed before merge
return addr
def has_anchors(self) -> bool:
+4 -4
View File
@@ -81,7 +81,7 @@ class Peer(Logger, EventListener):
def __init__(
self,
lnworker: Union['LNGossip', 'LNWallet'],
lnworker: Union['LNWallet', 'LNGossip'],
pubkey: bytes,
transport: LNTransportBase,
*, is_channel_backup= False):
@@ -402,7 +402,7 @@ class Peer(Logger, EventListener):
if constants.net.rev_genesis_bytes() not in their_chains:
raise GracefulDisconnect(f"no common chain found with remote. (they sent: {their_chains})")
# all checks passed
self.lnworker.on_peer_successfully_established(self)
self.lnworker.lnpeermgr.on_peer_successfully_established(self)
self._received_init = True
self.maybe_set_initialized()
@@ -888,7 +888,7 @@ class Peer(Logger, EventListener):
self.transport.close()
except Exception:
pass
self.lnworker.peer_closed(self)
self.lnworker.lnpeermgr.peer_closed(self)
self.got_disconnected.set()
def is_shutdown_anysegwit(self):
@@ -3064,7 +3064,7 @@ class Peer(Logger, EventListener):
or not self.lnworker.is_payment_bundle_complete(payment_key):
# maybe this set is COMPLETE but the bundle is not yet completed, so the bundle can be considered WAITING
if int(time.time()) - first_htlc_timestamp > self.lnworker.MPP_EXPIRY \
or self.lnworker.stopping_soon:
or self.lnworker.lnpeermgr.stopping_soon:
_log_fail_reason(f"MPP TIMEOUT (> {self.lnworker.MPP_EXPIRY} sec)")
return OnionFailureCode.MPP_TIMEOUT, None, None
+188 -94
View File
@@ -39,7 +39,7 @@ from .util import (
profiler, OldTaskGroup, ESocksProxy, NetworkRetryManager, JsonRPCClient, NotEnoughFunds, EventListener,
event_listener, bfh, InvoiceError, resolve_dns_srv, is_ip_address, log_exceptions, ignore_exceptions,
make_aiohttp_session, random_shuffled_copy, is_private_netaddress,
UnrelatedTransactionException, LightningHistoryItem
UnrelatedTransactionException, LightningHistoryItem, get_asyncio_loop,
)
from .fee_policy import (
FeePolicy, FEERATE_FALLBACK_STATIC_FEE, FEE_LN_ETA_TARGET, FEE_LN_LOW_ETA_TARGET,
@@ -215,9 +215,15 @@ LNGOSSIP_FEATURES = (
)
class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
class LNPeerManager(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
def __init__(self, node_keypair, features: LnFeatures, *, config: 'SimpleConfig'):
def __init__(
self, node_keypair,
*,
lnwallet_or_lngossip: 'LNWallet | LNGossip',
features: LnFeatures,
config: 'SimpleConfig',
):
Logger.__init__(self)
NetworkRetryManager.__init__(
self,
@@ -228,6 +234,7 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
)
self.lock = threading.RLock()
self.node_keypair = node_keypair
self._lnwallet_or_lngossip = lnwallet_or_lngossip
self._peers = {} # type: Dict[bytes, Peer] # pubkey -> Peer # needs self.lock
self._channelless_incoming_peers = set() # type: Set[bytes] # node_ids # needs self.lock
self.taskgroup = OldTaskGroup()
@@ -252,7 +259,10 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
return self._peers.copy()
def channels_for_peer(self, node_id: bytes) -> Dict[bytes, Channel]:
return {}
return self._lnwallet_or_lngossip.channels_for_peer(node_id)
def get_peer_by_pubkey(self, pubkey: bytes) -> Optional[Peer]:
return self._peers.get(pubkey)
def get_node_alias(self, node_id: bytes) -> Optional[str]:
"""Returns the alias of the node, or None if unknown."""
@@ -269,7 +279,7 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
return node_alias
async def maybe_listen(self):
# FIXME: only one LNWorker can listen at a time (single port)
# FIXME: only one LNPeerManager can listen at a time (single port)
listen_addr = self.config.LIGHTNING_LISTEN
if listen_addr:
self.logger.info(f'lightning_listen enabled. will try to bind: {listen_addr!r}')
@@ -368,13 +378,17 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
return None
self._channelless_incoming_peers.add(node_id)
# checks done: we are adding this peer.
peer = Peer(self, node_id, transport)
peer = Peer(self._lnwallet_or_lngossip, node_id, transport)
assert node_id not in self._peers
self._peers[node_id] = peer
await self.taskgroup.spawn(peer.main_loop())
return peer
def peer_closed(self, peer: Peer) -> None:
if isinstance(self._lnwallet_or_lngossip, LNWallet):
for chan in self.channels_for_peer(peer.pubkey).values():
chan.peer_state = PeerState.DISCONNECTED
util.trigger_callback('channel', self._lnwallet_or_lngossip.wallet, chan)
with self.lock:
peer2 = self._peers.get(peer.pubkey)
if peer2 is peer:
@@ -392,14 +406,25 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
return True
return False
def start_network(self, network: 'Network'):
def start_network(
self, network: 'Network', *,
listen: bool = False,
maintain_random_peers: bool = False,
) -> None:
assert network
assert self.network is None, "already started"
self.network = network
self._add_peers_from_config()
asyncio.run_coroutine_threadsafe(self.main_loop(), self.network.asyncio_loop)
asyncio.run_coroutine_threadsafe(self.main_loop(), get_asyncio_loop())
if listen:
tg_coro = self.taskgroup.spawn(self.maybe_listen())
asyncio.run_coroutine_threadsafe(tg_coro, get_asyncio_loop())
if maintain_random_peers:
tg_coro = self.taskgroup.spawn(self._maintain_connectivity())
asyncio.run_coroutine_threadsafe(tg_coro, get_asyncio_loop())
async def stop(self):
self.stopping_soon = True
if self.listen_server:
self.listen_server.close()
self.unregister_callbacks()
@@ -410,7 +435,7 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
for host, port, pubkey in peer_list:
asyncio.run_coroutine_threadsafe(
self._add_peer(host, int(port), bfh(pubkey)),
self.network.asyncio_loop)
get_asyncio_loop())
def is_good_peer(self, peer: LNPeerAddr) -> bool:
# the purpose of this method is to filter peers that advertise the desired feature bits
@@ -573,8 +598,44 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
peer = await self._add_peer(host, port, node_id)
return peer
async def reestablish_peer_for_given_channel(self, chan: Channel) -> None:
await self.taskgroup.spawn(self._reestablish_peer_for_given_channel(chan))
class LNGossip(LNWorker):
@ignore_exceptions
@log_exceptions
async def _reestablish_peer_for_given_channel(self, chan: Channel) -> None:
now = time.time()
peer_addresses = []
if self.uses_trampoline():
addr = trampolines_by_id().get(chan.node_id)
if addr:
peer_addresses.append(addr)
else:
# will try last good address first, from gossip
last_good_addr = self.channel_db.get_last_good_address(chan.node_id)
if last_good_addr:
peer_addresses.append(last_good_addr)
# will try addresses for node_id from gossip
addrs_from_gossip = self.channel_db.get_node_addresses(chan.node_id) or []
for host, port, ts in addrs_from_gossip:
peer_addresses.append(LNPeerAddr(host, port, chan.node_id))
# will try addresses stored in channel storage
peer_addresses += list(chan.get_peer_addresses())
# Done gathering addresses.
# Now select first one that has not failed recently.
for peer in peer_addresses:
if self._can_retry_addr(peer, urgent=True, now=now):
await self._add_peer(peer.host, peer.port, peer.pubkey)
return
async def reestablish_peer_for_zero_conf_trusted_node(self) -> None:
if self.config.ZEROCONF_TRUSTED_NODE:
peer = LNPeerAddr.from_str(self.config.ZEROCONF_TRUSTED_NODE)
if self._can_retry_addr(peer, urgent=True):
await self._add_peer(peer.host, peer.port, peer.pubkey)
class LNGossip(Logger):
"""The LNGossip class is a separate, unannounced Lightning node with random id that is just querying
gossip from other nodes. The LNGossip node does not satisfy gossip queries, this is done by the
LNWallet class(es). LNWallets are the advertised nodes used for actual payments and only satisfy
@@ -584,11 +645,14 @@ class LNGossip(LNWorker):
max_age = 14*24*3600
def __init__(self, config: 'SimpleConfig'):
self.config = config
seed = os.urandom(32)
node = BIP32Node.from_rootseed(seed, xtype='standard')
xprv = node.to_xprv()
node_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.NODE_KEY)
LNWorker.__init__(self, node_keypair, LNGOSSIP_FEATURES, config=config)
Logger.__init__(self)
self.lnpeermgr = LNPeerManager(node_keypair, features=LNGOSSIP_FEATURES, config=self.config, lnwallet_or_lngossip=self)
self.taskgroup = OldTaskGroup()
self.unknown_ids = set()
self._forwarding_gossip = [] # type: List[GossipForwardingMessage]
self._last_gossip_batch_ts = 0 # type: int
@@ -600,15 +664,44 @@ class LNGossip(LNWorker):
self._num_chan_upd = 0
self._num_chan_upd_good = 0
@property
def features(self) -> 'LnFeatures':
return self.lnpeermgr.features
@property
def network(self) -> Optional['Network']:
return self.lnpeermgr.network
@property
def channel_db(self) -> 'ChannelDB':
return self.network.channel_db if self.network else None
def uses_trampoline(self) -> bool:
return not bool(self.channel_db)
async def main_loop(self):
self.logger.info("starting taskgroup.")
try:
async with self.taskgroup as group:
await group.spawn(asyncio.Event().wait) # run forever (until cancel)
except Exception as e:
self.logger.exception("taskgroup died.")
finally:
self.logger.info("taskgroup stopped.")
def start_network(self, network: 'Network'):
super().start_network(network)
asyncio.run_coroutine_threadsafe(self.main_loop(), get_asyncio_loop())
self.lnpeermgr.start_network(network, maintain_random_peers=True)
for coro in [
self._maintain_connectivity(),
self.maintain_db(),
self._maintain_forwarding_gossip()
]:
tg_coro = self.taskgroup.spawn(coro)
asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop)
asyncio.run_coroutine_threadsafe(tg_coro, get_asyncio_loop())
async def stop(self):
await self.lnpeermgr.stop()
await self.taskgroup.cancel_remaining()
async def maintain_db(self):
await self.channel_db.data_loaded.wait()
@@ -637,7 +730,7 @@ class LNGossip(LNWorker):
new = set(ids) - set(known)
self.unknown_ids.update(new)
util.trigger_callback('unknown_channels', len(self.unknown_ids))
util.trigger_callback('gossip_peers', self.num_peers())
util.trigger_callback('gossip_peers', self.lnpeermgr.num_peers())
util.trigger_callback('ln_gossip_sync_progress')
def get_ids_to_query(self) -> Sequence[bytes]:
@@ -652,7 +745,7 @@ class LNGossip(LNWorker):
"""Estimates the gossip synchronization process and returns the number
of synchronized channels, the total channels in the network and a
rescaled percentage of the synchronization process."""
if self.num_peers() == 0:
if self.lnpeermgr.num_peers() == 0:
return None, None, None
nchans_with_0p, nchans_with_1p, nchans_with_2p = self.channel_db.get_num_channels_partitioned_by_policy_count()
num_db_channels = nchans_with_0p + nchans_with_1p + nchans_with_2p
@@ -730,6 +823,9 @@ class LNGossip(LNWorker):
# flush the gossip queue so we don't forward old gossip after sync is complete
self.channel_db.get_forwarding_gossip_batch()
def channels_for_peer(self, node_id: bytes) -> Dict[bytes, Channel]:
return {}
class PaySession(Logger):
@@ -875,7 +971,7 @@ class PaySession(Logger):
return nhtlcs_resolved == self._nhtlcs_inflight
class LNWallet(LNWorker):
class LNWallet(Logger):
lnwatcher: Optional['LNWatcher']
MPP_EXPIRY = 120
@@ -884,7 +980,7 @@ class LNWallet(LNWorker):
MPP_SPLIT_PART_FRACTION = 0.2
MPP_SPLIT_PART_MINAMT_MSAT = 5_000_000
def __init__(self, wallet: 'Abstract_Wallet', xprv):
def __init__(self, wallet: 'Abstract_Wallet', xprv, *, features: LnFeatures = None):
self.wallet = wallet
self.config = wallet.config
self.db = wallet.db
@@ -894,16 +990,20 @@ class LNWallet(LNWorker):
self.payment_secret_key = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.PAYMENT_SECRET_KEY).privkey
self.funding_root_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.FUNDING_ROOT_KEY)
Logger.__init__(self)
features = LNWALLET_FEATURES
if self.config.ENABLE_ANCHOR_CHANNELS:
features |= LnFeatures.OPTION_ANCHORS_ZERO_FEE_HTLC_OPT
if self.config.ACCEPT_ZEROCONF_CHANNELS:
features |= LnFeatures.OPTION_ZEROCONF_OPT
if self.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS or self.config.EXPERIMENTAL_LN_FORWARD_TRAMPOLINE_PAYMENTS:
features |= LnFeatures.OPTION_ONION_MESSAGE_OPT
if self.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS and self.config.LIGHTNING_USE_GOSSIP:
features |= LnFeatures.GOSSIP_QUERIES_OPT # signal we have gossip to fetch
LNWorker.__init__(self, self.node_keypair, features, config=self.config)
if features is None:
features = LNWALLET_FEATURES
if self.config.ENABLE_ANCHOR_CHANNELS:
features |= LnFeatures.OPTION_ANCHORS_ZERO_FEE_HTLC_OPT
if self.config.ACCEPT_ZEROCONF_CHANNELS:
features |= LnFeatures.OPTION_ZEROCONF_OPT
if self.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS or self.config.EXPERIMENTAL_LN_FORWARD_TRAMPOLINE_PAYMENTS:
features |= LnFeatures.OPTION_ONION_MESSAGE_OPT
if self.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS and self.config.LIGHTNING_USE_GOSSIP:
features |= LnFeatures.GOSSIP_QUERIES_OPT # signal we have gossip to fetch
Logger.__init__(self)
self.lock = threading.RLock()
self.lnpeermgr = LNPeerManager(self.node_keypair, features=features, config=self.config, lnwallet_or_lngossip=self)
self.taskgroup = OldTaskGroup()
self.lnwatcher = LNWatcher(self)
self.lnrater: LNRater = None
# "RHASH:direction" -> amount_msat, status, min_final_cltv_delta, expiry_delay, creation_ts, invoice_features
@@ -997,6 +1097,21 @@ class LNWallet(LNWorker):
return any(chan.has_anchors() and not chan.is_closed()
for chan in self.channels.values())
@property
def features(self) -> 'LnFeatures':
return self.lnpeermgr.features
@property
def network(self) -> Optional['Network']:
return self.lnpeermgr.network
@property
def channel_db(self) -> 'ChannelDB':
return self.network.channel_db if self.network else None
def uses_trampoline(self) -> bool:
return not bool(self.channel_db)
@property
def channels(self) -> Mapping[bytes, Channel]:
"""Returns a read-only copy of channels."""
@@ -1060,29 +1175,39 @@ class LNWallet(LNWorker):
await watchtower.add_sweep_tx(outpoint, ctn, tx.inputs()[0].prevout.to_str(), tx.serialize())
self.watchtower_ctns[outpoint] = ctn
async def main_loop(self):
self.logger.info("starting taskgroup.")
try:
async with self.taskgroup as group:
await group.spawn(asyncio.Event().wait) # run forever (until cancel)
except Exception as e:
self.logger.exception("taskgroup died.")
finally:
self.logger.info("taskgroup stopped.")
def start_network(self, network: 'Network'):
super().start_network(network)
asyncio.run_coroutine_threadsafe(self.main_loop(), get_asyncio_loop())
self.lnpeermgr.start_network(network, listen=True)
self.lnwatcher.start_network(network)
self.swap_manager.start_network(network)
self.lnrater = LNRater(self, network)
self.onion_message_manager.start_network(network=network)
for coro in [
self.maybe_listen(),
self.lnwatcher.trigger_callbacks(), # shortcut (don't block) if funding tx locked and verified
self.reestablish_peers_and_channels(),
self.sync_with_remote_watchtower(),
]:
tg_coro = self.taskgroup.spawn(coro)
asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop)
asyncio.run_coroutine_threadsafe(tg_coro, get_asyncio_loop())
async def stop(self):
self.stopping_soon = True
if self.listen_server: # stop accepting new peers
self.listen_server.close()
self.lnpeermgr.stopping_soon = True
if self.lnpeermgr.listen_server: # stop accepting new peers
self.lnpeermgr.listen_server.close()
async with ignore_after(self.TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS):
await self.wait_for_received_pending_htlcs_to_get_removed()
await LNWorker.stop(self)
await self.lnpeermgr.stop()
if self.lnwatcher:
self.lnwatcher.stop()
self.lnwatcher = None
@@ -1090,30 +1215,25 @@ class LNWallet(LNWorker):
await self.swap_manager.stop()
if self.onion_message_manager:
await self.onion_message_manager.stop()
await self.taskgroup.cancel_remaining()
async def wait_for_received_pending_htlcs_to_get_removed(self):
assert self.stopping_soon is True
assert self.lnpeermgr.stopping_soon is True
# We try to fail pending MPP HTLCs, and wait a bit for them to get removed.
# Note: even without MPP, if we just failed/fulfilled an HTLC, it is good
# to wait a bit for it to become irrevocably removed.
# Note: we don't wait for *all htlcs* to get removed, only for those
# that we can already fail/fulfill. e.g. forwarded htlcs cannot be removed
async with OldTaskGroup() as group:
for peer in self.peers.values():
for peer in self.lnpeermgr.peers.values():
await group.spawn(peer.wait_one_htlc_switch_iteration())
while True:
if all(not peer.received_htlcs_pending_removal for peer in self.peers.values()):
if all(not peer.received_htlcs_pending_removal for peer in self.lnpeermgr.peers.values()):
break
async with OldTaskGroup(wait=any) as group:
for peer in self.peers.values():
for peer in self.lnpeermgr.peers.values():
await group.spawn(peer.received_htlc_removed_event.wait())
def peer_closed(self, peer):
for chan in self.channels_for_peer(peer.pubkey).values():
chan.peer_state = PeerState.DISCONNECTED
util.trigger_callback('channel', self.wallet, chan)
super().peer_closed(peer)
def get_payments(self, *, status=None) -> Mapping[bytes, List[HTLCWithStatus]]:
out = defaultdict(list)
for chan in self.channels.values():
@@ -1263,7 +1383,7 @@ class LNWallet(LNWorker):
node_ids = [chan.node_id for chan in self.channels.values() if not chan.is_closed()]
return node_ids
def channels_for_peer(self, node_id):
def channels_for_peer(self, node_id: bytes) -> Dict[bytes, Channel]:
assert type(node_id) is bytes
return {chan_id: chan for (chan_id, chan) in self.channels.items()
if chan.node_id == node_id}
@@ -1307,12 +1427,12 @@ class LNWallet(LNWorker):
await self.schedule_force_closing(chan.channel_id)
elif chan.get_state() == ChannelState.FUNDED:
peer = self._peers.get(chan.node_id)
peer = self.lnpeermgr.get_peer_by_pubkey(chan.node_id)
if peer and peer.is_initialized() and chan.peer_state == PeerState.GOOD:
peer.send_channel_ready(chan)
elif chan.get_state() == ChannelState.OPEN:
peer = self._peers.get(chan.node_id)
peer = self.lnpeermgr.get_peer_by_pubkey(chan.node_id)
if peer and peer.is_initialized() and chan.peer_state == PeerState.GOOD:
peer.maybe_update_fee(chan)
peer.maybe_send_announcement_signatures(chan)
@@ -1326,9 +1446,10 @@ class LNWallet(LNWorker):
await self.network.try_broadcasting(force_close_tx, 'force-close')
def get_peer_by_static_jit_scid_alias(self, scid_alias: bytes) -> Optional[Peer]:
for nodeid, peer in self.peers.items():
for nodeid, peer in self.lnpeermgr.peers.items():
if scid_alias == self._scid_alias_of_node(nodeid):
return peer
return None
def _scid_alias_of_node(self, nodeid: bytes) -> bytes:
# scid alias for just-in-time channels
@@ -1557,7 +1678,7 @@ class LNWallet(LNWorker):
password: str = None,
) -> Tuple[Channel, PartialTransaction]:
fut = asyncio.run_coroutine_threadsafe(self.add_peer(connect_str), self.network.asyncio_loop)
fut = asyncio.run_coroutine_threadsafe(self.lnpeermgr.add_peer(connect_str), get_asyncio_loop())
try:
peer = fut.result()
except concurrent.futures.TimeoutError:
@@ -1569,7 +1690,7 @@ class LNWallet(LNWorker):
push_sat=push_amt_sat,
public=public,
password=password)
fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
fut = asyncio.run_coroutine_threadsafe(coro, get_asyncio_loop())
try:
chan, funding_tx = fut.result()
except concurrent.futures.TimeoutError:
@@ -1860,7 +1981,7 @@ class LNWallet(LNWorker):
short_channel_id = shi.route[0].short_channel_id
chan = self.get_channel_by_short_id(short_channel_id)
assert chan, ShortChannelID(short_channel_id)
peer = self._peers.get(shi.route[0].node_id)
peer = self.lnpeermgr.get_peer_by_pubkey(shi.route[0].node_id)
if not peer:
raise PaymentFailure('Dropped peer')
await peer.initialized
@@ -2040,7 +2161,7 @@ class LNWallet(LNWorker):
# until trampoline is advertised in lnfeatures, check against hardcoded list
if is_hardcoded_trampoline(node_id):
return True
peer = self._peers.get(node_id)
peer = self.lnpeermgr.get_peer_by_pubkey(node_id)
if not peer:
return False
return (peer.their_features.supports(LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ECLAIR)
@@ -2794,7 +2915,7 @@ class LNWallet(LNWorker):
return
upstream_chan_scid, _ = deserialize_htlc_key(upstream_key)
upstream_chan = self.get_channel_by_short_id(upstream_chan_scid)
upstream_peer = self.peers.get(upstream_chan.node_id) if upstream_chan else None
upstream_peer = self.lnpeermgr.get_peer_by_pubkey(upstream_chan.node_id) if upstream_chan else None
if upstream_peer:
upstream_peer.downstream_htlc_resolved_event.set()
upstream_peer.downstream_htlc_resolved_event.clear()
@@ -3110,7 +3231,7 @@ class LNWallet(LNWorker):
# invalid connection string
return False
# only return True if we are connected to the zeroconf provider
return node_id in self.peers
return self.lnpeermgr.get_peer_by_pubkey(node_id) is not None
def _suggest_channels_for_rebalance(self, direction, amount_sat) -> Sequence[Tuple[Channel, int]]:
"""
@@ -3239,7 +3360,9 @@ class LNWallet(LNWorker):
async def close_channel(self, chan_id):
chan = self._channels[chan_id]
peer = self._peers[chan.node_id]
peer = self.lnpeermgr.get_peer_by_pubkey(chan.node_id)
if peer is None:
raise KeyError
return await peer.close_channel(chan_id)
def _force_close_channel(self, chan_id: bytes) -> Transaction:
@@ -3288,52 +3411,23 @@ class LNWallet(LNWorker):
util.trigger_callback('channels_updated', self.wallet)
util.trigger_callback('wallet_updated', self.wallet)
@ignore_exceptions
@log_exceptions
async def reestablish_peer_for_given_channel(self, chan: Channel) -> None:
now = time.time()
peer_addresses = []
if self.uses_trampoline():
addr = trampolines_by_id().get(chan.node_id)
if addr:
peer_addresses.append(addr)
else:
# will try last good address first, from gossip
last_good_addr = self.channel_db.get_last_good_address(chan.node_id)
if last_good_addr:
peer_addresses.append(last_good_addr)
# will try addresses for node_id from gossip
addrs_from_gossip = self.channel_db.get_node_addresses(chan.node_id) or []
for host, port, ts in addrs_from_gossip:
peer_addresses.append(LNPeerAddr(host, port, chan.node_id))
# will try addresses stored in channel storage
peer_addresses += list(chan.get_peer_addresses())
# Done gathering addresses.
# Now select first one that has not failed recently.
for peer in peer_addresses:
if self._can_retry_addr(peer, urgent=True, now=now):
await self._add_peer(peer.host, peer.port, peer.pubkey)
return
async def reestablish_peers_and_channels(self):
while True:
await asyncio.sleep(1)
if self.stopping_soon:
if self.lnpeermgr.stopping_soon:
return
if self.config.ZEROCONF_TRUSTED_NODE:
peer = LNPeerAddr.from_str(self.config.ZEROCONF_TRUSTED_NODE)
if self._can_retry_addr(peer, urgent=True):
await self._add_peer(peer.host, peer.port, peer.pubkey)
await self.lnpeermgr.reestablish_peer_for_zero_conf_trusted_node()
for chan in self.channels.values():
# reestablish
# note: we delegate filtering out uninteresting chans to this:
if not chan.should_try_to_reestablish_peer():
continue
peer = self._peers.get(chan.node_id, None)
peer = self.lnpeermgr.get_peer_by_pubkey(chan.node_id)
if peer:
# FIXME maybe this should be the responsibility of the peer itself, done in peer.main_loop:
await peer.taskgroup.spawn(peer.reestablish_channel(chan))
else:
await self.taskgroup.spawn(self.reestablish_peer_for_given_channel(chan))
await self.lnpeermgr.reestablish_peer_for_given_channel(chan)
def current_target_feerate_per_kw(self, *, has_anchors: bool) -> Optional[int]:
target: int = FEE_LN_MINIMUM_ETA_TARGET if has_anchors else FEE_LN_ETA_TARGET
@@ -3396,12 +3490,12 @@ class LNWallet(LNWorker):
async def request_force_close(self, channel_id: bytes, *, connect_str=None) -> None:
if chan := self.get_channel_by_id(channel_id):
peer = self._peers.get(chan.node_id)
peer = self.lnpeermgr.get_peer_by_pubkey(chan.node_id)
chan.should_request_force_close = True
if peer:
peer.close_and_cleanup() # to force a reconnect
elif connect_str:
peer = await self.add_peer(connect_str)
peer = await self.lnpeermgr.add_peer(connect_str)
await peer.request_force_close(channel_id)
elif channel_id in self.channel_backups:
await self._request_force_close_from_backup(channel_id)
@@ -3688,7 +3782,7 @@ class LNWallet(LNWorker):
f"maybe_forward_htlc. will forward HTLC: inc_chan={incoming_chan.short_channel_id}. inc_htlc={str(htlc)}. "
f"next_chan={next_chan.get_id_for_log()}.")
next_peer = self.peers.get(next_chan.node_id)
next_peer = self.lnpeermgr.get_peer_by_pubkey(next_chan.node_id)
if next_peer is None:
log_fail_reason(f"next_peer offline ({next_chan.node_id.hex()})")
raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data=outgoing_chan_upd_message)
@@ -3774,7 +3868,7 @@ class LNWallet(LNWorker):
raise OnionRoutingFailure(code=OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON, data=b'')
# do we have a connection to the node?
next_peer = self.peers.get(outgoing_node_id)
next_peer = self.lnpeermgr.get_peer_by_pubkey(outgoing_node_id)
if next_peer and next_peer.accepts_zeroconf():
self.logger.info(f'JIT: found next_peer')
for next_chan in next_peer.channels.values():
+12 -12
View File
@@ -162,7 +162,7 @@ def create_onion_message_route_to(lnwallet: 'LNWallet', node_id: bytes) -> Seque
): return path
# alt: dest is existing peer?
if lnwallet.peers.get(node_id):
if lnwallet.lnpeermgr.get_peer_by_pubkey(node_id):
return [PathEdge(short_channel_id=None, start_node=None, end_node=node_id)]
# if we have an address, pass it.
@@ -219,7 +219,7 @@ def send_onion_message_to(
encrypted_recipient_data=our_payload['encrypted_recipient_data']
)
peer = lnwallet.peers.get(recipient_data['next_node_id']['node_id'])
peer = lnwallet.lnpeermgr.get_peer_by_pubkey(recipient_data['next_node_id']['node_id'])
assert peer, 'next_node_id not a peer'
# blinding override?
@@ -241,7 +241,7 @@ def send_onion_message_to(
if not isinstance(remaining_blinded_path, list): # doesn't return list when num items == 1
remaining_blinded_path = [remaining_blinded_path]
peer = lnwallet.peers.get(introduction_point)
peer = lnwallet.lnpeermgr.get_peer_by_pubkey(introduction_point)
# if blinded path introduction point is our direct peer, no need to route-find
if peer:
# start of blinded path is our peer
@@ -250,7 +250,7 @@ def send_onion_message_to(
path = create_onion_message_route_to(lnwallet, introduction_point)
# first edge must be to our peer
peer = lnwallet.peers.get(path[0].end_node)
peer = lnwallet.lnpeermgr.get_peer_by_pubkey(path[0].end_node)
assert peer, 'first hop not a peer'
# last edge is to introduction point and start of blinded path. remove from route
@@ -321,7 +321,7 @@ def send_onion_message_to(
raise Exception('cannot send to myself')
hops_data = []
peer = lnwallet.peers.get(pubkey)
peer = lnwallet.lnpeermgr.get_peer_by_pubkey(pubkey)
if peer:
# destination is our direct peer, no need to route-find
@@ -330,7 +330,7 @@ def send_onion_message_to(
path = create_onion_message_route_to(lnwallet, pubkey)
# first edge must be to our peer
peer = lnwallet.peers.get(path[0].end_node)
peer = lnwallet.lnpeermgr.get_peer_by_pubkey(path[0].end_node)
assert peer, 'first hop not a peer'
hops_data = [
@@ -379,9 +379,9 @@ def get_blinded_reply_paths(
- reply_path introduction points are direct peers only (TODO: longer reply paths)"""
# TODO: build longer paths and/or add dummy hops to increase privacy
my_active_channels = [chan for chan in lnwallet.channels.values() if chan.is_active()]
my_onionmsg_channels = [chan for chan in my_active_channels if lnwallet.peers.get(chan.node_id) and
lnwallet.peers.get(chan.node_id).their_features.supports(LnFeatures.OPTION_ONION_MESSAGE_OPT)]
my_onionmsg_peers = [peer for peer in lnwallet.peers.values() if peer.their_features.supports(LnFeatures.OPTION_ONION_MESSAGE_OPT)]
my_onionmsg_channels = [chan for chan in my_active_channels if lnwallet.lnpeermgr.get_peer_by_pubkey(chan.node_id) and
lnwallet.lnpeermgr.get_peer_by_pubkey(chan.node_id).their_features.supports(LnFeatures.OPTION_ONION_MESSAGE_OPT)]
my_onionmsg_peers = [peer for peer in lnwallet.lnpeermgr.peers.values() if peer.their_features.supports(LnFeatures.OPTION_ONION_MESSAGE_OPT)]
result = []
mynodeid = lnwallet.node_keypair.pubkey
@@ -472,7 +472,7 @@ class OnionMessageManager(Logger):
try:
onion_packet_b = onion_packet.to_bytes()
next_peer = self.lnwallet.peers.get(node_id)
next_peer = self.lnwallet.lnpeermgr.get_peer_by_pubkey(node_id)
if not next_peer.their_features.supports(LnFeatures.OPTION_ONION_MESSAGE_OPT):
self.logger.debug('forward dropped, next peer is not ONION_MESSAGE capable')
@@ -528,7 +528,7 @@ class OnionMessageManager(Logger):
req.future.set_exception(copy.copy(e))
# NOTE: above, when passing the caught exception instance e directly it leads to GeneratorExit() in
if isinstance(e, NoRouteFound) and e.peer_address:
await self.lnwallet.add_peer(str(e.peer_address))
await self.lnwallet.lnpeermgr.add_peer(str(e.peer_address))
else:
self.logger.debug(f'resubmit {key=}')
self.send_queue.put_nowait((now() + self.REQUEST_REPLY_RETRY_DELAY, expires, key))
@@ -700,7 +700,7 @@ class OnionMessageManager(Logger):
'onion_message dropped (not forwarding due to lightning_forward_payments config option disabled')
return
# is next_node one of our peers?
next_peer = self.lnwallet.peers.get(next_node_id)
next_peer = self.lnwallet.lnpeermgr.get_peer_by_pubkey(next_node_id)
if not next_peer:
self.logger.info(f'next node {next_node_id.hex()} not a peer, dropping message')
return
+1 -1
View File
@@ -81,7 +81,7 @@ async def worker(work_queue: asyncio.Queue, results_queue: asyncio.Queue, flag):
print(f"worker connecting to {connect_str}")
try:
peer = await wallet.lnworker.add_peer(connect_str)
peer = await wallet.lnworker.lnpeermgr.add_peer(connect_str)
res = await util.wait_for2(peer.initialized, TIMEOUT)
if res:
if peer.features & flag == work['features'] & flag:
+2 -2
View File
@@ -3451,7 +3451,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
lightning_has_channels = (
self.lnworker and len([chan for chan in self.lnworker.channels.values() if chan.is_open()]) > 0
)
lightning_online = self.lnworker and self.lnworker.num_peers() > 0
lightning_online = self.lnworker and self.lnworker.lnpeermgr.num_peers() > 0
num_sats_can_receive = self.lnworker.num_sats_can_receive() if self.lnworker else 0
can_receive_lightning = self.lnworker and num_sats_can_receive > 0 and amount_sat <= num_sats_can_receive
try:
@@ -3459,7 +3459,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
except Exception:
zeroconf_nodeid = None
can_get_zeroconf_channel = (self.lnworker and self.config.ACCEPT_ZEROCONF_CHANNELS
and zeroconf_nodeid in self.lnworker.peers)
and self.lnworker.lnpeermgr.get_peer_by_pubkey(zeroconf_nodeid) is not None)
status = self.get_invoice_status(req)
if status == PR_EXPIRED: