interface: small clean-up. intro ChainResolutionMode.
- type hints - minor API changes - no functional changes
This commit is contained in:
@@ -38,6 +38,7 @@ import logging
|
||||
import hashlib
|
||||
import functools
|
||||
import random
|
||||
import enum
|
||||
|
||||
import aiorpcx
|
||||
from aiorpcx import RPCSession, Notification, NetAddress, NewlineFramer
|
||||
@@ -132,6 +133,14 @@ def assert_list_or_tuple(val: Any) -> None:
|
||||
raise RequestCorrupted(f'{val!r} should be a list or tuple')
|
||||
|
||||
|
||||
class ChainResolutionMode(enum.Enum):
|
||||
CATCHUP = enum.auto()
|
||||
BACKWARD = enum.auto()
|
||||
BINARY = enum.auto()
|
||||
FORK = enum.auto()
|
||||
NO_FORK = enum.auto()
|
||||
|
||||
|
||||
class NotificationSession(RPCSession):
|
||||
|
||||
def __init__(self, *args, interface: 'Interface', **kwargs):
|
||||
@@ -510,7 +519,7 @@ class Interface(Logger):
|
||||
# Note that these values are updated before they are verified.
|
||||
# Especially during initial header sync, verification can take a long time.
|
||||
# Failing verification will get the interface closed.
|
||||
self.tip_header = None
|
||||
self.tip_header = None # type: Optional[dict]
|
||||
self.tip = 0
|
||||
|
||||
self.fee_estimates_eta = {} # type: Dict[int, int]
|
||||
@@ -543,13 +552,13 @@ class Interface(Logger):
|
||||
def __str__(self):
|
||||
return f"<Interface {self.diagnostic_name()}>"
|
||||
|
||||
async def is_server_ca_signed(self, ca_ssl_context):
|
||||
async def is_server_ca_signed(self, ca_ssl_context: ssl.SSLContext) -> bool:
|
||||
"""Given a CA enforcing SSL context, returns True if the connection
|
||||
can be established. Returns False if the server has a self-signed
|
||||
certificate but otherwise is okay. Any other failures raise.
|
||||
"""
|
||||
try:
|
||||
await self.open_session(ca_ssl_context, exit_early=True)
|
||||
await self.open_session(ssl_context=ca_ssl_context, exit_early=True)
|
||||
except ConnectError as e:
|
||||
cause = e.__cause__
|
||||
if (isinstance(cause, ssl.SSLCertVerificationError)
|
||||
@@ -562,7 +571,7 @@ class Interface(Logger):
|
||||
# Good. We will use this server as CA-signed.
|
||||
return True
|
||||
|
||||
async def _try_saving_ssl_cert_for_first_time(self, ca_ssl_context):
|
||||
async def _try_saving_ssl_cert_for_first_time(self, ca_ssl_context: ssl.SSLContext) -> None:
|
||||
ca_signed = await self.is_server_ca_signed(ca_ssl_context)
|
||||
if ca_signed:
|
||||
if self._get_expected_fingerprint():
|
||||
@@ -599,10 +608,10 @@ class Interface(Logger):
|
||||
self.logger.info(f"certificate has expired: {e}")
|
||||
os.unlink(self.cert_path) # delete pinned cert only in this case
|
||||
return False
|
||||
self._verify_certificate_fingerprint(bytearray(b))
|
||||
self._verify_certificate_fingerprint(bytes(b))
|
||||
return True
|
||||
|
||||
async def _get_ssl_context(self):
|
||||
async def _get_ssl_context(self) -> Optional[ssl.SSLContext]:
|
||||
if self.protocol != 's':
|
||||
# using plaintext TCP
|
||||
return None
|
||||
@@ -658,7 +667,7 @@ class Interface(Logger):
|
||||
self.logger.info(f'disconnecting due to: {repr(e)}')
|
||||
return
|
||||
try:
|
||||
await self.open_session(ssl_context)
|
||||
await self.open_session(ssl_context=ssl_context)
|
||||
except (asyncio.CancelledError, ConnectError, aiorpcx.socks.SOCKSError) as e:
|
||||
# make SSL errors for main interface more visible (to help servers ops debug cert pinning issues)
|
||||
if (isinstance(e, ConnectError) and isinstance(e.__cause__, ssl.SSLError)
|
||||
@@ -731,8 +740,9 @@ class Interface(Logger):
|
||||
def _get_expected_fingerprint(self) -> Optional[str]:
|
||||
if self.is_main_server():
|
||||
return self.network.config.NETWORK_SERVERFINGERPRINT
|
||||
return None
|
||||
|
||||
def _verify_certificate_fingerprint(self, certificate):
|
||||
def _verify_certificate_fingerprint(self, certificate: bytes) -> None:
|
||||
expected_fingerprint = self._get_expected_fingerprint()
|
||||
if not expected_fingerprint:
|
||||
return
|
||||
@@ -743,21 +753,27 @@ class Interface(Logger):
|
||||
raise ErrorSSLCertFingerprintMismatch('Refusing to connect to server due to cert fingerprint mismatch')
|
||||
self.logger.info("cert fingerprint verification passed")
|
||||
|
||||
async def get_block_header(self, height, assert_mode):
|
||||
async def get_block_header(self, height: int, *, mode: ChainResolutionMode) -> dict:
|
||||
if not is_non_negative_integer(height):
|
||||
raise Exception(f"{repr(height)} is not a block height")
|
||||
self.logger.info(f'requesting block header {height} in mode {assert_mode}')
|
||||
self.logger.info(f'requesting block header {height} in {mode=}')
|
||||
# use lower timeout as we usually have network.bhi_lock here
|
||||
timeout = self.network.get_network_timeout_seconds(NetworkTimeout.Urgent)
|
||||
res = await self.session.send_request('blockchain.block.header', [height], timeout=timeout)
|
||||
return blockchain.deserialize_header(bytes.fromhex(res), height)
|
||||
|
||||
async def request_chunk(self, height: int, tip=None, *, can_return_early=False):
|
||||
async def request_chunk(
|
||||
self,
|
||||
height: int,
|
||||
*,
|
||||
tip: Optional[int] = None,
|
||||
can_return_early: bool = False,
|
||||
) -> Optional[Tuple[bool, int]]:
|
||||
if not is_non_negative_integer(height):
|
||||
raise Exception(f"{repr(height)} is not a block height")
|
||||
index = height // 2016
|
||||
if can_return_early and index in self._requested_chunks:
|
||||
return
|
||||
return None
|
||||
self.logger.info(f"requesting chunk from height {height}")
|
||||
size = 2016
|
||||
if tip is not None:
|
||||
@@ -790,12 +806,17 @@ class Interface(Logger):
|
||||
return (self.network.interface == self or
|
||||
self.network.interface is None and self.network.default_server == self.server)
|
||||
|
||||
async def open_session(self, sslc, exit_early=False):
|
||||
async def open_session(
|
||||
self,
|
||||
*,
|
||||
ssl_context: Optional[ssl.SSLContext],
|
||||
exit_early: bool = False,
|
||||
):
|
||||
session_factory = lambda *args, iface=self, **kwargs: NotificationSession(*args, **kwargs, interface=iface)
|
||||
async with _RSClient(
|
||||
session_factory=session_factory,
|
||||
host=self.host, port=self.port,
|
||||
ssl=sslc,
|
||||
ssl=ssl_context,
|
||||
proxy=self.proxy,
|
||||
transport=PaddedRSTransport,
|
||||
) as session:
|
||||
@@ -918,20 +939,25 @@ class Interface(Logger):
|
||||
if self.blockchain.height() >= height and self.blockchain.check_header(header):
|
||||
# another interface amended the blockchain
|
||||
return False
|
||||
_, height = await self.step(height, header)
|
||||
_, height = await self.step(height, header=header)
|
||||
# in the simple case, height == self.tip+1
|
||||
if height <= self.tip:
|
||||
await self.sync_until(height)
|
||||
return True
|
||||
|
||||
async def sync_until(self, height, next_height=None):
|
||||
async def sync_until(
|
||||
self,
|
||||
height: int,
|
||||
*,
|
||||
next_height: Optional[int] = None,
|
||||
) -> Tuple[ChainResolutionMode, int]:
|
||||
if next_height is None:
|
||||
next_height = self.tip
|
||||
last = None
|
||||
last = None # type: Optional[ChainResolutionMode]
|
||||
while last is None or height <= next_height:
|
||||
prev_last, prev_height = last, height
|
||||
if next_height > height + 10:
|
||||
could_connect, num_headers = await self.request_chunk(height, next_height)
|
||||
if next_height > height + 10: # TODO make smarter. the protocol allows asking for n headers
|
||||
could_connect, num_headers = await self.request_chunk(height, tip=next_height)
|
||||
if not could_connect:
|
||||
if height <= constants.net.max_checkpoint():
|
||||
raise GracefulDisconnect('server chain conflicts with checkpoints or genesis')
|
||||
@@ -941,16 +967,21 @@ class Interface(Logger):
|
||||
util.trigger_callback('network_updated')
|
||||
height = (height // 2016 * 2016) + num_headers
|
||||
assert height <= next_height+1, (height, self.tip)
|
||||
last = 'catchup'
|
||||
last = ChainResolutionMode.CATCHUP
|
||||
else:
|
||||
last, height = await self.step(height)
|
||||
assert (prev_last, prev_height) != (last, height), 'had to prevent infinite loop in interface.sync_until'
|
||||
return last, height
|
||||
|
||||
async def step(self, height, header=None):
|
||||
async def step(
|
||||
self,
|
||||
height: int,
|
||||
*,
|
||||
header: Optional[dict] = None, # at 'height'
|
||||
) -> Tuple[ChainResolutionMode, int]:
|
||||
assert 0 <= height <= self.tip, (height, self.tip)
|
||||
if header is None:
|
||||
header = await self.get_block_header(height, 'catchup')
|
||||
header = await self.get_block_header(height, mode=ChainResolutionMode.CATCHUP)
|
||||
|
||||
chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
|
||||
if chain:
|
||||
@@ -959,12 +990,12 @@ class Interface(Logger):
|
||||
# we might know the blockhash (enough for check_header) but
|
||||
# not have the header itself. e.g. regtest chain with only genesis.
|
||||
# this situation resolves itself on the next block
|
||||
return 'catchup', height+1
|
||||
return ChainResolutionMode.CATCHUP, height+1
|
||||
|
||||
can_connect = blockchain.can_connect(header) if 'mock' not in header else header['mock']['connect'](height)
|
||||
if not can_connect:
|
||||
self.logger.info(f"can't connect new block: {height=}")
|
||||
height, header, bad, bad_header = await self._search_headers_backwards(height, header)
|
||||
height, header, bad, bad_header = await self._search_headers_backwards(height, header=header)
|
||||
chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
|
||||
can_connect = blockchain.can_connect(header) if 'mock' not in header else header['mock']['connect'](height)
|
||||
assert chain or can_connect
|
||||
@@ -974,12 +1005,18 @@ class Interface(Logger):
|
||||
if isinstance(can_connect, Blockchain): # not when mocking
|
||||
self.blockchain = can_connect
|
||||
self.blockchain.save_header(header)
|
||||
return 'catchup', height
|
||||
return ChainResolutionMode.CATCHUP, height
|
||||
|
||||
good, bad, bad_header = await self._search_headers_binary(height, bad, bad_header, chain)
|
||||
return await self._resolve_potential_chain_fork_given_forkpoint(good, bad, bad_header)
|
||||
|
||||
async def _search_headers_binary(self, height, bad, bad_header, chain):
|
||||
async def _search_headers_binary(
|
||||
self,
|
||||
height: int,
|
||||
bad: int,
|
||||
bad_header: dict,
|
||||
chain: Optional[Blockchain],
|
||||
) -> Tuple[int, int, dict]:
|
||||
assert bad == bad_header['block_height']
|
||||
_assert_header_does_not_check_against_any_chain(bad_header)
|
||||
|
||||
@@ -989,7 +1026,7 @@ class Interface(Logger):
|
||||
assert good < bad, (good, bad)
|
||||
height = (good + bad) // 2
|
||||
self.logger.info(f"binary step. good {good}, bad {bad}, height {height}")
|
||||
header = await self.get_block_header(height, 'binary')
|
||||
header = await self.get_block_header(height, mode=ChainResolutionMode.BINARY)
|
||||
chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
|
||||
if chain:
|
||||
self.blockchain = chain if isinstance(chain, Blockchain) else self.blockchain
|
||||
@@ -1009,7 +1046,12 @@ class Interface(Logger):
|
||||
self.logger.info(f"binary search exited. good {good}, bad {bad}")
|
||||
return good, bad, bad_header
|
||||
|
||||
async def _resolve_potential_chain_fork_given_forkpoint(self, good, bad, bad_header):
|
||||
async def _resolve_potential_chain_fork_given_forkpoint(
|
||||
self,
|
||||
good: int,
|
||||
bad: int,
|
||||
bad_header: dict,
|
||||
) -> Tuple[ChainResolutionMode, int]:
|
||||
assert good + 1 == bad
|
||||
assert bad == bad_header['block_height']
|
||||
_assert_header_does_not_check_against_any_chain(bad_header)
|
||||
@@ -1021,7 +1063,7 @@ class Interface(Logger):
|
||||
if bh == good:
|
||||
height = good + 1
|
||||
self.logger.info(f"catching up from {height}")
|
||||
return 'no_fork', height
|
||||
return ChainResolutionMode.NO_FORK, height
|
||||
|
||||
# this is a new fork we don't yet have
|
||||
height = bad + 1
|
||||
@@ -1030,16 +1072,21 @@ class Interface(Logger):
|
||||
b = forkfun(bad_header) # type: Blockchain
|
||||
self.blockchain = b
|
||||
assert b.forkpoint == bad
|
||||
return 'fork', height
|
||||
return ChainResolutionMode.FORK, height
|
||||
|
||||
async def _search_headers_backwards(self, height, header):
|
||||
async def _search_headers_backwards(
|
||||
self,
|
||||
height: int,
|
||||
*,
|
||||
header: dict,
|
||||
) -> Tuple[int, dict, int, dict]:
|
||||
async def iterate():
|
||||
nonlocal height, header
|
||||
checkp = False
|
||||
if height <= constants.net.max_checkpoint():
|
||||
height = constants.net.max_checkpoint()
|
||||
checkp = True
|
||||
header = await self.get_block_header(height, 'backward')
|
||||
header = await self.get_block_header(height, mode=ChainResolutionMode.BACKWARD)
|
||||
chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
|
||||
can_connect = blockchain.can_connect(header) if 'mock' not in header else header['mock']['connect'](height)
|
||||
if chain or can_connect:
|
||||
|
||||
Reference in New Issue
Block a user