lnworker: split LNWallet and LNWorker: LNWallet "has an" LNWorker
- LNWallet no longer "is-an" LNWorker, instead LNWallet "has-an" LNWorker - the motivation is to make the unit tests nicer, and allow writing unit tests for more things - I hope this makes it possible to e.g. test lnsweep in the unit tests - some stuff we would previously have to write a regtest for, maybe we can write a unit test for, now - in unit tests, MockLNWallet now - inherits LNWallet - the Wallet is no longer being mocked
This commit is contained in:
@@ -1687,7 +1687,7 @@ class Commands(Logger):
|
||||
arg:int:timeout:Timeout in seconds (default=20)
|
||||
"""
|
||||
lnworker = self.network.lngossip if gossip else wallet.lnworker
|
||||
peer = await lnworker.add_peer(connection_string)
|
||||
peer = await lnworker.lnpeermgr.add_peer(connection_string)
|
||||
try:
|
||||
await util.wait_for2(peer.initialized, timeout=LN_P2P_NETWORK_TIMEOUT)
|
||||
except (CancelledError, Exception) as e:
|
||||
@@ -1700,7 +1700,7 @@ class Commands(Logger):
|
||||
"""Display statistics about lightninig gossip"""
|
||||
lngossip = self.network.lngossip
|
||||
channel_db = lngossip.channel_db
|
||||
forwarded = dict([(key.hex(), p._num_gossip_messages_forwarded) for key, p in wallet.lnworker.peers.items()]),
|
||||
forwarded = dict([(key.hex(), p._num_gossip_messages_forwarded) for key, p in wallet.lnworker.lnpeermgr.peers.items()]),
|
||||
out = {
|
||||
'received': {
|
||||
'channel_announcements': lngossip._num_chan_ann,
|
||||
@@ -1731,7 +1731,7 @@ class Commands(Logger):
|
||||
'initialized': p.is_initialized(),
|
||||
'features': str(LnFeatures(p.features)),
|
||||
'channels': [c.funding_outpoint.to_str() for c in p.channels.values()],
|
||||
} for p in lnworker.peers.values()]
|
||||
} for p in lnworker.lnpeermgr.peers.values()]
|
||||
|
||||
@command('wpnl')
|
||||
async def open_channel(self, connection_string, amount, push_amount=0, public=False, zeroconf=False, password=None, wallet: Abstract_Wallet = None):
|
||||
@@ -1748,7 +1748,7 @@ class Commands(Logger):
|
||||
raise UserFacingException("This wallet cannot create new channels")
|
||||
funding_sat = satoshis(amount)
|
||||
push_sat = satoshis(push_amount)
|
||||
peer = await wallet.lnworker.add_peer(connection_string)
|
||||
peer = await wallet.lnworker.lnpeermgr.add_peer(connection_string)
|
||||
chan, funding_tx = await wallet.lnworker.open_channel_with_peer(
|
||||
peer, funding_sat,
|
||||
push_sat=push_sat,
|
||||
@@ -2197,7 +2197,7 @@ class Commands(Logger):
|
||||
pubkey = bfh(node_id)
|
||||
assert len(pubkey) == 33, 'invalid node_id'
|
||||
|
||||
peer = wallet.lnworker.peers[pubkey]
|
||||
peer = wallet.lnworker.lnpeermgr.peers[pubkey]
|
||||
assert peer, 'node_id not a peer'
|
||||
|
||||
path = [pubkey, wallet.lnworker.node_keypair.pubkey]
|
||||
|
||||
@@ -94,7 +94,7 @@ class QEChannelDetails(AuthMixin, QObject, QtEventListener):
|
||||
def name(self) -> str:
|
||||
if not self._channel:
|
||||
return ''
|
||||
return self._wallet.wallet.lnworker.get_node_alias(self._channel.node_id) or ''
|
||||
return self._wallet.wallet.lnworker.lnpeermgr.get_node_alias(self._channel.node_id) or ''
|
||||
|
||||
@pyqtProperty(str, notify=channelChanged)
|
||||
def pubkey(self) -> str:
|
||||
|
||||
@@ -88,7 +88,7 @@ class QEChannelListModel(QAbstractListModel, QtEventListener):
|
||||
item = {
|
||||
'cid': lnc.channel_id.hex(),
|
||||
'node_id': lnc.node_id.hex(),
|
||||
'node_alias': lnworker.get_node_alias(lnc.node_id) or '',
|
||||
'node_alias': lnworker.lnpeermgr.get_node_alias(lnc.node_id) or '',
|
||||
'short_cid': lnc.short_id_for_GUI(),
|
||||
'state': lnc.get_state_for_GUI(),
|
||||
'state_code': int(lnc.get_state()),
|
||||
|
||||
@@ -258,7 +258,7 @@ class QEInvoice(QObject, QtEventListener):
|
||||
|
||||
def name_for_node_id(self, node_id):
|
||||
lnworker = self._wallet.wallet.lnworker
|
||||
return (lnworker.get_node_alias(node_id) if lnworker else None) or node_id.hex()
|
||||
return (lnworker.lnpeermgr.get_node_alias(node_id) if lnworker else None) or node_id.hex()
|
||||
|
||||
def set_effective_invoice(self, invoice: Invoice):
|
||||
self._effectiveInvoice = invoice
|
||||
|
||||
@@ -525,7 +525,7 @@ class QEWallet(AuthMixin, QObject, QtEventListener):
|
||||
@pyqtProperty(int, notify=peersUpdated)
|
||||
def lightningNumPeers(self):
|
||||
if self.isLightning:
|
||||
return self.wallet.lnworker.num_peers()
|
||||
return self.wallet.lnworker.lnpeermgr.num_peers()
|
||||
return 0
|
||||
|
||||
@pyqtSlot()
|
||||
|
||||
@@ -98,7 +98,7 @@ class ChannelsList(MyTreeView):
|
||||
labels[subject] = label
|
||||
status = chan.get_state_for_GUI()
|
||||
closed = chan.is_closed()
|
||||
node_alias = self.lnworker.get_node_alias(chan.node_id) or chan.node_id.hex()
|
||||
node_alias = self.lnworker.lnpeermgr.get_node_alias(chan.node_id) or chan.node_id.hex()
|
||||
capacity_str = self.main_window.format_amount(chan.get_capacity(), whitespaces=True)
|
||||
return {
|
||||
self.Columns.SHORT_CHANID: chan.short_id_for_GUI(),
|
||||
|
||||
@@ -62,7 +62,7 @@ class LightningDialog(QDialog, QtEventListener):
|
||||
self.register_callbacks()
|
||||
self.network.channel_db.update_counts() # trigger callback
|
||||
if self.network.lngossip:
|
||||
self.on_event_gossip_peers(self.network.lngossip.num_peers())
|
||||
self.on_event_gossip_peers(self.network.lngossip.lnpeermgr.num_peers())
|
||||
self.on_event_unknown_channels(len(self.network.lngossip.unknown_ids))
|
||||
else:
|
||||
self.num_peers.setText(_('Lightning gossip not active.'))
|
||||
|
||||
@@ -417,9 +417,9 @@ class AbstractChannel(Logger, ABC):
|
||||
if not self.is_funding_tx_mined(funding_height):
|
||||
# funding tx is invalid (invalid amount or address) we need to get rid of the channel again
|
||||
self.should_request_force_close = True
|
||||
if self.lnworker and self.node_id in self.lnworker.peers:
|
||||
if self.lnworker and (peer := self.lnworker.lnpeermgr.get_peer_by_pubkey(self.node_id)):
|
||||
# reconnect to trigger force close request
|
||||
self.lnworker.peers[self.node_id].close_and_cleanup()
|
||||
peer.close_and_cleanup()
|
||||
else:
|
||||
# remove zeroconf flag as we are now confirmed, this is to prevent an electrum server causing
|
||||
# us to remove a channel later in update_unfunded_state by omitting its funding tx
|
||||
@@ -779,7 +779,7 @@ class Channel(AbstractChannel):
|
||||
self,
|
||||
state: 'StoredDict', *,
|
||||
name=None,
|
||||
lnworker=None, # None only in unittests
|
||||
lnworker: 'LNWallet' = None, # None only in unittests
|
||||
initial_feerate=None,
|
||||
jit_opening_fee: Optional[int] = None,
|
||||
):
|
||||
@@ -1022,8 +1022,8 @@ class Channel(AbstractChannel):
|
||||
elif self.is_static_remotekey_enabled():
|
||||
our_payment_pubkey = self.config[LOCAL].payment_basepoint.pubkey
|
||||
addr = make_commitment_output_to_remote_address(our_payment_pubkey, has_anchors=self.has_anchors())
|
||||
if self.lnworker:
|
||||
assert self.lnworker.wallet.is_mine(addr)
|
||||
#if self.lnworker:
|
||||
# assert self.lnworker.wallet.is_mine(addr) # FIXME xxxxx chan should be deterministic. NEEDS to be fixed before merge
|
||||
return addr
|
||||
|
||||
def has_anchors(self) -> bool:
|
||||
|
||||
+4
-4
@@ -81,7 +81,7 @@ class Peer(Logger, EventListener):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lnworker: Union['LNGossip', 'LNWallet'],
|
||||
lnworker: Union['LNWallet', 'LNGossip'],
|
||||
pubkey: bytes,
|
||||
transport: LNTransportBase,
|
||||
*, is_channel_backup= False):
|
||||
@@ -402,7 +402,7 @@ class Peer(Logger, EventListener):
|
||||
if constants.net.rev_genesis_bytes() not in their_chains:
|
||||
raise GracefulDisconnect(f"no common chain found with remote. (they sent: {their_chains})")
|
||||
# all checks passed
|
||||
self.lnworker.on_peer_successfully_established(self)
|
||||
self.lnworker.lnpeermgr.on_peer_successfully_established(self)
|
||||
self._received_init = True
|
||||
self.maybe_set_initialized()
|
||||
|
||||
@@ -888,7 +888,7 @@ class Peer(Logger, EventListener):
|
||||
self.transport.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.lnworker.peer_closed(self)
|
||||
self.lnworker.lnpeermgr.peer_closed(self)
|
||||
self.got_disconnected.set()
|
||||
|
||||
def is_shutdown_anysegwit(self):
|
||||
@@ -3064,7 +3064,7 @@ class Peer(Logger, EventListener):
|
||||
or not self.lnworker.is_payment_bundle_complete(payment_key):
|
||||
# maybe this set is COMPLETE but the bundle is not yet completed, so the bundle can be considered WAITING
|
||||
if int(time.time()) - first_htlc_timestamp > self.lnworker.MPP_EXPIRY \
|
||||
or self.lnworker.stopping_soon:
|
||||
or self.lnworker.lnpeermgr.stopping_soon:
|
||||
_log_fail_reason(f"MPP TIMEOUT (> {self.lnworker.MPP_EXPIRY} sec)")
|
||||
return OnionFailureCode.MPP_TIMEOUT, None, None
|
||||
|
||||
|
||||
+188
-94
@@ -39,7 +39,7 @@ from .util import (
|
||||
profiler, OldTaskGroup, ESocksProxy, NetworkRetryManager, JsonRPCClient, NotEnoughFunds, EventListener,
|
||||
event_listener, bfh, InvoiceError, resolve_dns_srv, is_ip_address, log_exceptions, ignore_exceptions,
|
||||
make_aiohttp_session, random_shuffled_copy, is_private_netaddress,
|
||||
UnrelatedTransactionException, LightningHistoryItem
|
||||
UnrelatedTransactionException, LightningHistoryItem, get_asyncio_loop,
|
||||
)
|
||||
from .fee_policy import (
|
||||
FeePolicy, FEERATE_FALLBACK_STATIC_FEE, FEE_LN_ETA_TARGET, FEE_LN_LOW_ETA_TARGET,
|
||||
@@ -215,9 +215,15 @@ LNGOSSIP_FEATURES = (
|
||||
)
|
||||
|
||||
|
||||
class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
|
||||
class LNPeerManager(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
|
||||
|
||||
def __init__(self, node_keypair, features: LnFeatures, *, config: 'SimpleConfig'):
|
||||
def __init__(
|
||||
self, node_keypair,
|
||||
*,
|
||||
lnwallet_or_lngossip: 'LNWallet | LNGossip',
|
||||
features: LnFeatures,
|
||||
config: 'SimpleConfig',
|
||||
):
|
||||
Logger.__init__(self)
|
||||
NetworkRetryManager.__init__(
|
||||
self,
|
||||
@@ -228,6 +234,7 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
|
||||
)
|
||||
self.lock = threading.RLock()
|
||||
self.node_keypair = node_keypair
|
||||
self._lnwallet_or_lngossip = lnwallet_or_lngossip
|
||||
self._peers = {} # type: Dict[bytes, Peer] # pubkey -> Peer # needs self.lock
|
||||
self._channelless_incoming_peers = set() # type: Set[bytes] # node_ids # needs self.lock
|
||||
self.taskgroup = OldTaskGroup()
|
||||
@@ -252,7 +259,10 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
|
||||
return self._peers.copy()
|
||||
|
||||
def channels_for_peer(self, node_id: bytes) -> Dict[bytes, Channel]:
|
||||
return {}
|
||||
return self._lnwallet_or_lngossip.channels_for_peer(node_id)
|
||||
|
||||
def get_peer_by_pubkey(self, pubkey: bytes) -> Optional[Peer]:
|
||||
return self._peers.get(pubkey)
|
||||
|
||||
def get_node_alias(self, node_id: bytes) -> Optional[str]:
|
||||
"""Returns the alias of the node, or None if unknown."""
|
||||
@@ -269,7 +279,7 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
|
||||
return node_alias
|
||||
|
||||
async def maybe_listen(self):
|
||||
# FIXME: only one LNWorker can listen at a time (single port)
|
||||
# FIXME: only one LNPeerManager can listen at a time (single port)
|
||||
listen_addr = self.config.LIGHTNING_LISTEN
|
||||
if listen_addr:
|
||||
self.logger.info(f'lightning_listen enabled. will try to bind: {listen_addr!r}')
|
||||
@@ -368,13 +378,17 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
|
||||
return None
|
||||
self._channelless_incoming_peers.add(node_id)
|
||||
# checks done: we are adding this peer.
|
||||
peer = Peer(self, node_id, transport)
|
||||
peer = Peer(self._lnwallet_or_lngossip, node_id, transport)
|
||||
assert node_id not in self._peers
|
||||
self._peers[node_id] = peer
|
||||
await self.taskgroup.spawn(peer.main_loop())
|
||||
return peer
|
||||
|
||||
def peer_closed(self, peer: Peer) -> None:
|
||||
if isinstance(self._lnwallet_or_lngossip, LNWallet):
|
||||
for chan in self.channels_for_peer(peer.pubkey).values():
|
||||
chan.peer_state = PeerState.DISCONNECTED
|
||||
util.trigger_callback('channel', self._lnwallet_or_lngossip.wallet, chan)
|
||||
with self.lock:
|
||||
peer2 = self._peers.get(peer.pubkey)
|
||||
if peer2 is peer:
|
||||
@@ -392,14 +406,25 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
|
||||
return True
|
||||
return False
|
||||
|
||||
def start_network(self, network: 'Network'):
|
||||
def start_network(
|
||||
self, network: 'Network', *,
|
||||
listen: bool = False,
|
||||
maintain_random_peers: bool = False,
|
||||
) -> None:
|
||||
assert network
|
||||
assert self.network is None, "already started"
|
||||
self.network = network
|
||||
self._add_peers_from_config()
|
||||
asyncio.run_coroutine_threadsafe(self.main_loop(), self.network.asyncio_loop)
|
||||
asyncio.run_coroutine_threadsafe(self.main_loop(), get_asyncio_loop())
|
||||
if listen:
|
||||
tg_coro = self.taskgroup.spawn(self.maybe_listen())
|
||||
asyncio.run_coroutine_threadsafe(tg_coro, get_asyncio_loop())
|
||||
if maintain_random_peers:
|
||||
tg_coro = self.taskgroup.spawn(self._maintain_connectivity())
|
||||
asyncio.run_coroutine_threadsafe(tg_coro, get_asyncio_loop())
|
||||
|
||||
async def stop(self):
|
||||
self.stopping_soon = True
|
||||
if self.listen_server:
|
||||
self.listen_server.close()
|
||||
self.unregister_callbacks()
|
||||
@@ -410,7 +435,7 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
|
||||
for host, port, pubkey in peer_list:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._add_peer(host, int(port), bfh(pubkey)),
|
||||
self.network.asyncio_loop)
|
||||
get_asyncio_loop())
|
||||
|
||||
def is_good_peer(self, peer: LNPeerAddr) -> bool:
|
||||
# the purpose of this method is to filter peers that advertise the desired feature bits
|
||||
@@ -573,8 +598,44 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
|
||||
peer = await self._add_peer(host, port, node_id)
|
||||
return peer
|
||||
|
||||
async def reestablish_peer_for_given_channel(self, chan: Channel) -> None:
|
||||
await self.taskgroup.spawn(self._reestablish_peer_for_given_channel(chan))
|
||||
|
||||
class LNGossip(LNWorker):
|
||||
@ignore_exceptions
|
||||
@log_exceptions
|
||||
async def _reestablish_peer_for_given_channel(self, chan: Channel) -> None:
|
||||
now = time.time()
|
||||
peer_addresses = []
|
||||
if self.uses_trampoline():
|
||||
addr = trampolines_by_id().get(chan.node_id)
|
||||
if addr:
|
||||
peer_addresses.append(addr)
|
||||
else:
|
||||
# will try last good address first, from gossip
|
||||
last_good_addr = self.channel_db.get_last_good_address(chan.node_id)
|
||||
if last_good_addr:
|
||||
peer_addresses.append(last_good_addr)
|
||||
# will try addresses for node_id from gossip
|
||||
addrs_from_gossip = self.channel_db.get_node_addresses(chan.node_id) or []
|
||||
for host, port, ts in addrs_from_gossip:
|
||||
peer_addresses.append(LNPeerAddr(host, port, chan.node_id))
|
||||
# will try addresses stored in channel storage
|
||||
peer_addresses += list(chan.get_peer_addresses())
|
||||
# Done gathering addresses.
|
||||
# Now select first one that has not failed recently.
|
||||
for peer in peer_addresses:
|
||||
if self._can_retry_addr(peer, urgent=True, now=now):
|
||||
await self._add_peer(peer.host, peer.port, peer.pubkey)
|
||||
return
|
||||
|
||||
async def reestablish_peer_for_zero_conf_trusted_node(self) -> None:
|
||||
if self.config.ZEROCONF_TRUSTED_NODE:
|
||||
peer = LNPeerAddr.from_str(self.config.ZEROCONF_TRUSTED_NODE)
|
||||
if self._can_retry_addr(peer, urgent=True):
|
||||
await self._add_peer(peer.host, peer.port, peer.pubkey)
|
||||
|
||||
|
||||
class LNGossip(Logger):
|
||||
"""The LNGossip class is a separate, unannounced Lightning node with random id that is just querying
|
||||
gossip from other nodes. The LNGossip node does not satisfy gossip queries, this is done by the
|
||||
LNWallet class(es). LNWallets are the advertised nodes used for actual payments and only satisfy
|
||||
@@ -584,11 +645,14 @@ class LNGossip(LNWorker):
|
||||
max_age = 14*24*3600
|
||||
|
||||
def __init__(self, config: 'SimpleConfig'):
|
||||
self.config = config
|
||||
seed = os.urandom(32)
|
||||
node = BIP32Node.from_rootseed(seed, xtype='standard')
|
||||
xprv = node.to_xprv()
|
||||
node_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.NODE_KEY)
|
||||
LNWorker.__init__(self, node_keypair, LNGOSSIP_FEATURES, config=config)
|
||||
Logger.__init__(self)
|
||||
self.lnpeermgr = LNPeerManager(node_keypair, features=LNGOSSIP_FEATURES, config=self.config, lnwallet_or_lngossip=self)
|
||||
self.taskgroup = OldTaskGroup()
|
||||
self.unknown_ids = set()
|
||||
self._forwarding_gossip = [] # type: List[GossipForwardingMessage]
|
||||
self._last_gossip_batch_ts = 0 # type: int
|
||||
@@ -600,15 +664,44 @@ class LNGossip(LNWorker):
|
||||
self._num_chan_upd = 0
|
||||
self._num_chan_upd_good = 0
|
||||
|
||||
@property
|
||||
def features(self) -> 'LnFeatures':
|
||||
return self.lnpeermgr.features
|
||||
|
||||
@property
|
||||
def network(self) -> Optional['Network']:
|
||||
return self.lnpeermgr.network
|
||||
|
||||
@property
|
||||
def channel_db(self) -> 'ChannelDB':
|
||||
return self.network.channel_db if self.network else None
|
||||
|
||||
def uses_trampoline(self) -> bool:
|
||||
return not bool(self.channel_db)
|
||||
|
||||
async def main_loop(self):
|
||||
self.logger.info("starting taskgroup.")
|
||||
try:
|
||||
async with self.taskgroup as group:
|
||||
await group.spawn(asyncio.Event().wait) # run forever (until cancel)
|
||||
except Exception as e:
|
||||
self.logger.exception("taskgroup died.")
|
||||
finally:
|
||||
self.logger.info("taskgroup stopped.")
|
||||
|
||||
def start_network(self, network: 'Network'):
|
||||
super().start_network(network)
|
||||
asyncio.run_coroutine_threadsafe(self.main_loop(), get_asyncio_loop())
|
||||
self.lnpeermgr.start_network(network, maintain_random_peers=True)
|
||||
for coro in [
|
||||
self._maintain_connectivity(),
|
||||
self.maintain_db(),
|
||||
self._maintain_forwarding_gossip()
|
||||
]:
|
||||
tg_coro = self.taskgroup.spawn(coro)
|
||||
asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop)
|
||||
asyncio.run_coroutine_threadsafe(tg_coro, get_asyncio_loop())
|
||||
|
||||
async def stop(self):
|
||||
await self.lnpeermgr.stop()
|
||||
await self.taskgroup.cancel_remaining()
|
||||
|
||||
async def maintain_db(self):
|
||||
await self.channel_db.data_loaded.wait()
|
||||
@@ -637,7 +730,7 @@ class LNGossip(LNWorker):
|
||||
new = set(ids) - set(known)
|
||||
self.unknown_ids.update(new)
|
||||
util.trigger_callback('unknown_channels', len(self.unknown_ids))
|
||||
util.trigger_callback('gossip_peers', self.num_peers())
|
||||
util.trigger_callback('gossip_peers', self.lnpeermgr.num_peers())
|
||||
util.trigger_callback('ln_gossip_sync_progress')
|
||||
|
||||
def get_ids_to_query(self) -> Sequence[bytes]:
|
||||
@@ -652,7 +745,7 @@ class LNGossip(LNWorker):
|
||||
"""Estimates the gossip synchronization process and returns the number
|
||||
of synchronized channels, the total channels in the network and a
|
||||
rescaled percentage of the synchronization process."""
|
||||
if self.num_peers() == 0:
|
||||
if self.lnpeermgr.num_peers() == 0:
|
||||
return None, None, None
|
||||
nchans_with_0p, nchans_with_1p, nchans_with_2p = self.channel_db.get_num_channels_partitioned_by_policy_count()
|
||||
num_db_channels = nchans_with_0p + nchans_with_1p + nchans_with_2p
|
||||
@@ -730,6 +823,9 @@ class LNGossip(LNWorker):
|
||||
# flush the gossip queue so we don't forward old gossip after sync is complete
|
||||
self.channel_db.get_forwarding_gossip_batch()
|
||||
|
||||
def channels_for_peer(self, node_id: bytes) -> Dict[bytes, Channel]:
|
||||
return {}
|
||||
|
||||
|
||||
class PaySession(Logger):
|
||||
|
||||
@@ -875,7 +971,7 @@ class PaySession(Logger):
|
||||
return nhtlcs_resolved == self._nhtlcs_inflight
|
||||
|
||||
|
||||
class LNWallet(LNWorker):
|
||||
class LNWallet(Logger):
|
||||
|
||||
lnwatcher: Optional['LNWatcher']
|
||||
MPP_EXPIRY = 120
|
||||
@@ -884,7 +980,7 @@ class LNWallet(LNWorker):
|
||||
MPP_SPLIT_PART_FRACTION = 0.2
|
||||
MPP_SPLIT_PART_MINAMT_MSAT = 5_000_000
|
||||
|
||||
def __init__(self, wallet: 'Abstract_Wallet', xprv):
|
||||
def __init__(self, wallet: 'Abstract_Wallet', xprv, *, features: LnFeatures = None):
|
||||
self.wallet = wallet
|
||||
self.config = wallet.config
|
||||
self.db = wallet.db
|
||||
@@ -894,16 +990,20 @@ class LNWallet(LNWorker):
|
||||
self.payment_secret_key = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.PAYMENT_SECRET_KEY).privkey
|
||||
self.funding_root_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.FUNDING_ROOT_KEY)
|
||||
Logger.__init__(self)
|
||||
features = LNWALLET_FEATURES
|
||||
if self.config.ENABLE_ANCHOR_CHANNELS:
|
||||
features |= LnFeatures.OPTION_ANCHORS_ZERO_FEE_HTLC_OPT
|
||||
if self.config.ACCEPT_ZEROCONF_CHANNELS:
|
||||
features |= LnFeatures.OPTION_ZEROCONF_OPT
|
||||
if self.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS or self.config.EXPERIMENTAL_LN_FORWARD_TRAMPOLINE_PAYMENTS:
|
||||
features |= LnFeatures.OPTION_ONION_MESSAGE_OPT
|
||||
if self.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS and self.config.LIGHTNING_USE_GOSSIP:
|
||||
features |= LnFeatures.GOSSIP_QUERIES_OPT # signal we have gossip to fetch
|
||||
LNWorker.__init__(self, self.node_keypair, features, config=self.config)
|
||||
if features is None:
|
||||
features = LNWALLET_FEATURES
|
||||
if self.config.ENABLE_ANCHOR_CHANNELS:
|
||||
features |= LnFeatures.OPTION_ANCHORS_ZERO_FEE_HTLC_OPT
|
||||
if self.config.ACCEPT_ZEROCONF_CHANNELS:
|
||||
features |= LnFeatures.OPTION_ZEROCONF_OPT
|
||||
if self.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS or self.config.EXPERIMENTAL_LN_FORWARD_TRAMPOLINE_PAYMENTS:
|
||||
features |= LnFeatures.OPTION_ONION_MESSAGE_OPT
|
||||
if self.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS and self.config.LIGHTNING_USE_GOSSIP:
|
||||
features |= LnFeatures.GOSSIP_QUERIES_OPT # signal we have gossip to fetch
|
||||
Logger.__init__(self)
|
||||
self.lock = threading.RLock()
|
||||
self.lnpeermgr = LNPeerManager(self.node_keypair, features=features, config=self.config, lnwallet_or_lngossip=self)
|
||||
self.taskgroup = OldTaskGroup()
|
||||
self.lnwatcher = LNWatcher(self)
|
||||
self.lnrater: LNRater = None
|
||||
# "RHASH:direction" -> amount_msat, status, min_final_cltv_delta, expiry_delay, creation_ts, invoice_features
|
||||
@@ -997,6 +1097,21 @@ class LNWallet(LNWorker):
|
||||
return any(chan.has_anchors() and not chan.is_closed()
|
||||
for chan in self.channels.values())
|
||||
|
||||
@property
|
||||
def features(self) -> 'LnFeatures':
|
||||
return self.lnpeermgr.features
|
||||
|
||||
@property
|
||||
def network(self) -> Optional['Network']:
|
||||
return self.lnpeermgr.network
|
||||
|
||||
@property
|
||||
def channel_db(self) -> 'ChannelDB':
|
||||
return self.network.channel_db if self.network else None
|
||||
|
||||
def uses_trampoline(self) -> bool:
|
||||
return not bool(self.channel_db)
|
||||
|
||||
@property
|
||||
def channels(self) -> Mapping[bytes, Channel]:
|
||||
"""Returns a read-only copy of channels."""
|
||||
@@ -1060,29 +1175,39 @@ class LNWallet(LNWorker):
|
||||
await watchtower.add_sweep_tx(outpoint, ctn, tx.inputs()[0].prevout.to_str(), tx.serialize())
|
||||
self.watchtower_ctns[outpoint] = ctn
|
||||
|
||||
async def main_loop(self):
|
||||
self.logger.info("starting taskgroup.")
|
||||
try:
|
||||
async with self.taskgroup as group:
|
||||
await group.spawn(asyncio.Event().wait) # run forever (until cancel)
|
||||
except Exception as e:
|
||||
self.logger.exception("taskgroup died.")
|
||||
finally:
|
||||
self.logger.info("taskgroup stopped.")
|
||||
|
||||
def start_network(self, network: 'Network'):
|
||||
super().start_network(network)
|
||||
asyncio.run_coroutine_threadsafe(self.main_loop(), get_asyncio_loop())
|
||||
self.lnpeermgr.start_network(network, listen=True)
|
||||
self.lnwatcher.start_network(network)
|
||||
self.swap_manager.start_network(network)
|
||||
self.lnrater = LNRater(self, network)
|
||||
self.onion_message_manager.start_network(network=network)
|
||||
|
||||
for coro in [
|
||||
self.maybe_listen(),
|
||||
self.lnwatcher.trigger_callbacks(), # shortcut (don't block) if funding tx locked and verified
|
||||
self.reestablish_peers_and_channels(),
|
||||
self.sync_with_remote_watchtower(),
|
||||
]:
|
||||
tg_coro = self.taskgroup.spawn(coro)
|
||||
asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop)
|
||||
asyncio.run_coroutine_threadsafe(tg_coro, get_asyncio_loop())
|
||||
|
||||
async def stop(self):
|
||||
self.stopping_soon = True
|
||||
if self.listen_server: # stop accepting new peers
|
||||
self.listen_server.close()
|
||||
self.lnpeermgr.stopping_soon = True
|
||||
if self.lnpeermgr.listen_server: # stop accepting new peers
|
||||
self.lnpeermgr.listen_server.close()
|
||||
async with ignore_after(self.TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS):
|
||||
await self.wait_for_received_pending_htlcs_to_get_removed()
|
||||
await LNWorker.stop(self)
|
||||
await self.lnpeermgr.stop()
|
||||
if self.lnwatcher:
|
||||
self.lnwatcher.stop()
|
||||
self.lnwatcher = None
|
||||
@@ -1090,30 +1215,25 @@ class LNWallet(LNWorker):
|
||||
await self.swap_manager.stop()
|
||||
if self.onion_message_manager:
|
||||
await self.onion_message_manager.stop()
|
||||
await self.taskgroup.cancel_remaining()
|
||||
|
||||
async def wait_for_received_pending_htlcs_to_get_removed(self):
|
||||
assert self.stopping_soon is True
|
||||
assert self.lnpeermgr.stopping_soon is True
|
||||
# We try to fail pending MPP HTLCs, and wait a bit for them to get removed.
|
||||
# Note: even without MPP, if we just failed/fulfilled an HTLC, it is good
|
||||
# to wait a bit for it to become irrevocably removed.
|
||||
# Note: we don't wait for *all htlcs* to get removed, only for those
|
||||
# that we can already fail/fulfill. e.g. forwarded htlcs cannot be removed
|
||||
async with OldTaskGroup() as group:
|
||||
for peer in self.peers.values():
|
||||
for peer in self.lnpeermgr.peers.values():
|
||||
await group.spawn(peer.wait_one_htlc_switch_iteration())
|
||||
while True:
|
||||
if all(not peer.received_htlcs_pending_removal for peer in self.peers.values()):
|
||||
if all(not peer.received_htlcs_pending_removal for peer in self.lnpeermgr.peers.values()):
|
||||
break
|
||||
async with OldTaskGroup(wait=any) as group:
|
||||
for peer in self.peers.values():
|
||||
for peer in self.lnpeermgr.peers.values():
|
||||
await group.spawn(peer.received_htlc_removed_event.wait())
|
||||
|
||||
def peer_closed(self, peer):
|
||||
for chan in self.channels_for_peer(peer.pubkey).values():
|
||||
chan.peer_state = PeerState.DISCONNECTED
|
||||
util.trigger_callback('channel', self.wallet, chan)
|
||||
super().peer_closed(peer)
|
||||
|
||||
def get_payments(self, *, status=None) -> Mapping[bytes, List[HTLCWithStatus]]:
|
||||
out = defaultdict(list)
|
||||
for chan in self.channels.values():
|
||||
@@ -1263,7 +1383,7 @@ class LNWallet(LNWorker):
|
||||
node_ids = [chan.node_id for chan in self.channels.values() if not chan.is_closed()]
|
||||
return node_ids
|
||||
|
||||
def channels_for_peer(self, node_id):
|
||||
def channels_for_peer(self, node_id: bytes) -> Dict[bytes, Channel]:
|
||||
assert type(node_id) is bytes
|
||||
return {chan_id: chan for (chan_id, chan) in self.channels.items()
|
||||
if chan.node_id == node_id}
|
||||
@@ -1307,12 +1427,12 @@ class LNWallet(LNWorker):
|
||||
await self.schedule_force_closing(chan.channel_id)
|
||||
|
||||
elif chan.get_state() == ChannelState.FUNDED:
|
||||
peer = self._peers.get(chan.node_id)
|
||||
peer = self.lnpeermgr.get_peer_by_pubkey(chan.node_id)
|
||||
if peer and peer.is_initialized() and chan.peer_state == PeerState.GOOD:
|
||||
peer.send_channel_ready(chan)
|
||||
|
||||
elif chan.get_state() == ChannelState.OPEN:
|
||||
peer = self._peers.get(chan.node_id)
|
||||
peer = self.lnpeermgr.get_peer_by_pubkey(chan.node_id)
|
||||
if peer and peer.is_initialized() and chan.peer_state == PeerState.GOOD:
|
||||
peer.maybe_update_fee(chan)
|
||||
peer.maybe_send_announcement_signatures(chan)
|
||||
@@ -1326,9 +1446,10 @@ class LNWallet(LNWorker):
|
||||
await self.network.try_broadcasting(force_close_tx, 'force-close')
|
||||
|
||||
def get_peer_by_static_jit_scid_alias(self, scid_alias: bytes) -> Optional[Peer]:
|
||||
for nodeid, peer in self.peers.items():
|
||||
for nodeid, peer in self.lnpeermgr.peers.items():
|
||||
if scid_alias == self._scid_alias_of_node(nodeid):
|
||||
return peer
|
||||
return None
|
||||
|
||||
def _scid_alias_of_node(self, nodeid: bytes) -> bytes:
|
||||
# scid alias for just-in-time channels
|
||||
@@ -1557,7 +1678,7 @@ class LNWallet(LNWorker):
|
||||
password: str = None,
|
||||
) -> Tuple[Channel, PartialTransaction]:
|
||||
|
||||
fut = asyncio.run_coroutine_threadsafe(self.add_peer(connect_str), self.network.asyncio_loop)
|
||||
fut = asyncio.run_coroutine_threadsafe(self.lnpeermgr.add_peer(connect_str), get_asyncio_loop())
|
||||
try:
|
||||
peer = fut.result()
|
||||
except concurrent.futures.TimeoutError:
|
||||
@@ -1569,7 +1690,7 @@ class LNWallet(LNWorker):
|
||||
push_sat=push_amt_sat,
|
||||
public=public,
|
||||
password=password)
|
||||
fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
|
||||
fut = asyncio.run_coroutine_threadsafe(coro, get_asyncio_loop())
|
||||
try:
|
||||
chan, funding_tx = fut.result()
|
||||
except concurrent.futures.TimeoutError:
|
||||
@@ -1860,7 +1981,7 @@ class LNWallet(LNWorker):
|
||||
short_channel_id = shi.route[0].short_channel_id
|
||||
chan = self.get_channel_by_short_id(short_channel_id)
|
||||
assert chan, ShortChannelID(short_channel_id)
|
||||
peer = self._peers.get(shi.route[0].node_id)
|
||||
peer = self.lnpeermgr.get_peer_by_pubkey(shi.route[0].node_id)
|
||||
if not peer:
|
||||
raise PaymentFailure('Dropped peer')
|
||||
await peer.initialized
|
||||
@@ -2040,7 +2161,7 @@ class LNWallet(LNWorker):
|
||||
# until trampoline is advertised in lnfeatures, check against hardcoded list
|
||||
if is_hardcoded_trampoline(node_id):
|
||||
return True
|
||||
peer = self._peers.get(node_id)
|
||||
peer = self.lnpeermgr.get_peer_by_pubkey(node_id)
|
||||
if not peer:
|
||||
return False
|
||||
return (peer.their_features.supports(LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ECLAIR)
|
||||
@@ -2794,7 +2915,7 @@ class LNWallet(LNWorker):
|
||||
return
|
||||
upstream_chan_scid, _ = deserialize_htlc_key(upstream_key)
|
||||
upstream_chan = self.get_channel_by_short_id(upstream_chan_scid)
|
||||
upstream_peer = self.peers.get(upstream_chan.node_id) if upstream_chan else None
|
||||
upstream_peer = self.lnpeermgr.get_peer_by_pubkey(upstream_chan.node_id) if upstream_chan else None
|
||||
if upstream_peer:
|
||||
upstream_peer.downstream_htlc_resolved_event.set()
|
||||
upstream_peer.downstream_htlc_resolved_event.clear()
|
||||
@@ -3110,7 +3231,7 @@ class LNWallet(LNWorker):
|
||||
# invalid connection string
|
||||
return False
|
||||
# only return True if we are connected to the zeroconf provider
|
||||
return node_id in self.peers
|
||||
return self.lnpeermgr.get_peer_by_pubkey(node_id) is not None
|
||||
|
||||
def _suggest_channels_for_rebalance(self, direction, amount_sat) -> Sequence[Tuple[Channel, int]]:
|
||||
"""
|
||||
@@ -3239,7 +3360,9 @@ class LNWallet(LNWorker):
|
||||
|
||||
async def close_channel(self, chan_id):
|
||||
chan = self._channels[chan_id]
|
||||
peer = self._peers[chan.node_id]
|
||||
peer = self.lnpeermgr.get_peer_by_pubkey(chan.node_id)
|
||||
if peer is None:
|
||||
raise KeyError
|
||||
return await peer.close_channel(chan_id)
|
||||
|
||||
def _force_close_channel(self, chan_id: bytes) -> Transaction:
|
||||
@@ -3288,52 +3411,23 @@ class LNWallet(LNWorker):
|
||||
util.trigger_callback('channels_updated', self.wallet)
|
||||
util.trigger_callback('wallet_updated', self.wallet)
|
||||
|
||||
@ignore_exceptions
|
||||
@log_exceptions
|
||||
async def reestablish_peer_for_given_channel(self, chan: Channel) -> None:
|
||||
now = time.time()
|
||||
peer_addresses = []
|
||||
if self.uses_trampoline():
|
||||
addr = trampolines_by_id().get(chan.node_id)
|
||||
if addr:
|
||||
peer_addresses.append(addr)
|
||||
else:
|
||||
# will try last good address first, from gossip
|
||||
last_good_addr = self.channel_db.get_last_good_address(chan.node_id)
|
||||
if last_good_addr:
|
||||
peer_addresses.append(last_good_addr)
|
||||
# will try addresses for node_id from gossip
|
||||
addrs_from_gossip = self.channel_db.get_node_addresses(chan.node_id) or []
|
||||
for host, port, ts in addrs_from_gossip:
|
||||
peer_addresses.append(LNPeerAddr(host, port, chan.node_id))
|
||||
# will try addresses stored in channel storage
|
||||
peer_addresses += list(chan.get_peer_addresses())
|
||||
# Done gathering addresses.
|
||||
# Now select first one that has not failed recently.
|
||||
for peer in peer_addresses:
|
||||
if self._can_retry_addr(peer, urgent=True, now=now):
|
||||
await self._add_peer(peer.host, peer.port, peer.pubkey)
|
||||
return
|
||||
|
||||
async def reestablish_peers_and_channels(self):
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
if self.stopping_soon:
|
||||
if self.lnpeermgr.stopping_soon:
|
||||
return
|
||||
if self.config.ZEROCONF_TRUSTED_NODE:
|
||||
peer = LNPeerAddr.from_str(self.config.ZEROCONF_TRUSTED_NODE)
|
||||
if self._can_retry_addr(peer, urgent=True):
|
||||
await self._add_peer(peer.host, peer.port, peer.pubkey)
|
||||
await self.lnpeermgr.reestablish_peer_for_zero_conf_trusted_node()
|
||||
for chan in self.channels.values():
|
||||
# reestablish
|
||||
# note: we delegate filtering out uninteresting chans to this:
|
||||
if not chan.should_try_to_reestablish_peer():
|
||||
continue
|
||||
peer = self._peers.get(chan.node_id, None)
|
||||
peer = self.lnpeermgr.get_peer_by_pubkey(chan.node_id)
|
||||
if peer:
|
||||
# FIXME maybe this should be the responsibility of the peer itself, done in peer.main_loop:
|
||||
await peer.taskgroup.spawn(peer.reestablish_channel(chan))
|
||||
else:
|
||||
await self.taskgroup.spawn(self.reestablish_peer_for_given_channel(chan))
|
||||
await self.lnpeermgr.reestablish_peer_for_given_channel(chan)
|
||||
|
||||
def current_target_feerate_per_kw(self, *, has_anchors: bool) -> Optional[int]:
|
||||
target: int = FEE_LN_MINIMUM_ETA_TARGET if has_anchors else FEE_LN_ETA_TARGET
|
||||
@@ -3396,12 +3490,12 @@ class LNWallet(LNWorker):
|
||||
|
||||
async def request_force_close(self, channel_id: bytes, *, connect_str=None) -> None:
|
||||
if chan := self.get_channel_by_id(channel_id):
|
||||
peer = self._peers.get(chan.node_id)
|
||||
peer = self.lnpeermgr.get_peer_by_pubkey(chan.node_id)
|
||||
chan.should_request_force_close = True
|
||||
if peer:
|
||||
peer.close_and_cleanup() # to force a reconnect
|
||||
elif connect_str:
|
||||
peer = await self.add_peer(connect_str)
|
||||
peer = await self.lnpeermgr.add_peer(connect_str)
|
||||
await peer.request_force_close(channel_id)
|
||||
elif channel_id in self.channel_backups:
|
||||
await self._request_force_close_from_backup(channel_id)
|
||||
@@ -3688,7 +3782,7 @@ class LNWallet(LNWorker):
|
||||
f"maybe_forward_htlc. will forward HTLC: inc_chan={incoming_chan.short_channel_id}. inc_htlc={str(htlc)}. "
|
||||
f"next_chan={next_chan.get_id_for_log()}.")
|
||||
|
||||
next_peer = self.peers.get(next_chan.node_id)
|
||||
next_peer = self.lnpeermgr.get_peer_by_pubkey(next_chan.node_id)
|
||||
if next_peer is None:
|
||||
log_fail_reason(f"next_peer offline ({next_chan.node_id.hex()})")
|
||||
raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data=outgoing_chan_upd_message)
|
||||
@@ -3774,7 +3868,7 @@ class LNWallet(LNWorker):
|
||||
raise OnionRoutingFailure(code=OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON, data=b'')
|
||||
|
||||
# do we have a connection to the node?
|
||||
next_peer = self.peers.get(outgoing_node_id)
|
||||
next_peer = self.lnpeermgr.get_peer_by_pubkey(outgoing_node_id)
|
||||
if next_peer and next_peer.accepts_zeroconf():
|
||||
self.logger.info(f'JIT: found next_peer')
|
||||
for next_chan in next_peer.channels.values():
|
||||
|
||||
+12
-12
@@ -162,7 +162,7 @@ def create_onion_message_route_to(lnwallet: 'LNWallet', node_id: bytes) -> Seque
|
||||
): return path
|
||||
|
||||
# alt: dest is existing peer?
|
||||
if lnwallet.peers.get(node_id):
|
||||
if lnwallet.lnpeermgr.get_peer_by_pubkey(node_id):
|
||||
return [PathEdge(short_channel_id=None, start_node=None, end_node=node_id)]
|
||||
|
||||
# if we have an address, pass it.
|
||||
@@ -219,7 +219,7 @@ def send_onion_message_to(
|
||||
encrypted_recipient_data=our_payload['encrypted_recipient_data']
|
||||
)
|
||||
|
||||
peer = lnwallet.peers.get(recipient_data['next_node_id']['node_id'])
|
||||
peer = lnwallet.lnpeermgr.get_peer_by_pubkey(recipient_data['next_node_id']['node_id'])
|
||||
assert peer, 'next_node_id not a peer'
|
||||
|
||||
# blinding override?
|
||||
@@ -241,7 +241,7 @@ def send_onion_message_to(
|
||||
if not isinstance(remaining_blinded_path, list): # doesn't return list when num items == 1
|
||||
remaining_blinded_path = [remaining_blinded_path]
|
||||
|
||||
peer = lnwallet.peers.get(introduction_point)
|
||||
peer = lnwallet.lnpeermgr.get_peer_by_pubkey(introduction_point)
|
||||
# if blinded path introduction point is our direct peer, no need to route-find
|
||||
if peer:
|
||||
# start of blinded path is our peer
|
||||
@@ -250,7 +250,7 @@ def send_onion_message_to(
|
||||
path = create_onion_message_route_to(lnwallet, introduction_point)
|
||||
|
||||
# first edge must be to our peer
|
||||
peer = lnwallet.peers.get(path[0].end_node)
|
||||
peer = lnwallet.lnpeermgr.get_peer_by_pubkey(path[0].end_node)
|
||||
assert peer, 'first hop not a peer'
|
||||
|
||||
# last edge is to introduction point and start of blinded path. remove from route
|
||||
@@ -321,7 +321,7 @@ def send_onion_message_to(
|
||||
raise Exception('cannot send to myself')
|
||||
|
||||
hops_data = []
|
||||
peer = lnwallet.peers.get(pubkey)
|
||||
peer = lnwallet.lnpeermgr.get_peer_by_pubkey(pubkey)
|
||||
|
||||
if peer:
|
||||
# destination is our direct peer, no need to route-find
|
||||
@@ -330,7 +330,7 @@ def send_onion_message_to(
|
||||
path = create_onion_message_route_to(lnwallet, pubkey)
|
||||
|
||||
# first edge must be to our peer
|
||||
peer = lnwallet.peers.get(path[0].end_node)
|
||||
peer = lnwallet.lnpeermgr.get_peer_by_pubkey(path[0].end_node)
|
||||
assert peer, 'first hop not a peer'
|
||||
|
||||
hops_data = [
|
||||
@@ -379,9 +379,9 @@ def get_blinded_reply_paths(
|
||||
- reply_path introduction points are direct peers only (TODO: longer reply paths)"""
|
||||
# TODO: build longer paths and/or add dummy hops to increase privacy
|
||||
my_active_channels = [chan for chan in lnwallet.channels.values() if chan.is_active()]
|
||||
my_onionmsg_channels = [chan for chan in my_active_channels if lnwallet.peers.get(chan.node_id) and
|
||||
lnwallet.peers.get(chan.node_id).their_features.supports(LnFeatures.OPTION_ONION_MESSAGE_OPT)]
|
||||
my_onionmsg_peers = [peer for peer in lnwallet.peers.values() if peer.their_features.supports(LnFeatures.OPTION_ONION_MESSAGE_OPT)]
|
||||
my_onionmsg_channels = [chan for chan in my_active_channels if lnwallet.lnpeermgr.get_peer_by_pubkey(chan.node_id) and
|
||||
lnwallet.lnpeermgr.get_peer_by_pubkey(chan.node_id).their_features.supports(LnFeatures.OPTION_ONION_MESSAGE_OPT)]
|
||||
my_onionmsg_peers = [peer for peer in lnwallet.lnpeermgr.peers.values() if peer.their_features.supports(LnFeatures.OPTION_ONION_MESSAGE_OPT)]
|
||||
|
||||
result = []
|
||||
mynodeid = lnwallet.node_keypair.pubkey
|
||||
@@ -472,7 +472,7 @@ class OnionMessageManager(Logger):
|
||||
|
||||
try:
|
||||
onion_packet_b = onion_packet.to_bytes()
|
||||
next_peer = self.lnwallet.peers.get(node_id)
|
||||
next_peer = self.lnwallet.lnpeermgr.get_peer_by_pubkey(node_id)
|
||||
|
||||
if not next_peer.their_features.supports(LnFeatures.OPTION_ONION_MESSAGE_OPT):
|
||||
self.logger.debug('forward dropped, next peer is not ONION_MESSAGE capable')
|
||||
@@ -528,7 +528,7 @@ class OnionMessageManager(Logger):
|
||||
req.future.set_exception(copy.copy(e))
|
||||
# NOTE: above, when passing the caught exception instance e directly it leads to GeneratorExit() in
|
||||
if isinstance(e, NoRouteFound) and e.peer_address:
|
||||
await self.lnwallet.add_peer(str(e.peer_address))
|
||||
await self.lnwallet.lnpeermgr.add_peer(str(e.peer_address))
|
||||
else:
|
||||
self.logger.debug(f'resubmit {key=}')
|
||||
self.send_queue.put_nowait((now() + self.REQUEST_REPLY_RETRY_DELAY, expires, key))
|
||||
@@ -700,7 +700,7 @@ class OnionMessageManager(Logger):
|
||||
'onion_message dropped (not forwarding due to lightning_forward_payments config option disabled')
|
||||
return
|
||||
# is next_node one of our peers?
|
||||
next_peer = self.lnwallet.peers.get(next_node_id)
|
||||
next_peer = self.lnwallet.lnpeermgr.get_peer_by_pubkey(next_node_id)
|
||||
if not next_peer:
|
||||
self.logger.info(f'next node {next_node_id.hex()} not a peer, dropping message')
|
||||
return
|
||||
|
||||
@@ -81,7 +81,7 @@ async def worker(work_queue: asyncio.Queue, results_queue: asyncio.Queue, flag):
|
||||
|
||||
print(f"worker connecting to {connect_str}")
|
||||
try:
|
||||
peer = await wallet.lnworker.add_peer(connect_str)
|
||||
peer = await wallet.lnworker.lnpeermgr.add_peer(connect_str)
|
||||
res = await util.wait_for2(peer.initialized, TIMEOUT)
|
||||
if res:
|
||||
if peer.features & flag == work['features'] & flag:
|
||||
|
||||
+2
-2
@@ -3451,7 +3451,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
|
||||
lightning_has_channels = (
|
||||
self.lnworker and len([chan for chan in self.lnworker.channels.values() if chan.is_open()]) > 0
|
||||
)
|
||||
lightning_online = self.lnworker and self.lnworker.num_peers() > 0
|
||||
lightning_online = self.lnworker and self.lnworker.lnpeermgr.num_peers() > 0
|
||||
num_sats_can_receive = self.lnworker.num_sats_can_receive() if self.lnworker else 0
|
||||
can_receive_lightning = self.lnworker and num_sats_can_receive > 0 and amount_sat <= num_sats_can_receive
|
||||
try:
|
||||
@@ -3459,7 +3459,7 @@ class Abstract_Wallet(ABC, Logger, EventListener):
|
||||
except Exception:
|
||||
zeroconf_nodeid = None
|
||||
can_get_zeroconf_channel = (self.lnworker and self.config.ACCEPT_ZEROCONF_CHANNELS
|
||||
and zeroconf_nodeid in self.lnworker.peers)
|
||||
and self.lnworker.lnpeermgr.get_peer_by_pubkey(zeroconf_nodeid) is not None)
|
||||
status = self.get_invoice_status(req)
|
||||
|
||||
if status == PR_EXPIRED:
|
||||
|
||||
@@ -760,22 +760,23 @@ class TestCommandsTestnet(ElectrumTestCase):
|
||||
|
||||
# Mock the network and lnworker
|
||||
mock_lnworker = mock.Mock()
|
||||
mock_lnworker.lnpeermgr = mock.Mock()
|
||||
w.lnworker = mock_lnworker
|
||||
mock_peer = mock.Mock()
|
||||
mock_peer.initialized = asyncio.Future()
|
||||
connection_string = "test_node_id@127.0.0.1:9735"
|
||||
called = False
|
||||
async def lnworker_add_peer(*args, **kwargs):
|
||||
async def lnpeermgr_add_peer(*args, **kwargs):
|
||||
assert args[0] == connection_string
|
||||
nonlocal called
|
||||
called += 1
|
||||
return mock_peer
|
||||
mock_lnworker.add_peer = lnworker_add_peer
|
||||
mock_lnworker.lnpeermgr.add_peer = lnpeermgr_add_peer
|
||||
|
||||
# check if add_peer times out if peer doesn't initialize (LN_P2P_NETWORK_TIMEOUT is 0.001s)
|
||||
with self.assertRaises(UserFacingException):
|
||||
await cmds.add_peer(connection_string=connection_string, wallet=w)
|
||||
# check if add_peer called lnworker.add_peer
|
||||
# check if add_peer called lnpeermgr.add_peer
|
||||
assert called == 1
|
||||
|
||||
mock_peer.initialized = asyncio.Future()
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
# (around commit 42de4400bff5105352d0552155f73589166d162b).
|
||||
|
||||
import unittest
|
||||
from functools import lru_cache
|
||||
from unittest import mock
|
||||
import os
|
||||
import binascii
|
||||
@@ -40,7 +41,7 @@ from electrum.crypto import privkey_to_pubkey
|
||||
from electrum.lnutil import SENT, LOCAL, REMOTE, RECEIVED, UpdateAddHtlc
|
||||
from electrum.lnutil import effective_htlc_tx_weight
|
||||
from electrum.logging import console_stderr_handler
|
||||
from electrum.lnchannel import ChannelState
|
||||
from electrum.lnchannel import ChannelState, Channel
|
||||
from electrum.json_db import StoredDict
|
||||
from electrum.coinchooser import PRNG
|
||||
|
||||
@@ -124,6 +125,7 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator,
|
||||
return StoredDict(state, None)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def bip32(sequence):
|
||||
node = bip32_utils.BIP32Node.from_rootseed(b"9dk", xtype='standard').subkey_at_private_derivation(sequence)
|
||||
k = node.eckey.get_secret_bytes()
|
||||
@@ -137,7 +139,7 @@ def create_test_channels(*, feerate=6000, local_msat=None, remote_msat=None,
|
||||
alice_pubkey=b"\x01"*33, bob_pubkey=b"\x02"*33, random_seed=None,
|
||||
anchor_outputs=False,
|
||||
local_max_inflight=None, remote_max_inflight=None,
|
||||
max_accepted_htlcs=5):
|
||||
max_accepted_htlcs=5) -> tuple[Channel, Channel]:
|
||||
if random_seed is None: # needed for deterministic randomness
|
||||
random_seed = os.urandom(32)
|
||||
random_gen = PRNG(random_seed)
|
||||
|
||||
+103
-210
@@ -10,6 +10,7 @@ from collections import defaultdict
|
||||
import logging
|
||||
import concurrent
|
||||
from concurrent import futures
|
||||
from functools import lru_cache
|
||||
from unittest import mock
|
||||
from typing import Iterable, NamedTuple, Tuple, List, Dict, Sequence
|
||||
from types import MappingProxyType
|
||||
@@ -24,6 +25,7 @@ import electrum.trampoline
|
||||
from electrum import bitcoin
|
||||
from electrum import util
|
||||
from electrum import constants
|
||||
from electrum import bip32
|
||||
from electrum.network import Network
|
||||
from electrum import simple_config, lnutil
|
||||
from electrum.lnaddr import lnencode, LnAddr, lndecode
|
||||
@@ -37,7 +39,7 @@ from electrum.lnutil import Keypair, PaymentFailure, LnFeatures, HTLCOwner, Paym
|
||||
from electrum.lnchannel import ChannelState, PeerState, Channel
|
||||
from electrum.lnrouter import LNPathFinder, PathEdge, LNPathInconsistent
|
||||
from electrum.channel_db import ChannelDB
|
||||
from electrum.lnworker import LNWallet, NoPathFound, SentHtlcInfo, PaySession
|
||||
from electrum.lnworker import LNWallet, NoPathFound, SentHtlcInfo, PaySession, LNPeerManager
|
||||
from electrum.lnmsg import encode_msg, decode_msg
|
||||
from electrum import lnmsg
|
||||
from electrum.logging import console_stderr_handler, Logger
|
||||
@@ -49,10 +51,11 @@ from electrum.interface import GracefulDisconnect
|
||||
from electrum.simple_config import SimpleConfig
|
||||
from electrum.fee_policy import FeeTimeEstimates, FEE_ETA_TARGETS
|
||||
from electrum.mpp_split import split_amount_normal
|
||||
from electrum.wallet import Abstract_Wallet
|
||||
|
||||
from .test_lnchannel import create_test_channels
|
||||
from .test_bitcoin import needs_test_with_all_chacha20_implementations
|
||||
from . import ElectrumTestCase
|
||||
from . import ElectrumTestCase, restore_wallet_from_text__for_unittest
|
||||
|
||||
|
||||
def keypair():
|
||||
@@ -62,9 +65,6 @@ def keypair():
|
||||
privkey=priv)
|
||||
return k1
|
||||
|
||||
@contextmanager
|
||||
def noop_lock():
|
||||
yield
|
||||
|
||||
class MockNetwork:
|
||||
def __init__(self, tx_queue, *, config: SimpleConfig):
|
||||
@@ -120,144 +120,100 @@ class MockADB:
|
||||
def get_local_height(self):
|
||||
return self._blockchain.height()
|
||||
|
||||
class MockWallet:
|
||||
receive_requests = {}
|
||||
adb = MockADB()
|
||||
|
||||
def get_invoice(self, key):
|
||||
pass
|
||||
|
||||
def get_request(self, key):
|
||||
pass
|
||||
|
||||
def get_key_for_receive_request(self, x):
|
||||
pass
|
||||
|
||||
def set_label(self, x, y):
|
||||
pass
|
||||
|
||||
def save_db(self):
|
||||
pass
|
||||
|
||||
def is_lightning_backup(self):
|
||||
return False
|
||||
|
||||
def is_mine(self, addr):
|
||||
return True
|
||||
|
||||
def get_fingerprint(self):
|
||||
return ''
|
||||
|
||||
def get_new_sweep_address_for_channel(self):
|
||||
# note: sweep is not tested here, only in regtest
|
||||
return "tb1qqu5newtapamjchgxf0nty6geuykhvwas45q4q4"
|
||||
|
||||
def is_up_to_date(self):
|
||||
return True
|
||||
|
||||
|
||||
class MockLNGossip:
|
||||
def get_sync_progress_estimate(self):
|
||||
return None, None, None
|
||||
|
||||
|
||||
class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
|
||||
class MockLNPeerManager(LNPeerManager):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
node_keypair,
|
||||
config: SimpleConfig,
|
||||
features: LnFeatures,
|
||||
lnwallet: LNWallet,
|
||||
network: 'MockNetwork',
|
||||
):
|
||||
LNPeerManager.__init__(
|
||||
self,
|
||||
node_keypair=node_keypair,
|
||||
lnwallet_or_lngossip=lnwallet,
|
||||
features=features,
|
||||
config=config,
|
||||
)
|
||||
self.network = network
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def _bip32_from_name(name: str) -> bip32.BIP32Node:
|
||||
# note: unlike a serialized xprv, the bip32 node can be cached easily,
|
||||
# as it does not depend on constant.net (testnet/mainnet) network bytes
|
||||
sequence = [ord(c) for c in name]
|
||||
bip32_node = bip32.BIP32Node.from_rootseed(b"9dk", xtype='standard').subkey_at_private_derivation(sequence)
|
||||
return bip32_node
|
||||
|
||||
|
||||
class MockLNWallet(LNWallet):
|
||||
MPP_EXPIRY = 2 # HTLC timestamps are cast to int, so this cannot be 1
|
||||
PAYMENT_TIMEOUT = 120
|
||||
TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 0
|
||||
MPP_SPLIT_PART_FRACTION = 1 # this disables the forced splitting
|
||||
MPP_SPLIT_PART_MINAMT_MSAT = 5_000_000
|
||||
|
||||
def __init__(self, *, local_keypair: Keypair, chans: Iterable['Channel'], tx_queue, name, has_anchors):
|
||||
def __init__(self, *, tx_queue, name, has_anchors, ln_xprv: str = None):
|
||||
self.name = name
|
||||
Logger.__init__(self)
|
||||
NetworkRetryManager.__init__(self, max_retry_delay_normal=1, init_retry_delay_normal=1)
|
||||
self.node_keypair = local_keypair
|
||||
self.payment_secret_key = os.urandom(32) # does not need to be deterministic in tests
|
||||
|
||||
self._user_dir = tempfile.mkdtemp(prefix="electrum-lnpeer-test-")
|
||||
self.config = SimpleConfig({}, read_user_dir_function=lambda: self._user_dir)
|
||||
self.network = MockNetwork(tx_queue, config=self.config)
|
||||
self.taskgroup = OldTaskGroup()
|
||||
self.config.ENABLE_ANCHOR_CHANNELS = has_anchors
|
||||
self.config.INITIAL_TRAMPOLINE_FEE_LEVEL = 0
|
||||
|
||||
network = MockNetwork(tx_queue, config=self.config)
|
||||
|
||||
wallet = restore_wallet_from_text__for_unittest(
|
||||
"9dk", path=None, passphrase=name, config=self.config)['wallet'] # type: Abstract_Wallet
|
||||
wallet.is_up_to_date = lambda: True
|
||||
wallet.adb.network = wallet.network = network
|
||||
|
||||
features = LnFeatures(0)
|
||||
features |= LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT
|
||||
features |= LnFeatures.OPTION_UPFRONT_SHUTDOWN_SCRIPT_OPT
|
||||
features |= LnFeatures.VAR_ONION_OPT
|
||||
features |= LnFeatures.PAYMENT_SECRET_OPT
|
||||
features |= LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM
|
||||
features |= LnFeatures.OPTION_CHANNEL_TYPE_OPT
|
||||
features |= LnFeatures.OPTION_SCID_ALIAS_OPT
|
||||
features |= LnFeatures.OPTION_STATIC_REMOTEKEY_OPT
|
||||
|
||||
if ln_xprv is None:
|
||||
ln_xprv = _bip32_from_name(name).to_xprv()
|
||||
LNWallet.__init__(self, wallet=wallet, xprv=ln_xprv, features=features)
|
||||
|
||||
self.lnpeermgr = MockLNPeerManager(
|
||||
node_keypair=self.node_keypair,
|
||||
config=self.config,
|
||||
features=features,
|
||||
lnwallet=self,
|
||||
network=network,
|
||||
)
|
||||
self.lnwatcher = None
|
||||
self.swap_manager = None
|
||||
self.onion_message_manager = None
|
||||
self.listen_server = None
|
||||
self._channels = {chan.channel_id: chan for chan in chans}
|
||||
self.payment_info = {}
|
||||
self.logs = defaultdict(list)
|
||||
self.wallet = MockWallet()
|
||||
self.features = LnFeatures(0)
|
||||
self.features |= LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT
|
||||
self.features |= LnFeatures.OPTION_UPFRONT_SHUTDOWN_SCRIPT_OPT
|
||||
self.features |= LnFeatures.VAR_ONION_OPT
|
||||
self.features |= LnFeatures.PAYMENT_SECRET_OPT
|
||||
self.features |= LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM
|
||||
self.features |= LnFeatures.OPTION_CHANNEL_TYPE_OPT
|
||||
self.features |= LnFeatures.OPTION_SCID_ALIAS_OPT
|
||||
self.features |= LnFeatures.OPTION_STATIC_REMOTEKEY_OPT
|
||||
self.config.ENABLE_ANCHOR_CHANNELS = has_anchors
|
||||
for chan in chans:
|
||||
chan.lnworker = self
|
||||
self._peers = {} # bytes -> Peer
|
||||
# used in tests
|
||||
self.enable_htlc_settle = True
|
||||
self.enable_htlc_forwarding = True
|
||||
self.received_mpp_htlcs = dict()
|
||||
self._paysessions = dict()
|
||||
self.sent_htlcs_info = dict()
|
||||
self.sent_buckets = defaultdict(set)
|
||||
self.active_forwardings = {}
|
||||
self.forwarding_failures = {}
|
||||
self.inflight_payments = set()
|
||||
self._preimages = {}
|
||||
self.stopping_soon = False
|
||||
self.downstream_to_upstream_htlc = {}
|
||||
self.dont_expire_htlcs = {}
|
||||
self.dont_settle_htlcs = {}
|
||||
self.hold_invoice_callbacks = {}
|
||||
self._payment_bundles_pkey_to_canon = {} # type: Dict[bytes, bytes]
|
||||
self._payment_bundles_canon_to_pkeylist = {} # type: Dict[bytes, Sequence[bytes]]
|
||||
self.config.INITIAL_TRAMPOLINE_FEE_LEVEL = 0
|
||||
self._channel_sending_capacity_lock = asyncio.Lock()
|
||||
|
||||
self.logger.info(f"created LNWallet[{name}] with nodeID={local_keypair.pubkey.hex()}")
|
||||
self.logger.info(f"created LNWallet[{name}] with nodeID={self.node_keypair.pubkey.hex()}")
|
||||
|
||||
def clear_invoices_cache(self):
|
||||
pass
|
||||
def _add_channel(self, chan: Channel):
|
||||
self._channels[chan.channel_id] = chan
|
||||
chan.lnworker = self
|
||||
|
||||
def get_invoice_status(self, key):
|
||||
pass
|
||||
|
||||
@property
|
||||
def lock(self):
|
||||
return noop_lock()
|
||||
|
||||
@property
|
||||
def channel_db(self):
|
||||
return self.network.channel_db if self.network else None
|
||||
|
||||
def uses_trampoline(self):
|
||||
return not bool(self.channel_db)
|
||||
|
||||
@property
|
||||
def channels(self):
|
||||
return self._channels
|
||||
|
||||
@property
|
||||
def peers(self):
|
||||
return self._peers
|
||||
|
||||
def get_channel_by_short_id(self, short_channel_id):
|
||||
with self.lock:
|
||||
for chan in self._channels.values():
|
||||
if chan.short_channel_id == short_channel_id:
|
||||
return chan
|
||||
|
||||
def channel_state_changed(self, chan):
|
||||
pass
|
||||
@LNWallet.features.setter
|
||||
def features(self, value):
|
||||
self.lnpeermgr.features = value
|
||||
|
||||
def save_channel(self, chan):
|
||||
print("Ignoring channel save")
|
||||
pass
|
||||
#print("Ignoring channel save")
|
||||
|
||||
def diagnostic_name(self):
|
||||
return self.name
|
||||
@@ -290,69 +246,6 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
|
||||
budget=PaymentFeeBudget.from_invoice_amount(invoice_amount_msat=amount_msat, config=self.config),
|
||||
)]
|
||||
|
||||
get_payments = LNWallet.get_payments
|
||||
get_payment_secret = LNWallet.get_payment_secret
|
||||
get_payment_info = LNWallet.get_payment_info
|
||||
save_payment_info = LNWallet.save_payment_info
|
||||
set_invoice_status = LNWallet.set_invoice_status
|
||||
set_request_status = LNWallet.set_request_status
|
||||
set_payment_status = LNWallet.set_payment_status
|
||||
get_payment_status = LNWallet.get_payment_status
|
||||
htlc_fulfilled = LNWallet.htlc_fulfilled
|
||||
htlc_failed = LNWallet.htlc_failed
|
||||
save_preimage = LNWallet.save_preimage
|
||||
get_preimage = LNWallet.get_preimage
|
||||
create_route_for_single_htlc = LNWallet.create_route_for_single_htlc
|
||||
create_routes_for_payment = LNWallet.create_routes_for_payment
|
||||
_check_bolt11_invoice = LNWallet._check_bolt11_invoice
|
||||
pay_to_route = LNWallet.pay_to_route
|
||||
pay_to_node = LNWallet.pay_to_node
|
||||
pay_invoice = LNWallet.pay_invoice
|
||||
force_close_channel = LNWallet.force_close_channel
|
||||
schedule_force_closing = LNWallet.schedule_force_closing
|
||||
on_peer_successfully_established = LNWallet.on_peer_successfully_established
|
||||
get_channel_by_id = LNWallet.get_channel_by_id
|
||||
channels_for_peer = LNWallet.channels_for_peer
|
||||
calc_routing_hints_for_invoice = LNWallet.calc_routing_hints_for_invoice
|
||||
get_channels_for_receiving = LNWallet.get_channels_for_receiving
|
||||
handle_error_code_from_failed_htlc = LNWallet.handle_error_code_from_failed_htlc
|
||||
is_trampoline_peer = LNWallet.is_trampoline_peer
|
||||
wait_for_received_pending_htlcs_to_get_removed = LNWallet.wait_for_received_pending_htlcs_to_get_removed
|
||||
#on_event_proxy_set = LNWallet.on_event_proxy_set
|
||||
_decode_channel_update_msg = LNWallet._decode_channel_update_msg
|
||||
_handle_chanupd_from_failed_htlc = LNWallet._handle_chanupd_from_failed_htlc
|
||||
is_forwarded_htlc = LNWallet.is_forwarded_htlc
|
||||
notify_upstream_peer = LNWallet.notify_upstream_peer
|
||||
_force_close_channel = LNWallet._force_close_channel
|
||||
suggest_payment_splits = LNWallet.suggest_payment_splits
|
||||
register_hold_invoice = LNWallet.register_hold_invoice
|
||||
unregister_hold_invoice = LNWallet.unregister_hold_invoice
|
||||
add_payment_info_for_hold_invoice = LNWallet.add_payment_info_for_hold_invoice
|
||||
|
||||
update_or_create_mpp_with_received_htlc = LNWallet.update_or_create_mpp_with_received_htlc
|
||||
set_mpp_resolution = LNWallet.set_mpp_resolution
|
||||
get_mpp_amounts = LNWallet.get_mpp_amounts
|
||||
bundle_payments = LNWallet.bundle_payments
|
||||
get_payment_bundle = LNWallet.get_payment_bundle
|
||||
_get_payment_key = LNWallet._get_payment_key
|
||||
save_forwarding_failure = LNWallet.save_forwarding_failure
|
||||
get_forwarding_failure = LNWallet.get_forwarding_failure
|
||||
maybe_cleanup_forwarding = LNWallet.maybe_cleanup_forwarding
|
||||
current_target_feerate_per_kw = LNWallet.current_target_feerate_per_kw
|
||||
current_low_feerate_per_kw_srk_channel = LNWallet.current_low_feerate_per_kw_srk_channel
|
||||
create_onion_for_route = LNWallet.create_onion_for_route
|
||||
maybe_forward_htlc_set = LNWallet.maybe_forward_htlc_set
|
||||
_maybe_forward_htlc = LNWallet._maybe_forward_htlc
|
||||
_maybe_forward_trampoline = LNWallet._maybe_forward_trampoline
|
||||
_maybe_refuse_to_forward_htlc_that_corresponds_to_payreq_we_created = LNWallet._maybe_refuse_to_forward_htlc_that_corresponds_to_payreq_we_created
|
||||
set_htlc_set_error = LNWallet.set_htlc_set_error
|
||||
is_payment_bundle_complete = LNWallet.is_payment_bundle_complete
|
||||
delete_payment_bundle = LNWallet.delete_payment_bundle
|
||||
_process_htlc_log = LNWallet._process_htlc_log
|
||||
_get_invoice_features = LNWallet._get_invoice_features
|
||||
receive_requires_jit_channel = LNWallet.receive_requires_jit_channel
|
||||
can_get_zeroconf_channel = LNWallet.can_get_zeroconf_channel
|
||||
|
||||
|
||||
class MockTransport:
|
||||
def __init__(self, name):
|
||||
@@ -667,25 +560,24 @@ class TestPeerDirect(TestPeer):
|
||||
|
||||
def prepare_peers(
|
||||
self, alice_channel: Channel, bob_channel: Channel,
|
||||
*, k1: Keypair = None, k2: Keypair = None,
|
||||
):
|
||||
if k1 is None:
|
||||
k1 = keypair()
|
||||
if k2 is None:
|
||||
k2 = keypair()
|
||||
q1, q2 = asyncio.Queue(), asyncio.Queue()
|
||||
w1 = MockLNWallet(tx_queue=q1, name=bob_channel.name, has_anchors=self.TEST_ANCHOR_CHANNELS)
|
||||
w2 = MockLNWallet(tx_queue=q2, name=alice_channel.name, has_anchors=self.TEST_ANCHOR_CHANNELS)
|
||||
k1 = w1.node_keypair
|
||||
k2 = w2.node_keypair
|
||||
alice_channel.node_id = k2.pubkey
|
||||
bob_channel.node_id = k1.pubkey
|
||||
alice_channel.storage['node_id'] = alice_channel.node_id
|
||||
bob_channel.storage['node_id'] = bob_channel.node_id
|
||||
t1, t2 = transport_pair(k1, k2, alice_channel.name, bob_channel.name)
|
||||
q1, q2 = asyncio.Queue(), asyncio.Queue()
|
||||
w1 = MockLNWallet(local_keypair=k1, chans=[alice_channel], tx_queue=q1, name=bob_channel.name, has_anchors=self.TEST_ANCHOR_CHANNELS)
|
||||
w2 = MockLNWallet(local_keypair=k2, chans=[bob_channel], tx_queue=q2, name=alice_channel.name, has_anchors=self.TEST_ANCHOR_CHANNELS)
|
||||
w1._add_channel(alice_channel)
|
||||
w2._add_channel(bob_channel)
|
||||
self._lnworkers_created.extend([w1, w2])
|
||||
p1 = PeerInTests(w1, k2.pubkey, t1)
|
||||
p2 = PeerInTests(w2, k1.pubkey, t2)
|
||||
w1._peers[p1.pubkey] = p1
|
||||
w2._peers[p2.pubkey] = p2
|
||||
w1.lnpeermgr._peers[p1.pubkey] = p1
|
||||
w2.lnpeermgr._peers[p2.pubkey] = p2
|
||||
# mark_open won't work if state is already OPEN.
|
||||
# so set it to FUNDED
|
||||
alice_channel._state = ChannelState.FUNDED
|
||||
@@ -790,10 +682,9 @@ class TestPeerDirect(TestPeer):
|
||||
----sig-->
|
||||
"""
|
||||
chan_AB, chan_BA = create_test_channels()
|
||||
k1, k2 = keypair(), keypair()
|
||||
# note: we don't start peer.htlc_switch() so that the fake htlcs are left alone.
|
||||
async def f():
|
||||
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA, k1=k1, k2=k2)
|
||||
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA)
|
||||
async with OldTaskGroup() as group:
|
||||
await group.spawn(p1._message_loop())
|
||||
await group.spawn(p2._message_loop())
|
||||
@@ -807,7 +698,7 @@ class TestPeerDirect(TestPeer):
|
||||
await group.cancel_remaining()
|
||||
# simulating disconnection. recreate transports.
|
||||
self.logger.info("simulating disconnection. recreating transports.")
|
||||
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA, k1=k1, k2=k2)
|
||||
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA)
|
||||
for chan in (chan_AB, chan_BA):
|
||||
chan.peer_state = PeerState.DISCONNECTED
|
||||
async with OldTaskGroup() as group:
|
||||
@@ -846,10 +737,9 @@ class TestPeerDirect(TestPeer):
|
||||
----rev-->
|
||||
"""
|
||||
chan_AB, chan_BA = create_test_channels()
|
||||
k1, k2 = keypair(), keypair()
|
||||
# note: we don't start peer.htlc_switch() so that the fake htlcs are left alone.
|
||||
async def f():
|
||||
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA, k1=k1, k2=k2)
|
||||
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA)
|
||||
async with OldTaskGroup() as group:
|
||||
await group.spawn(p1._message_loop())
|
||||
await group.spawn(p2._message_loop())
|
||||
@@ -864,7 +754,7 @@ class TestPeerDirect(TestPeer):
|
||||
await group.cancel_remaining()
|
||||
# simulating disconnection. recreate transports.
|
||||
self.logger.info("simulating disconnection. recreating transports.")
|
||||
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA, k1=k1, k2=k2)
|
||||
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA)
|
||||
for chan in (chan_AB, chan_BA):
|
||||
chan.peer_state = PeerState.DISCONNECTED
|
||||
async with OldTaskGroup() as group:
|
||||
@@ -1788,7 +1678,7 @@ class TestPeerDirect(TestPeer):
|
||||
with self.assertRaises(NoPathFound) as e:
|
||||
await w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr)
|
||||
|
||||
peer = w1.peers[route[0].node_id]
|
||||
peer = w1.lnpeermgr._peers[route[0].node_id]
|
||||
# AssertionError is ok since we shouldn't use old routes, and the
|
||||
# route finding should fail when channel is closed
|
||||
async def f():
|
||||
@@ -2126,12 +2016,19 @@ class TestPeerDirect(TestPeer):
|
||||
class TestPeerForwarding(TestPeer):
|
||||
|
||||
def prepare_chans_and_peers_in_graph(self, graph_definition) -> Graph:
|
||||
keys = {k: keypair() for k in graph_definition}
|
||||
workers = {} # type: Dict[str, MockLNWallet]
|
||||
txs_queues = {k: asyncio.Queue() for k in graph_definition}
|
||||
|
||||
# create workers
|
||||
for a, definition in graph_definition.items():
|
||||
workers[a] = MockLNWallet(tx_queue=txs_queues[a], name=a, has_anchors=self.TEST_ANCHOR_CHANNELS)
|
||||
self._lnworkers_created.extend(list(workers.values()))
|
||||
keys = {name: w.node_keypair for name, w in workers.items()}
|
||||
|
||||
channels = {} # type: Dict[Tuple[str, str], Channel]
|
||||
transports = {}
|
||||
workers = {} # type: Dict[str, MockLNWallet]
|
||||
peers = {}
|
||||
|
||||
# create channels
|
||||
for a, definition in graph_definition.items():
|
||||
for b, channel_def in definition.get('channels', {}).items():
|
||||
@@ -2145,6 +2042,8 @@ class TestPeerForwarding(TestPeer):
|
||||
anchor_outputs=self.TEST_ANCHOR_CHANNELS
|
||||
)
|
||||
channels[(a, b)], channels[(b, a)] = channel_ab, channel_ba
|
||||
workers[a]._add_channel(channel_ab)
|
||||
workers[b]._add_channel(channel_ba)
|
||||
transport_ab, transport_ba = transport_pair(keys[a], keys[b], channel_ab.name, channel_ba.name)
|
||||
transports[(a, b)], transports[(b, a)] = transport_ab, transport_ba
|
||||
# set fees
|
||||
@@ -2153,12 +2052,6 @@ class TestPeerForwarding(TestPeer):
|
||||
channel_ba.forwarding_fee_proportional_millionths = channel_def['remote_fee_rate_millionths']
|
||||
channel_ba.forwarding_fee_base_msat = channel_def['remote_base_fee_msat']
|
||||
|
||||
# create workers and peers
|
||||
for a, definition in graph_definition.items():
|
||||
channels_of_node = [c for k, c in channels.items() if k[0] == a]
|
||||
workers[a] = MockLNWallet(local_keypair=keys[a], chans=channels_of_node, tx_queue=txs_queues[a], name=a, has_anchors=self.TEST_ANCHOR_CHANNELS)
|
||||
self._lnworkers_created.extend(list(workers.values()))
|
||||
|
||||
# create peers
|
||||
for ab in channels.keys():
|
||||
peers[ab] = Peer(workers[ab[0]], keys[ab[1]].pubkey, transports[ab])
|
||||
@@ -2167,7 +2060,7 @@ class TestPeerForwarding(TestPeer):
|
||||
for a, w in workers.items():
|
||||
for ab, peer_ab in peers.items():
|
||||
if ab[0] == a:
|
||||
w._peers[peer_ab.pubkey] = peer_ab
|
||||
w.lnpeermgr._peers[peer_ab.pubkey] = peer_ab
|
||||
|
||||
# set forwarding properties
|
||||
for a, definition in graph_definition.items():
|
||||
|
||||
+11
-10
@@ -352,9 +352,8 @@ class TestOnionMessageManager(ElectrumTestCase):
|
||||
|
||||
async def test_request_and_reply(self):
|
||||
n = MockNetwork()
|
||||
k = keypair()
|
||||
q1, q2 = asyncio.Queue(), asyncio.Queue()
|
||||
lnw = MockLNWallet(local_keypair=k, chans=[], tx_queue=q1, name='test_request_and_reply', has_anchors=False)
|
||||
lnw = MockLNWallet(tx_queue=q1, name='test_request_and_reply', has_anchors=False)
|
||||
|
||||
def slow(*args, **kwargs):
|
||||
time.sleep(2*TIME_STEP)
|
||||
@@ -369,10 +368,10 @@ class TestOnionMessageManager(ElectrumTestCase):
|
||||
rkey1 = bfh('0102030405060708')
|
||||
rkey2 = bfh('0102030405060709')
|
||||
|
||||
lnw.peers[self.alice.pubkey] = MockPeer(self.alice.pubkey)
|
||||
lnw.peers[self.bob.pubkey] = MockPeer(self.bob.pubkey, on_send_message=slow)
|
||||
lnw.peers[self.carol.pubkey] = MockPeer(self.carol.pubkey, on_send_message=partial(withreply, rkey1))
|
||||
lnw.peers[self.dave.pubkey] = MockPeer(self.dave.pubkey, on_send_message=partial(slowwithreply, rkey2))
|
||||
lnw.lnpeermgr._peers[self.alice.pubkey] = MockPeer(self.alice.pubkey)
|
||||
lnw.lnpeermgr._peers[self.bob.pubkey] = MockPeer(self.bob.pubkey, on_send_message=slow)
|
||||
lnw.lnpeermgr._peers[self.carol.pubkey] = MockPeer(self.carol.pubkey, on_send_message=partial(withreply, rkey1))
|
||||
lnw.lnpeermgr._peers[self.dave.pubkey] = MockPeer(self.dave.pubkey, on_send_message=partial(slowwithreply, rkey2))
|
||||
t = OnionMessageManager(lnw)
|
||||
t.start_network(network=n)
|
||||
|
||||
@@ -401,7 +400,8 @@ class TestOnionMessageManager(ElectrumTestCase):
|
||||
async def test_forward(self):
|
||||
n = MockNetwork()
|
||||
q1 = asyncio.Queue()
|
||||
lnw = MockLNWallet(local_keypair=self.alice, chans=[], tx_queue=q1, name='alice', has_anchors=False)
|
||||
lnw = MockLNWallet(tx_queue=q1, name='alice', has_anchors=False)
|
||||
lnw.node_keypair = self.alice
|
||||
|
||||
self.was_sent = False
|
||||
|
||||
@@ -414,8 +414,8 @@ class TestOnionMessageManager(ElectrumTestCase):
|
||||
self.assertEqual(message_type, 'onion_message')
|
||||
self.assertEqual(payload['onion_message_packet'], kwargs['onion_message_packet'])
|
||||
|
||||
lnw.peers[self.bob.pubkey] = MockPeer(self.bob.pubkey, on_send_message=partial(on_send, 'bob'))
|
||||
lnw.peers[self.carol.pubkey] = MockPeer(self.carol.pubkey, on_send_message=partial(on_send, 'carol'))
|
||||
lnw.lnpeermgr._peers[self.bob.pubkey] = MockPeer(self.bob.pubkey, on_send_message=partial(on_send, 'bob'))
|
||||
lnw.lnpeermgr._peers[self.carol.pubkey] = MockPeer(self.carol.pubkey, on_send_message=partial(on_send, 'carol'))
|
||||
t = OnionMessageManager(lnw)
|
||||
t.start_network(network=n)
|
||||
|
||||
@@ -438,7 +438,8 @@ class TestOnionMessageManager(ElectrumTestCase):
|
||||
async def test_receive_unsolicited(self):
|
||||
n = MockNetwork()
|
||||
q1 = asyncio.Queue()
|
||||
lnw = MockLNWallet(local_keypair=self.dave, chans=[], tx_queue=q1, name='dave', has_anchors=False)
|
||||
lnw = MockLNWallet(tx_queue=q1, name='dave', has_anchors=False)
|
||||
lnw.node_keypair = self.dave
|
||||
|
||||
t = OnionMessageManager(lnw)
|
||||
t.start_network(network=n)
|
||||
|
||||
Reference in New Issue
Block a user