import os import csv import io from typing import Callable, Tuple, Any, Dict, List, Sequence, Union, Optional, Mapping from types import MappingProxyType from collections import OrderedDict from .lnutil import OnionFailureCodeMetaFlag class FailedToParseMsg(Exception): msg_type_int: Optional[int] = None msg_type_name: Optional[str] = None class UnknownMsgType(FailedToParseMsg): pass class UnknownOptionalMsgType(UnknownMsgType): pass class UnknownMandatoryMsgType(UnknownMsgType): pass class MalformedMsg(FailedToParseMsg): pass class UnknownMsgFieldType(MalformedMsg): pass class UnexpectedEndOfStream(MalformedMsg): pass class FieldEncodingNotMinimal(MalformedMsg): pass class UnknownMandatoryTLVRecordType(MalformedMsg): pass class MsgTrailingGarbage(MalformedMsg): pass class MsgInvalidFieldOrder(MalformedMsg): pass class UnexpectedFieldSizeForEncoder(MalformedMsg): pass def _num_remaining_bytes_to_read(fd: io.BytesIO) -> int: cur_pos = fd.tell() end_pos = fd.seek(0, io.SEEK_END) fd.seek(cur_pos) return end_pos - cur_pos def _assert_can_read_at_least_n_bytes(fd: io.BytesIO, n: int) -> None: # note: it's faster to read n bytes and then check if we read n, than # to assert we can read at least n and then read n bytes. nremaining = _num_remaining_bytes_to_read(fd) if nremaining < n: raise UnexpectedEndOfStream(f"wants to read {n} bytes but only {nremaining} bytes left") def write_bigsize_int(i: int) -> bytes: assert i >= 0, i if i < 0xfd: return int.to_bytes(i, length=1, byteorder="big", signed=False) elif i < 0x1_0000: return b"\xfd" + int.to_bytes(i, length=2, byteorder="big", signed=False) elif i < 0x1_0000_0000: return b"\xfe" + int.to_bytes(i, length=4, byteorder="big", signed=False) else: return b"\xff" + int.to_bytes(i, length=8, byteorder="big", signed=False) def read_bigsize_int(fd: io.BytesIO) -> Optional[int]: try: first = fd.read(1)[0] except IndexError: return None # end of file if first < 0xfd: return first elif first == 0xfd: buf = fd.read(2) if len(buf) != 2: raise UnexpectedEndOfStream() val = int.from_bytes(buf, byteorder="big", signed=False) if not (0xfd <= val < 0x1_0000): raise FieldEncodingNotMinimal() return val elif first == 0xfe: buf = fd.read(4) if len(buf) != 4: raise UnexpectedEndOfStream() val = int.from_bytes(buf, byteorder="big", signed=False) if not (0x1_0000 <= val < 0x1_0000_0000): raise FieldEncodingNotMinimal() return val elif first == 0xff: buf = fd.read(8) if len(buf) != 8: raise UnexpectedEndOfStream() val = int.from_bytes(buf, byteorder="big", signed=False) if not (0x1_0000_0000 <= val): raise FieldEncodingNotMinimal() return val raise Exception() # TODO: maybe if field_type is not "byte", we could return a list of type_len sized chunks? # if field_type is a numeric, we could return a list of ints? def _read_primitive_field( *, fd: io.BytesIO, field_type: str, count: Union[int, str] ) -> Union[bytes, int]: if not fd: raise Exception() if isinstance(count, int): assert count >= 0, f"{count!r} must be non-neg int" elif count == "...": pass else: raise Exception(f"unexpected field count: {count!r}") if count == 0: return b"" type_len = None if field_type == 'byte': type_len = 1 elif field_type in ('u8', 'u16', 'u32', 'u64'): if field_type == 'u8': type_len = 1 elif field_type == 'u16': type_len = 2 elif field_type == 'u32': type_len = 4 else: assert field_type == 'u64' type_len = 8 assert count == 1, count buf = fd.read(type_len) if len(buf) != type_len: raise UnexpectedEndOfStream() return int.from_bytes(buf, byteorder="big", signed=False) elif field_type in ('tu16', 'tu32', 'tu64'): if field_type == 'tu16': type_len = 2 elif field_type == 'tu32': type_len = 4 else: assert field_type == 'tu64' type_len = 8 assert count == 1, count raw = fd.read(type_len) if len(raw) > 0 and raw[0] == 0x00: raise FieldEncodingNotMinimal() return int.from_bytes(raw, byteorder="big", signed=False) elif field_type == 'bigsize': assert count == 1, count val = read_bigsize_int(fd) if val is None: raise UnexpectedEndOfStream() return val elif field_type == 'chain_hash': type_len = 32 elif field_type == 'channel_id': type_len = 32 elif field_type == 'sha256': type_len = 32 elif field_type == 'signature': type_len = 64 elif field_type == 'point': type_len = 33 elif field_type == 'short_channel_id': type_len = 8 elif field_type == 'sciddir_or_pubkey': buf = fd.read(1) if buf[0] in [0, 1]: type_len = 9 elif buf[0] in [2, 3]: type_len = 33 else: raise Exception(f"invalid sciddir_or_pubkey, prefix byte not in range 0-3") buf += fd.read(type_len - 1) if len(buf) != type_len: raise UnexpectedEndOfStream() return buf if count == "...": total_len = -1 # read all else: if type_len is None: raise UnknownMsgFieldType(f"unknown field type: {field_type!r}") total_len = count * type_len buf = fd.read(total_len) if total_len >= 0 and len(buf) != total_len: raise UnexpectedEndOfStream() return buf # TODO: maybe for "value" we could accept a list with len "count" of appropriate items def _write_primitive_field( *, fd: io.BytesIO, field_type: str, count: Union[int, str], value: Union[bytes, int] ) -> None: if not fd: raise Exception() if isinstance(count, int): assert count >= 0, f"{count!r} must be non-neg int" elif count == "...": pass else: raise Exception(f"unexpected field count: {count!r}") if count == 0: return type_len = None if field_type == 'byte': type_len = 1 elif field_type == 'u8': type_len = 1 elif field_type == 'u16': type_len = 2 elif field_type == 'u32': type_len = 4 elif field_type == 'u64': type_len = 8 elif field_type in ('tu16', 'tu32', 'tu64'): if field_type == 'tu16': type_len = 2 elif field_type == 'tu32': type_len = 4 else: assert field_type == 'tu64' type_len = 8 assert count == 1, count if isinstance(value, int): value = int.to_bytes(value, length=type_len, byteorder="big", signed=False) if not isinstance(value, (bytes, bytearray)): raise Exception(f"can only write bytes into fd. got: {value!r}") while len(value) > 0 and value[0] == 0x00: value = value[1:] nbytes_written = fd.write(value) if nbytes_written != len(value): raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?") return elif field_type == 'bigsize': assert count == 1, count if isinstance(value, int): value = write_bigsize_int(value) if not isinstance(value, (bytes, bytearray)): raise Exception(f"can only write bytes into fd. got: {value!r}") nbytes_written = fd.write(value) if nbytes_written != len(value): raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?") return elif field_type == 'chain_hash': type_len = 32 elif field_type == 'channel_id': type_len = 32 elif field_type == 'sha256': type_len = 32 elif field_type == 'signature': type_len = 64 elif field_type == 'point': type_len = 33 elif field_type == 'short_channel_id': type_len = 8 elif field_type == 'sciddir_or_pubkey': assert isinstance(value, bytes) if value[0] in [0, 1]: type_len = 9 # short_channel_id elif value[0] in [2, 3]: type_len = 33 # point else: raise Exception(f"invalid sciddir_or_pubkey, prefix byte not in range 0-3") total_len = -1 if count != "...": if type_len is None: raise UnknownMsgFieldType(f"unknown field type: {field_type!r}") total_len = count * type_len if isinstance(value, int) and (count == 1 or field_type == 'byte'): value = int.to_bytes(value, length=total_len, byteorder="big", signed=False) if not isinstance(value, (bytes, bytearray)): raise Exception(f"can only write bytes into fd. got: {value!r}") if count != "..." and total_len != len(value): raise UnexpectedFieldSizeForEncoder(f"expected: {total_len}, got {len(value)}") nbytes_written = fd.write(value) if nbytes_written != len(value): raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?") def _read_tlv_record(*, fd: io.BytesIO) -> Tuple[int, bytes]: if not fd: raise Exception() tlv_type = _read_primitive_field(fd=fd, field_type="bigsize", count=1) tlv_len = _read_primitive_field(fd=fd, field_type="bigsize", count=1) tlv_val = _read_primitive_field(fd=fd, field_type="byte", count=tlv_len) return tlv_type, tlv_val def _write_tlv_record(*, fd: io.BytesIO, tlv_type: int, tlv_val: bytes) -> None: if not fd: raise Exception() tlv_len = len(tlv_val) _write_primitive_field(fd=fd, field_type="bigsize", count=1, value=tlv_type) _write_primitive_field(fd=fd, field_type="bigsize", count=1, value=tlv_len) _write_primitive_field(fd=fd, field_type="byte", count=tlv_len, value=tlv_val) def _resolve_field_count(field_count_str: str, *, vars_dict: Mapping, allow_any=False) -> Union[int, str]: """Returns an evaluated field count, typically an int. If allow_any is True, the return value can be a str with value=="...". """ if field_count_str == "": field_count = 1 elif field_count_str == "...": if not allow_any: raise Exception("field count is '...' but allow_any is False") return field_count_str else: try: field_count = int(field_count_str) except ValueError: field_count = vars_dict[field_count_str] if isinstance(field_count, (bytes, bytearray)): field_count = int.from_bytes(field_count, byteorder="big") assert isinstance(field_count, int) return field_count def _parse_msgtype_intvalue_for_onion_wire(value: str) -> int: msg_type_int = 0 for component in value.split("|"): try: msg_type_int |= int(component) except ValueError: msg_type_int |= OnionFailureCodeMetaFlag[component] return msg_type_int class LNSerializer: def __init__(self, *, for_onion_wire: bool = False): # TODO msg_type could be 'int' everywhere... self.msg_scheme_from_type = {} # type: Dict[bytes, List[Sequence[str]]] self.msg_type_from_name = {} # type: Dict[str, bytes] self.in_tlv_stream_get_tlv_record_scheme_from_type = {} # type: Dict[str, Dict[int, List[Sequence[str]]]] self.in_tlv_stream_get_record_type_from_name = {} # type: Dict[str, Dict[str, int]] self.in_tlv_stream_get_record_name_from_type = {} # type: Dict[str, Dict[int, str]] self.subtypes = {} # type: Dict[str, Dict[str, Sequence[str]]] if for_onion_wire: path = os.path.join(os.path.dirname(__file__), "lnwire", "onion_wire.csv") else: path = os.path.join(os.path.dirname(__file__), "lnwire", "peer_wire.csv") with open(path, newline='') as f: csvreader = csv.reader(f) for row in csvreader: #print(f">>> {row!r}") if row[0] == "msgtype": # msgtype,,[,