network: tighten checks of server responses for type/sanity

This commit is contained in:
SomberNight
2020-10-16 19:30:42 +02:00
parent c70484455c
commit c5da22a9dd
5 changed files with 174 additions and 54 deletions

View File

@@ -29,7 +29,7 @@ import sys
import traceback
import asyncio
import socket
from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple, Any
from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple, Any, Sequence
from collections import defaultdict
from ipaddress import IPv4Network, IPv6Network, ip_address, IPv6Address, IPv4Address
import itertools
@@ -46,13 +46,14 @@ import certifi
from .util import (ignore_exceptions, log_exceptions, bfh, SilentTaskGroup, MySocksProxy,
is_integer, is_non_negative_integer, is_hash256_str, is_hex_str,
is_real_number)
is_int_or_float, is_non_negative_int_or_float)
from . import util
from . import x509
from . import pem
from . import version
from . import blockchain
from .blockchain import Blockchain, HEADER_SIZE
from . import bitcoin
from . import constants
from .i18n import _
from .logging import Logger
@@ -96,9 +97,14 @@ def assert_integer(val: Any) -> None:
raise RequestCorrupted(f'{val!r} should be an integer')
def assert_real_number(val: Any, *, as_str: bool = False) -> None:
if not is_real_number(val, as_str=as_str):
raise RequestCorrupted(f'{val!r} should be a number')
def assert_int_or_float(val: Any) -> None:
if not is_int_or_float(val):
raise RequestCorrupted(f'{val!r} should be int or float')
def assert_non_negative_int_or_float(val: Any) -> None:
if not is_non_negative_int_or_float(val):
raise RequestCorrupted(f'{val!r} should be a non-negative int or float')
def assert_hash256_str(val: Any) -> None:
@@ -656,14 +662,13 @@ class Interface(Logger):
async def request_fee_estimates(self):
from .simple_config import FEE_ETA_TARGETS
from .bitcoin import COIN
while True:
async with TaskGroup() as group:
fee_tasks = []
for i in FEE_ETA_TARGETS:
fee_tasks.append((i, await group.spawn(self.session.send_request('blockchain.estimatefee', [i]))))
fee_tasks.append((i, await group.spawn(self.get_estimatefee(i))))
for nblock_target, task in fee_tasks:
fee = int(task.result() * COIN)
fee = task.result()
if fee < 0: continue
self.fee_estimates_eta[nblock_target] = fee
self.network.update_fee_estimates()
@@ -983,6 +988,61 @@ class Interface(Logger):
assert_hash256_str(res)
return res
async def get_fee_histogram(self) -> Sequence[Tuple[Union[float, int], int]]:
# do request
res = await self.session.send_request('mempool.get_fee_histogram')
# check response
assert_list_or_tuple(res)
for fee, s in res:
assert_non_negative_int_or_float(fee)
assert_non_negative_integer(s)
return res
async def get_server_banner(self) -> str:
# do request
res = await self.session.send_request('server.banner')
# check response
if not isinstance(res, str):
raise RequestCorrupted(f'{res!r} should be a str')
return res
async def get_donation_address(self) -> str:
# do request
res = await self.session.send_request('server.donation_address')
# check response
if not res: # ignore empty string
return ''
if not bitcoin.is_address(res):
# note: do not hard-fail -- allow server to use future-type
# bitcoin address we do not recognize
self.logger.info(f"invalid donation address from server: {repr(res)}")
res = ''
return res
async def get_relay_fee(self) -> int:
"""Returns the min relay feerate in sat/kbyte."""
# do request
res = await self.session.send_request('blockchain.relayfee')
# check response
assert_non_negative_int_or_float(res)
relayfee = int(res * bitcoin.COIN)
relayfee = max(0, relayfee)
return relayfee
async def get_estimatefee(self, num_blocks: int) -> int:
"""Returns a feerate estimate for getting confirmed within
num_blocks blocks, in sat/kbyte.
"""
if not is_non_negative_integer(num_blocks):
raise Exception(f"{repr(num_blocks)} is not a num_blocks")
# do request
res = await self.session.send_request('blockchain.estimatefee', [num_blocks])
# check response
if res != -1:
assert_non_negative_int_or_float(res)
res = int(res * bitcoin.COIN)
return res
def _assert_header_does_not_check_against_any_chain(header: dict) -> None:
chain_bad = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)