We pass the private edges to lnrouter, and let it find routes end-to-end. Previously the edge_cost heuristics didn't apply to the private edges and we were just randomly picking one of the route hints and use that. So e.g. cheaper private edges were not preferred, but they are now. PathEdge now stores both start_node and end_node; not just end_node.
411 lines
18 KiB
Python
411 lines
18 KiB
Python
# -*- coding: utf-8 -*-
|
|
#
|
|
# Electrum - lightweight Bitcoin client
|
|
# Copyright (C) 2018 The Electrum developers
|
|
#
|
|
# Permission is hereby granted, free of charge, to any person
|
|
# obtaining a copy of this software and associated documentation files
|
|
# (the "Software"), to deal in the Software without restriction,
|
|
# including without limitation the rights to use, copy, modify, merge,
|
|
# publish, distribute, sublicense, and/or sell copies of the Software,
|
|
# and to permit persons to whom the Software is furnished to do so,
|
|
# subject to the following conditions:
|
|
#
|
|
# The above copyright notice and this permission notice shall be
|
|
# included in all copies or substantial portions of the Software.
|
|
#
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
|
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
|
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
|
|
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
|
|
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
|
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
# SOFTWARE.
|
|
|
|
import queue
|
|
from collections import defaultdict
|
|
from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set
|
|
import time
|
|
import attr
|
|
|
|
from .util import bh2u, profiler
|
|
from .logging import Logger
|
|
from .lnutil import (NUM_MAX_EDGES_IN_PAYMENT_PATH, ShortChannelID, LnFeatures,
|
|
NBLOCK_CLTV_EXPIRY_TOO_FAR_INTO_FUTURE)
|
|
from .channel_db import ChannelDB, Policy, NodeInfo
|
|
|
|
if TYPE_CHECKING:
|
|
from .lnchannel import Channel
|
|
|
|
|
|
class NoChannelPolicy(Exception):
|
|
def __init__(self, short_channel_id: bytes):
|
|
short_channel_id = ShortChannelID.normalize(short_channel_id)
|
|
super().__init__(f'cannot find channel policy for short_channel_id: {short_channel_id}')
|
|
|
|
|
|
class LNPathInconsistent(Exception): pass
|
|
|
|
|
|
def fee_for_edge_msat(forwarded_amount_msat: int, fee_base_msat: int, fee_proportional_millionths: int) -> int:
|
|
return fee_base_msat \
|
|
+ (forwarded_amount_msat * fee_proportional_millionths // 1_000_000)
|
|
|
|
|
|
@attr.s(slots=True)
|
|
class PathEdge:
|
|
start_node = attr.ib(type=bytes, kw_only=True, repr=lambda val: val.hex())
|
|
end_node = attr.ib(type=bytes, kw_only=True, repr=lambda val: val.hex())
|
|
short_channel_id = attr.ib(type=ShortChannelID, kw_only=True, repr=lambda val: str(val))
|
|
|
|
@property
|
|
def node_id(self) -> bytes:
|
|
# legacy compat # TODO rm
|
|
return self.end_node
|
|
|
|
@attr.s
|
|
class RouteEdge(PathEdge):
|
|
fee_base_msat = attr.ib(type=int, kw_only=True)
|
|
fee_proportional_millionths = attr.ib(type=int, kw_only=True)
|
|
cltv_expiry_delta = attr.ib(type=int, kw_only=True)
|
|
node_features = attr.ib(type=int, kw_only=True, repr=lambda val: str(int(val))) # note: for end node!
|
|
|
|
def fee_for_edge(self, amount_msat: int) -> int:
|
|
return fee_for_edge_msat(forwarded_amount_msat=amount_msat,
|
|
fee_base_msat=self.fee_base_msat,
|
|
fee_proportional_millionths=self.fee_proportional_millionths)
|
|
|
|
@classmethod
|
|
def from_channel_policy(
|
|
cls,
|
|
*,
|
|
channel_policy: 'Policy',
|
|
short_channel_id: bytes,
|
|
start_node: bytes,
|
|
end_node: bytes,
|
|
node_info: Optional[NodeInfo], # for end_node
|
|
) -> 'RouteEdge':
|
|
assert isinstance(short_channel_id, bytes)
|
|
assert type(start_node) is bytes
|
|
assert type(end_node) is bytes
|
|
return RouteEdge(
|
|
start_node=start_node,
|
|
end_node=end_node,
|
|
short_channel_id=ShortChannelID.normalize(short_channel_id),
|
|
fee_base_msat=channel_policy.fee_base_msat,
|
|
fee_proportional_millionths=channel_policy.fee_proportional_millionths,
|
|
cltv_expiry_delta=channel_policy.cltv_expiry_delta,
|
|
node_features=node_info.features if node_info else 0)
|
|
|
|
def is_sane_to_use(self, amount_msat: int) -> bool:
|
|
# TODO revise ad-hoc heuristics
|
|
# cltv cannot be more than 2 weeks
|
|
if self.cltv_expiry_delta > 14 * 144:
|
|
return False
|
|
total_fee = self.fee_for_edge(amount_msat)
|
|
if not is_fee_sane(total_fee, payment_amount_msat=amount_msat):
|
|
return False
|
|
return True
|
|
|
|
def has_feature_varonion(self) -> bool:
|
|
features = LnFeatures(self.node_features)
|
|
return features.supports(LnFeatures.VAR_ONION_OPT)
|
|
|
|
def is_trampoline(self) -> bool:
|
|
return False
|
|
|
|
@attr.s
|
|
class TrampolineEdge(RouteEdge):
|
|
invoice_routing_info = attr.ib(type=bytes, default=None)
|
|
invoice_features = attr.ib(type=int, default=None)
|
|
# this is re-defined from parent just to specify a default value:
|
|
short_channel_id = attr.ib(default=ShortChannelID(8), repr=lambda val: str(val))
|
|
|
|
def is_trampoline(self):
|
|
return True
|
|
|
|
|
|
LNPaymentPath = Sequence[PathEdge]
|
|
LNPaymentRoute = Sequence[RouteEdge]
|
|
|
|
|
|
def is_route_sane_to_use(route: LNPaymentRoute, invoice_amount_msat: int, min_final_cltv_expiry: int) -> bool:
|
|
"""Run some sanity checks on the whole route, before attempting to use it.
|
|
called when we are paying; so e.g. lower cltv is better
|
|
"""
|
|
if len(route) > NUM_MAX_EDGES_IN_PAYMENT_PATH:
|
|
return False
|
|
amt = invoice_amount_msat
|
|
cltv = min_final_cltv_expiry
|
|
for route_edge in reversed(route[1:]):
|
|
if not route_edge.is_sane_to_use(amt): return False
|
|
amt += route_edge.fee_for_edge(amt)
|
|
cltv += route_edge.cltv_expiry_delta
|
|
total_fee = amt - invoice_amount_msat
|
|
# TODO revise ad-hoc heuristics
|
|
if cltv > NBLOCK_CLTV_EXPIRY_TOO_FAR_INTO_FUTURE:
|
|
return False
|
|
if not is_fee_sane(total_fee, payment_amount_msat=invoice_amount_msat):
|
|
return False
|
|
return True
|
|
|
|
|
|
def is_fee_sane(fee_msat: int, *, payment_amount_msat: int) -> bool:
|
|
# fees <= 5 sat are fine
|
|
if fee_msat <= 5_000:
|
|
return True
|
|
# fees <= 1 % of payment are fine
|
|
if 100 * fee_msat <= payment_amount_msat:
|
|
return True
|
|
return False
|
|
|
|
|
|
|
|
class LNPathFinder(Logger):
|
|
|
|
def __init__(self, channel_db: ChannelDB):
|
|
Logger.__init__(self)
|
|
self.channel_db = channel_db
|
|
|
|
def _edge_cost(
|
|
self,
|
|
*,
|
|
short_channel_id: bytes,
|
|
start_node: bytes,
|
|
end_node: bytes,
|
|
payment_amt_msat: int,
|
|
ignore_costs=False,
|
|
is_mine=False,
|
|
my_channels: Dict[ShortChannelID, 'Channel'] = None,
|
|
private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
|
|
) -> Tuple[float, int]:
|
|
"""Heuristic cost (distance metric) of going through a channel.
|
|
Returns (heuristic_cost, fee_for_edge_msat).
|
|
"""
|
|
if private_route_edges is None:
|
|
private_route_edges = {}
|
|
channel_info = self.channel_db.get_channel_info(
|
|
short_channel_id, my_channels=my_channels, private_route_edges=private_route_edges)
|
|
if channel_info is None:
|
|
return float('inf'), 0
|
|
channel_policy = self.channel_db.get_policy_for_node(
|
|
short_channel_id, start_node, my_channels=my_channels, private_route_edges=private_route_edges)
|
|
if channel_policy is None:
|
|
return float('inf'), 0
|
|
# channels that did not publish both policies often return temporary channel failure
|
|
channel_policy_backwards = self.channel_db.get_policy_for_node(
|
|
short_channel_id, end_node, my_channels=my_channels, private_route_edges=private_route_edges)
|
|
if (channel_policy_backwards is None
|
|
and not is_mine
|
|
and short_channel_id not in private_route_edges):
|
|
return float('inf'), 0
|
|
if channel_policy.is_disabled():
|
|
return float('inf'), 0
|
|
if payment_amt_msat < channel_policy.htlc_minimum_msat:
|
|
return float('inf'), 0 # payment amount too little
|
|
if channel_info.capacity_sat is not None and \
|
|
payment_amt_msat // 1000 > channel_info.capacity_sat:
|
|
return float('inf'), 0 # payment amount too large
|
|
if channel_policy.htlc_maximum_msat is not None and \
|
|
payment_amt_msat > channel_policy.htlc_maximum_msat:
|
|
return float('inf'), 0 # payment amount too large
|
|
route_edge = private_route_edges.get(short_channel_id, None)
|
|
if route_edge is None:
|
|
node_info = self.channel_db.get_node_info_for_node_id(node_id=end_node)
|
|
route_edge = RouteEdge.from_channel_policy(
|
|
channel_policy=channel_policy,
|
|
short_channel_id=short_channel_id,
|
|
start_node=start_node,
|
|
end_node=end_node,
|
|
node_info=node_info)
|
|
if not route_edge.is_sane_to_use(payment_amt_msat):
|
|
return float('inf'), 0 # thanks but no thanks
|
|
|
|
# Distance metric notes: # TODO constants are ad-hoc
|
|
# ( somewhat based on https://github.com/lightningnetwork/lnd/pull/1358 )
|
|
# - Edges have a base cost. (more edges -> less likely none will fail)
|
|
# - The larger the payment amount, and the longer the CLTV,
|
|
# the more irritating it is if the HTLC gets stuck.
|
|
# - Paying lower fees is better. :)
|
|
base_cost = 500 # one more edge ~ paying 500 msat more fees
|
|
if ignore_costs:
|
|
return base_cost, 0
|
|
fee_msat = route_edge.fee_for_edge(payment_amt_msat)
|
|
cltv_cost = route_edge.cltv_expiry_delta * payment_amt_msat * 15 / 1_000_000_000
|
|
overall_cost = base_cost + fee_msat + cltv_cost
|
|
return overall_cost, fee_msat
|
|
|
|
def get_distances(
|
|
self,
|
|
*,
|
|
nodeA: bytes,
|
|
nodeB: bytes,
|
|
invoice_amount_msat: int,
|
|
my_channels: Dict[ShortChannelID, 'Channel'] = None,
|
|
blacklist: Set[ShortChannelID] = None,
|
|
private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
|
|
) -> Dict[bytes, PathEdge]:
|
|
# note: we don't lock self.channel_db, so while the path finding runs,
|
|
# the underlying graph could potentially change... (not good but maybe ~OK?)
|
|
|
|
# run Dijkstra
|
|
# The search is run in the REVERSE direction, from nodeB to nodeA,
|
|
# to properly calculate compound routing fees.
|
|
distance_from_start = defaultdict(lambda: float('inf'))
|
|
distance_from_start[nodeB] = 0
|
|
prev_node = {} # type: Dict[bytes, PathEdge]
|
|
nodes_to_explore = queue.PriorityQueue()
|
|
nodes_to_explore.put((0, invoice_amount_msat, nodeB)) # order of fields (in tuple) matters!
|
|
|
|
# main loop of search
|
|
while nodes_to_explore.qsize() > 0:
|
|
dist_to_edge_endnode, amount_msat, edge_endnode = nodes_to_explore.get()
|
|
if edge_endnode == nodeA:
|
|
break
|
|
if dist_to_edge_endnode != distance_from_start[edge_endnode]:
|
|
# queue.PriorityQueue does not implement decrease_priority,
|
|
# so instead of decreasing priorities, we add items again into the queue.
|
|
# so there are duplicates in the queue, that we discard now:
|
|
continue
|
|
for edge_channel_id in self.channel_db.get_channels_for_node(
|
|
edge_endnode, my_channels=my_channels, private_route_edges=private_route_edges):
|
|
assert isinstance(edge_channel_id, bytes)
|
|
if blacklist and edge_channel_id in blacklist:
|
|
continue
|
|
channel_info = self.channel_db.get_channel_info(
|
|
edge_channel_id, my_channels=my_channels, private_route_edges=private_route_edges)
|
|
edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id
|
|
is_mine = edge_channel_id in my_channels
|
|
if is_mine:
|
|
if edge_startnode == nodeA: # payment outgoing, on our channel
|
|
if not my_channels[edge_channel_id].can_pay(amount_msat, check_frozen=True):
|
|
continue
|
|
else: # payment incoming, on our channel. (funny business, cycle weirdness)
|
|
assert edge_endnode == nodeA, (bh2u(edge_startnode), bh2u(edge_endnode))
|
|
if not my_channels[edge_channel_id].can_receive(amount_msat, check_frozen=True):
|
|
continue
|
|
edge_cost, fee_for_edge_msat = self._edge_cost(
|
|
short_channel_id=edge_channel_id,
|
|
start_node=edge_startnode,
|
|
end_node=edge_endnode,
|
|
payment_amt_msat=amount_msat,
|
|
ignore_costs=(edge_startnode == nodeA),
|
|
is_mine=is_mine,
|
|
my_channels=my_channels,
|
|
private_route_edges=private_route_edges)
|
|
alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost
|
|
if alt_dist_to_neighbour < distance_from_start[edge_startnode]:
|
|
distance_from_start[edge_startnode] = alt_dist_to_neighbour
|
|
prev_node[edge_startnode] = PathEdge(
|
|
start_node=edge_startnode,
|
|
end_node=edge_endnode,
|
|
short_channel_id=ShortChannelID(edge_channel_id))
|
|
amount_to_forward_msat = amount_msat + fee_for_edge_msat
|
|
nodes_to_explore.put((alt_dist_to_neighbour, amount_to_forward_msat, edge_startnode))
|
|
|
|
return prev_node
|
|
|
|
@profiler
|
|
def find_path_for_payment(
|
|
self,
|
|
*,
|
|
nodeA: bytes,
|
|
nodeB: bytes,
|
|
invoice_amount_msat: int,
|
|
my_channels: Dict[ShortChannelID, 'Channel'] = None,
|
|
blacklist: Set[ShortChannelID] = None,
|
|
private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
|
|
) -> Optional[LNPaymentPath]:
|
|
"""Return a path from nodeA to nodeB."""
|
|
assert type(nodeA) is bytes
|
|
assert type(nodeB) is bytes
|
|
assert type(invoice_amount_msat) is int
|
|
if my_channels is None:
|
|
my_channels = {}
|
|
|
|
prev_node = self.get_distances(
|
|
nodeA=nodeA,
|
|
nodeB=nodeB,
|
|
invoice_amount_msat=invoice_amount_msat,
|
|
my_channels=my_channels,
|
|
blacklist=blacklist,
|
|
private_route_edges=private_route_edges)
|
|
|
|
if nodeA not in prev_node:
|
|
return None # no path found
|
|
|
|
# backtrack from search_end (nodeA) to search_start (nodeB)
|
|
# FIXME paths cannot be longer than 20 edges (onion packet)...
|
|
edge_startnode = nodeA
|
|
path = []
|
|
while edge_startnode != nodeB:
|
|
edge = prev_node[edge_startnode]
|
|
path += [edge]
|
|
edge_startnode = edge.node_id
|
|
return path
|
|
|
|
def create_route_from_path(
|
|
self,
|
|
path: Optional[LNPaymentPath],
|
|
*,
|
|
my_channels: Dict[ShortChannelID, 'Channel'] = None,
|
|
private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
|
|
) -> LNPaymentRoute:
|
|
if path is None:
|
|
raise Exception('cannot create route from None path')
|
|
if private_route_edges is None:
|
|
private_route_edges = {}
|
|
route = []
|
|
prev_end_node = path[0].start_node
|
|
for path_edge in path:
|
|
short_channel_id = path_edge.short_channel_id
|
|
_endnodes = self.channel_db.get_endnodes_for_chan(short_channel_id, my_channels=my_channels)
|
|
if _endnodes and sorted(_endnodes) != sorted([path_edge.start_node, path_edge.end_node]):
|
|
raise LNPathInconsistent("endpoints of edge inconsistent with short_channel_id")
|
|
if path_edge.start_node != prev_end_node:
|
|
raise LNPathInconsistent("edges do not chain together")
|
|
route_edge = private_route_edges.get(short_channel_id, None)
|
|
if route_edge is None:
|
|
channel_policy = self.channel_db.get_policy_for_node(
|
|
short_channel_id=short_channel_id,
|
|
node_id=path_edge.start_node,
|
|
my_channels=my_channels)
|
|
if channel_policy is None:
|
|
raise NoChannelPolicy(short_channel_id)
|
|
node_info = self.channel_db.get_node_info_for_node_id(node_id=path_edge.end_node)
|
|
route_edge = RouteEdge.from_channel_policy(
|
|
channel_policy=channel_policy,
|
|
short_channel_id=short_channel_id,
|
|
start_node=path_edge.start_node,
|
|
end_node=path_edge.end_node,
|
|
node_info=node_info)
|
|
route.append(route_edge)
|
|
prev_end_node = path_edge.end_node
|
|
return route
|
|
|
|
def find_route(
|
|
self,
|
|
*,
|
|
nodeA: bytes,
|
|
nodeB: bytes,
|
|
invoice_amount_msat: int,
|
|
path = None,
|
|
my_channels: Dict[ShortChannelID, 'Channel'] = None,
|
|
blacklist: Set[ShortChannelID] = None,
|
|
private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
|
|
) -> Optional[LNPaymentRoute]:
|
|
route = None
|
|
if not path:
|
|
path = self.find_path_for_payment(
|
|
nodeA=nodeA,
|
|
nodeB=nodeB,
|
|
invoice_amount_msat=invoice_amount_msat,
|
|
my_channels=my_channels,
|
|
blacklist=blacklist,
|
|
private_route_edges=private_route_edges)
|
|
if path:
|
|
route = self.create_route_from_path(
|
|
path, my_channels=my_channels, private_route_edges=private_route_edges)
|
|
return route
|