diff --git a/electrum/address_synchronizer.py b/electrum/address_synchronizer.py index 17c4c2701..d60f219e9 100644 --- a/electrum/address_synchronizer.py +++ b/electrum/address_synchronizer.py @@ -89,9 +89,7 @@ class AddressSynchronizer(Logger, EventListener): # verifier (SPV) and synchronizer are started in start_network self.synchronizer = None self.verifier = None - # locks: if you need to take multiple ones, acquire them in the order they are defined here! self.lock = threading.RLock() - self.transaction_lock = threading.RLock() self.future_tx = {} # type: Dict[str, int] # txid -> wanted (abs) height # Txs the server claims are mined but still pending verification: self.unverified_tx = defaultdict(int) # type: Dict[str, int] # txid -> height. Access with self.lock. @@ -107,12 +105,7 @@ class AddressSynchronizer(Logger, EventListener): def diagnostic_name(self): return self.name or "" - def with_transaction_lock(func): - def func_wrapper(self: 'AddressSynchronizer', *args, **kwargs): - with self.transaction_lock: - return func(self, *args, **kwargs) - return func_wrapper - + @with_lock def load_and_cleanup(self): self.load_local_history() self.check_history() @@ -143,9 +136,7 @@ class AddressSynchronizer(Logger, EventListener): so that only includes txns the server sees. """ h = {} - # we need self.transaction_lock but get_tx_height will take self.lock - # so we need to take that too here, to enforce order of locks - with self.lock, self.transaction_lock: + with self.lock: related_txns = self._history_local.get(addr, set()) for tx_hash in related_txns: tx_height = self.get_tx_height(tx_hash).height @@ -156,6 +147,7 @@ class AddressSynchronizer(Logger, EventListener): """Return number of transactions where address is involved.""" return len(self._history_local.get(addr, ())) + @with_lock def get_txin_address(self, txin: TxInput) -> Optional[str]: if txin.address: return txin.address @@ -170,6 +162,7 @@ class AddressSynchronizer(Logger, EventListener): return tx.outputs()[prevout_n].address return None + @with_lock def get_txin_value(self, txin: TxInput, *, address: str = None) -> Optional[int]: if txin.value_sats() is not None: return txin.value_sats() @@ -189,6 +182,7 @@ class AddressSynchronizer(Logger, EventListener): return tx.outputs()[prevout_n].value return None + @with_lock def load_unverified_transactions(self): # review transactions that are in the history for addr in self.db.get_history(): @@ -208,8 +202,9 @@ class AddressSynchronizer(Logger, EventListener): @event_listener def on_event_blockchain_updated(self, *args): - self._get_balance_cache = {} # invalidate cache - self.db.put('stored_height', self.get_local_height()) + with self.lock: + self._get_balance_cache = {} # invalidate cache + self.db.put('stored_height', self.get_local_height()) async def stop(self): if self.network: @@ -240,7 +235,7 @@ class AddressSynchronizer(Logger, EventListener): conflict (if already in wallet history) """ conflicting_txns = set() - with self.transaction_lock: + with self.lock: for txin in tx.inputs(): if txin.is_coinbase_input(): continue @@ -262,6 +257,7 @@ class AddressSynchronizer(Logger, EventListener): conflicting_txns -= {tx_hash} return conflicting_txns + @with_lock def get_transaction(self, txid: str) -> Optional[Transaction]: tx = self.db.get_transaction(txid) if tx: @@ -287,9 +283,7 @@ class AddressSynchronizer(Logger, EventListener): raise Exception("cannot add tx without txid to wallet history") # For sanity, try to serialize and deserialize tx early: tx_from_any(str(tx)) # see if raises (no-side-effects) - # we need self.transaction_lock but get_tx_height will take self.lock - # so we need to take that too here, to enforce order of locks - with self.lock, self.transaction_lock: + with self.lock: # NOTE: returning if tx in self.transactions might seem like a good idea # BUT we track is_mine inputs in a txn, and during subsequent calls # of add_transaction tx, we might learn of more-and-more inputs of @@ -377,7 +371,7 @@ class AddressSynchronizer(Logger, EventListener): """Removes a transaction AND all its dependents/children from the wallet history. """ - with self.lock, self.transaction_lock: + with self.lock: to_remove = {tx_hash} to_remove |= self.get_depending_transactions(tx_hash) for txid in to_remove: @@ -404,7 +398,7 @@ class AddressSynchronizer(Logger, EventListener): if spending_txid == tx_hash: self.db.remove_spent_outpoint(prevout_hash, prevout_n) - with self.lock, self.transaction_lock: + with self.lock: self.logger.info(f"removing tx from history {tx_hash}") tx = self.db.remove_transaction(tx_hash) remove_from_spent_outpoints() @@ -426,7 +420,7 @@ class AddressSynchronizer(Logger, EventListener): def get_depending_transactions(self, tx_hash: str) -> Set[str]: """Returns all (grand-)children of tx_hash in this wallet.""" - with self.transaction_lock: + with self.lock: children = set() for n in self.db.get_spent_outpoints(tx_hash): other_hash = self.db.get_spent_outpoint(tx_hash, n) @@ -434,6 +428,7 @@ class AddressSynchronizer(Logger, EventListener): children |= self.get_depending_transactions(other_hash) return children + @with_lock def receive_tx_callback(self, tx: Transaction, *, tx_height: Optional[int] = None) -> None: txid = tx.txid() assert txid is not None @@ -442,18 +437,18 @@ class AddressSynchronizer(Logger, EventListener): self.add_unverified_or_unconfirmed_tx(txid, tx_height) self.add_transaction(tx, allow_unrelated=True) + @with_lock def receive_history_callback(self, addr: str, hist, tx_fees: Dict[str, int]): - with self.lock: - old_hist = self.get_address_history(addr) - for tx_hash, height in old_hist.items(): - if (tx_hash, height) not in hist: - # make tx local - self.unverified_tx.pop(tx_hash, None) - self.unconfirmed_tx.pop(tx_hash, None) - self.db.remove_verified_tx(tx_hash) - if self.verifier: - self.verifier.remove_spv_proof_for_tx(tx_hash) - self.db.set_addr_history(addr, hist) + old_hist = self.get_address_history(addr) + for tx_hash, height in old_hist.items(): + if (tx_hash, height) not in hist: + # make tx local + self.unverified_tx.pop(tx_hash, None) + self.unconfirmed_tx.pop(tx_hash, None) + self.db.remove_verified_tx(tx_hash) + if self.verifier: + self.verifier.remove_spv_proof_for_tx(tx_hash) + self.db.set_addr_history(addr, hist) for tx_hash, tx_height in hist: # add it in case it was previously unconfirmed @@ -472,6 +467,7 @@ class AddressSynchronizer(Logger, EventListener): for tx_hash, fee_sat in tx_fees.items(): self.db.add_tx_fee_from_server(tx_hash, fee_sat) + @with_lock @profiler def load_local_history(self): self._history_local = {} # type: Dict[str, Set[str]] # address -> set(txid) @@ -479,6 +475,7 @@ class AddressSynchronizer(Logger, EventListener): for txid in itertools.chain(self.db.list_txi(), self.db.list_txo()): self._add_tx_to_local_history(txid) + @with_lock @profiler def check_history(self): hist_addrs_mine = list(filter(lambda k: self.is_mine(k), self.db.get_history())) @@ -494,6 +491,7 @@ class AddressSynchronizer(Logger, EventListener): if tx is not None: self.add_transaction(tx, allow_unrelated=True) + @with_lock def remove_local_transactions_we_dont_have(self): for txid in itertools.chain(self.db.list_txi(), self.db.list_txo()): tx_height = self.get_tx_height(txid).height @@ -502,10 +500,9 @@ class AddressSynchronizer(Logger, EventListener): def clear_history(self): with self.lock: - with self.transaction_lock: - self.db.clear_history() - self._history_local.clear() - self._get_balance_cache.clear() # invalidate cache + self.db.clear_history() + self._history_local.clear() + self._get_balance_cache.clear() # invalidate cache def _get_tx_sort_key(self, tx_hash: str) -> Tuple[int, int]: """Returns a key to be used for sorting txs.""" @@ -544,7 +541,6 @@ class AddressSynchronizer(Logger, EventListener): return f @with_lock - @with_transaction_lock @with_local_height_cached def get_history(self, domain) -> Sequence[HistoryItem]: domain = set(domain) @@ -582,7 +578,7 @@ class AddressSynchronizer(Logger, EventListener): return h2 def _add_tx_to_local_history(self, txid): - with self.transaction_lock: + with self.lock: for addr in itertools.chain(self.db.get_txi_addresses(txid), self.db.get_txo_addresses(txid)): cur_hist = self._history_local.get(addr, set()) cur_hist.add(txid) @@ -590,7 +586,7 @@ class AddressSynchronizer(Logger, EventListener): self._mark_address_history_changed(addr) def _remove_tx_from_local_history(self, txid): - with self.transaction_lock: + with self.lock: for addr in itertools.chain(self.db.get_txi_addresses(txid), self.db.get_txo_addresses(txid)): cur_hist = self._history_local.get(addr, set()) try: @@ -621,16 +617,15 @@ class AddressSynchronizer(Logger, EventListener): await self._address_history_changed_events[addr].wait() def add_unverified_or_unconfirmed_tx(self, tx_hash: str, tx_height: int) -> None: - if self.db.is_in_verified_tx(tx_hash): - if tx_height <= 0: - # tx was previously SPV-verified but now in mempool (probably reorg) - with self.lock: + with self.lock: + if self.db.is_in_verified_tx(tx_hash): + if tx_height <= 0: + # tx was previously SPV-verified but now in mempool (probably reorg) self.db.remove_verified_tx(tx_hash) self.unconfirmed_tx[tx_hash] = tx_height - if self.verifier: - self.verifier.remove_spv_proof_for_tx(tx_hash) - else: - with self.lock: + if self.verifier: + self.verifier.remove_spv_proof_for_tx(tx_hash) + else: if tx_height > 0: self.unverified_tx[tx_hash] = tx_height else: @@ -750,7 +745,7 @@ class AddressSynchronizer(Logger, EventListener): nans += n2 return nsent, nans - @with_transaction_lock + @with_lock def get_tx_delta(self, tx_hash: str, address: str) -> int: """effect of tx on address""" delta = 0 @@ -764,6 +759,7 @@ class AddressSynchronizer(Logger, EventListener): delta += v return delta + @with_lock def get_tx_fee(self, txid: str) -> Optional[int]: """Returns tx_fee or None. Use server fee only if tx is unconfirmed and not mine. @@ -799,16 +795,15 @@ class AddressSynchronizer(Logger, EventListener): return None # compute fee if possible v_in = v_out = 0 - with self.lock, self.transaction_lock: - for txin in tx.inputs(): - addr = self.get_txin_address(txin) - value = self.get_txin_value(txin, address=addr) - if value is None: - v_in = None - elif v_in is not None: - v_in += value - for txout in tx.outputs(): - v_out += txout.value + for txin in tx.inputs(): + addr = self.get_txin_address(txin) + value = self.get_txin_value(txin, address=addr) + if value is None: + v_in = None + elif v_in is not None: + v_in += value + for txout in tx.outputs(): + v_out += txout.value if v_in is not None: fee = v_in - v_out else: @@ -819,7 +814,7 @@ class AddressSynchronizer(Logger, EventListener): return fee def get_addr_io(self, address: str): - with self.lock, self.transaction_lock: + with self.lock: h = self.get_address_history(address).items() received = {} sent = {} @@ -868,7 +863,6 @@ class AddressSynchronizer(Logger, EventListener): return sum([value for height, pos, value, is_cb in received.values()]) @with_lock - @with_transaction_lock @with_local_height_cached def get_balance(self, domain, *, excluded_addresses: Set[str] = None, excluded_coins: Set[str] = None) -> Tuple[int, int, int]: @@ -934,6 +928,7 @@ class AddressSynchronizer(Logger, EventListener): self._get_balance_cache[cache_key] = result return result + @with_lock @with_local_height_cached def get_utxos( self, @@ -990,6 +985,7 @@ class AddressSynchronizer(Logger, EventListener): coins = self.get_addr_utxo(address) return not bool(coins) + @with_lock @with_local_height_cached def address_is_old(self, address: str, *, req_conf: int = 3) -> bool: """Returns whether address has any history that is deeply confirmed. @@ -1009,7 +1005,8 @@ class AddressSynchronizer(Logger, EventListener): max_conf = max(max_conf, tx_age) return max_conf >= req_conf - def get_spender(self, outpoint: str) -> str: + @with_lock + def get_spender(self, outpoint: str) -> Optional[str]: """ returns txid spending outpoint. subscribes to addresses as a side effect. @@ -1021,7 +1018,7 @@ class AddressSynchronizer(Logger, EventListener): if tx_mined_status.height in [TX_HEIGHT_LOCAL, TX_HEIGHT_FUTURE]: spender_txid = None if not spender_txid: - return + return None spender_tx = self.get_transaction(spender_txid) for i, o in enumerate(spender_tx.outputs()): if o.address is None: diff --git a/electrum/wallet.py b/electrum/wallet.py index 6a805ffbc..f01ffa2fc 100644 --- a/electrum/wallet.py +++ b/electrum/wallet.py @@ -401,7 +401,6 @@ class Abstract_Wallet(ABC, Logger, EventListener): for addr in self.get_addresses(): self.adb.add_address(addr) self.lock = self.adb.lock - self.transaction_lock = self.adb.transaction_lock self._last_full_history = None self._tx_parents_cache = {} self._default_labels = {} @@ -568,7 +567,7 @@ class Abstract_Wallet(ABC, Logger, EventListener): return is_mine def clear_tx_parents_cache(self): - with self.lock, self.transaction_lock: + with self.lock: self._tx_parents_cache.clear() self._num_parents.clear() self._last_full_history = None @@ -877,7 +876,7 @@ class Abstract_Wallet(ABC, Logger, EventListener): is_relevant = False # "related to wallet?" num_input_ismine = 0 v_in = v_in_mine = v_out = v_out_mine = 0 - with self.lock, self.transaction_lock: + with self.lock: for txin in tx.inputs(): addr = self.adb.get_txin_address(txin) value = self.adb.get_txin_value(txin, address=addr) @@ -1015,7 +1014,7 @@ class Abstract_Wallet(ABC, Logger, EventListener): returns a flat dict: txid -> list of parent txids """ - with self.lock, self.transaction_lock: + with self.lock: if self._last_full_history is None: self._last_full_history = self.get_onchain_history() # populate cache in chronological order (confirmed tx only) @@ -1252,7 +1251,7 @@ class Abstract_Wallet(ABC, Logger, EventListener): if not invoice.is_lightning(): if self.is_onchain_invoice_paid(invoice)[0]: _logger.info("saving invoice... but it is already paid!") - with self.transaction_lock: + with self.lock: for txout in invoice.get_outputs(): self._invoices_from_scriptpubkey_map[txout.scriptpubkey].add(key) self._invoices[key] = invoice @@ -1362,7 +1361,7 @@ class Abstract_Wallet(ABC, Logger, EventListener): relevant_txs = set() is_paid = True conf_needed = None # type: Optional[int] - with self.lock, self.transaction_lock: + with self.lock: for invoice_scriptpubkey, invoice_amt in invoice_amounts.items(): scripthash = bitcoin.script_to_scripthash(invoice_scriptpubkey) prevouts_and_values = self.db.get_prevouts_by_scripthash(scripthash) @@ -2879,7 +2878,7 @@ class Abstract_Wallet(ABC, Logger, EventListener): def get_invoices_and_requests_touched_by_tx(self, tx): request_keys = set() invoice_keys = set() - with self.lock, self.transaction_lock: + with self.lock: for txo in tx.outputs(): addr = txo.address if request := self.get_request_by_addr(addr):