diff --git a/electrum/onion_message.py b/electrum/onion_message.py index a2f53606e..a98775eb7 100644 --- a/electrum/onion_message.py +++ b/electrum/onion_message.py @@ -27,9 +27,6 @@ import io import os import threading import time -import dataclasses -from random import random -from types import MappingProxyType from typing import TYPE_CHECKING, Optional, Sequence, NamedTuple, Tuple, Union @@ -44,7 +41,7 @@ from electrum.lnonion import (get_bolt04_onion_key, OnionPacket, process_onion_p OnionHopsDataSingle, decrypt_onionmsg_data_tlv, encrypt_onionmsg_data_tlv, get_shared_secrets_along_route, new_onion_packet, encrypt_hops_recipient_data) from electrum.lnutil import LnFeatures, MIN_FINAL_CLTV_DELTA_ACCEPTED, MAXIMUM_REMOTE_TO_SELF_DELAY_ACCEPTED -from electrum.util import OldTaskGroup, log_exceptions +from electrum.util import OldTaskGroup, log_exceptions, random_shuffled_copy def now(): @@ -240,17 +237,7 @@ def create_route_to_introduction_point( ) hops_data.append(final_hop_pre_ip) - # encrypt encrypted_data_tlv here - for i, hop in enumerate(hops_data): - encrypted_recipient_data = encrypt_onionmsg_data_tlv( - shared_secret=hop_shared_secrets[i], - **hop.blind_fields) - payload = dict(hop.payload) - payload['encrypted_recipient_data'] = { - 'encrypted_recipient_data': encrypted_recipient_data - } - hops_data[i] = dataclasses.replace(hop, payload=payload) - + encrypt_hops_recipient_data(tlv_stream_name='onionmsg_tlv', hops_data=hops_data, hop_shared_secrets=hop_shared_secrets) path_key = ecc.ECPrivkey(session_key).get_public_key_bytes() return peer, path_key, hops_data, blinded_node_ids @@ -340,8 +327,6 @@ def send_onion_message_to( hops_data.append(hop) payment_path_pubkeys = blinded_node_ids + blinded_path_blinded_ids - hop_shared_secrets, _ = get_shared_secrets_along_route(payment_path_pubkeys, session_key) - encrypt_hops_recipient_data('onionmsg_tlv', hops_data, hop_shared_secrets) packet = new_onion_packet(payment_path_pubkeys, session_key, hops_data, onion_message=True) packet_b = packet.to_bytes() @@ -437,8 +422,7 @@ def get_blinded_paths_to_me( local_height = lnwallet.network.get_local_height() if len(my_channels): - # randomize list - rchans = sorted(my_channels, key=lambda x: random()) + rchans = random_shuffled_copy(my_channels) for chan in rchans[:max_paths]: hop_extras = None if not onion_message: # add hop_extras and payinfo, assumption: len(blinded_path) == 2 (us and peer) @@ -494,8 +478,7 @@ def get_blinded_paths_to_me( my_onionmsg_peers = [peer for peer in lnwallet.lnpeermgr.peers.values() if peer.their_features.supports(LnFeatures.OPTION_ONION_MESSAGE_OPT)] if len(my_onionmsg_peers): - # randomize list - rpeers = sorted(my_onionmsg_peers, key=lambda x: random()) + rpeers = random_shuffled_copy(my_onionmsg_peers) for peer in rpeers[:max_paths]: blinded_path = create_blinded_path(os.urandom(32), [peer.pubkey, mynodeid], final_recipient_data) result.append(blinded_path) @@ -532,6 +515,17 @@ class OnionMessageManager(Logger): self.node_id_or_blinded_paths = node_id_or_blinded_paths self.current_index: int = 0 + # ensure node_id_or_blinded_paths is list + if isinstance(self.node_id_or_blinded_paths, bytes): + self.node_id_or_blinded_paths = [self.node_id_or_blinded_paths] + + def get_next_destination(self) -> bytes: + """get next path (round-robin)""" + dests = self.node_id_or_blinded_paths + dest = dests[self.current_index] + self.current_index = (self.current_index + 1) % len(dests) + return dest + def __init__(self, lnwallet: 'LNWallet'): Logger.__init__(self) self.network = None # type: Optional['Network'] @@ -680,15 +674,9 @@ class OnionMessageManager(Logger): def _send_pending_message(self, key: bytes) -> None: """adds reply_path to payload""" - req = self.pending.get(key) + req = self.pending[key] payload = req.payload - - # get next path (round robin) - dests = req.node_id_or_blinded_paths - if isinstance(req.node_id_or_blinded_paths, bytes): - dests = [req.node_id_or_blinded_paths] - dest = dests[req.current_index] - req.current_index = (req.current_index + 1) % len(dests) + dest = req.get_next_destination() self.logger.debug(f'send_pending_message {key=} {payload=} {dest=}')