diff --git a/plugins/lsps-plugin/src/core/tlv.rs b/plugins/lsps-plugin/src/core/tlv.rs index 357ebd283..7bf74442d 100644 --- a/plugins/lsps-plugin/src/core/tlv.rs +++ b/plugins/lsps-plugin/src/core/tlv.rs @@ -1,6 +1,6 @@ -use anyhow::Result; use serde::{de::Error as DeError, Deserialize, Deserializer, Serialize, Serializer}; use std::{convert::TryFrom, fmt}; +use thiserror::Error; pub const TLV_FORWARD_AMT: u64 = 2; pub const TLV_OUTGOING_CLTV: u64 = 4; @@ -16,41 +16,31 @@ pub struct TlvRecord { #[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)] pub struct TlvStream(pub Vec); -#[derive(Debug)] +#[derive(Debug, Error)] pub enum TlvError { + #[error("duplicate tlv type {0}")] DuplicateType(u64), + #[error("tlv types are not strictly increasing")] NotSorted, + #[error("length mismatch type {0}: expected {1}, got {2}")] LengthMismatch(u64, usize, usize), + #[error("truncated input")] Truncated, + #[error("non-canonical bigsize encoding")] NonCanonicalBigSize, + #[error("leftover bytes after parsing")] TrailingBytes, - Hex(hex::FromHexError), - Other(String), + #[error("")] + Hex(#[from] hex::FromHexError), + #[error("length overflow")] + Overflow, + #[error("tu64 is not minimal, got a leading zero")] + LeadingZero, + #[error("failed to parse bytes to u64")] + BytesToU64, } -impl fmt::Display for TlvError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - TlvError::DuplicateType(t) => write!(f, "duplicate tlv type {}", t), - TlvError::NotSorted => write!(f, "tlv types must be strictly increasing"), - TlvError::LengthMismatch(t, e, g) => { - write!(f, "length mismatch type {}: expected {}, got {}", t, e, g) - } - TlvError::Truncated => write!(f, "truncated input"), - TlvError::NonCanonicalBigSize => write!(f, "non-canonical bigsize encoding"), - TlvError::TrailingBytes => write!(f, "leftover bytes after parsing"), - TlvError::Hex(e) => write!(f, "hex error: {}", e), - TlvError::Other(s) => write!(f, "{}", s), - } - } -} - -impl std::error::Error for TlvError {} -impl From for TlvError { - fn from(e: hex::FromHexError) -> Self { - TlvError::Hex(e) - } -} +type Result = std::result::Result; impl TlvStream { pub fn to_bytes(&mut self) -> Result> { @@ -82,7 +72,7 @@ impl TlvStream { let (len, n2) = decode_bigsize(bytes)?; bytes = &bytes[n2..]; - let l = usize::try_from(len).map_err(|_| TlvError::Other("length too large".into()))?; + let l = usize::try_from(len).map_err(|_| TlvError::Overflow)?; if bytes.len() < l { return Err(TlvError::Truncated.into()); } @@ -111,8 +101,7 @@ impl TlvStream { let (length, length_bytes) = decode_bigsize(bytes)?; let remaining = &bytes[length_bytes..]; - let length_usize = usize::try_from(length) - .map_err(|_| TlvError::Other("length prefix too large".into()))?; + let length_usize = usize::try_from(length).map_err(|_| TlvError::Overflow)?; if remaining.len() != length_usize { return Err(TlvError::LengthMismatch(0, length_usize, remaining.len()).into()); @@ -181,7 +170,7 @@ impl TlvStream { /// Read a `tu64` if present, validating minimal encoding. /// Returns Ok(None) if the type isn't present. - pub fn get_tu64(&self, type_: u64) -> Result, TlvError> { + pub fn get_tu64(&self, type_: u64) -> Result> { if let Some(rec) = self.0.iter().find(|r| r.type_ == type_) { Ok(Some(decode_tu64(&rec.value)?)) } else { @@ -202,13 +191,10 @@ impl TlvStream { } /// Read a `u64` if present.Returns Ok(None) if the type isn't present. - pub fn get_u64(&self, type_: u64) -> Result, TlvError> { + pub fn get_u64(&self, type_: u64) -> Result> { if let Some(rec) = self.0.iter().find(|r| r.type_ == type_) { - let value = u64::from_be_bytes( - rec.value[..] - .try_into() - .map_err(|e| TlvError::Other(format!("failed not decode to u64: {e}")))?, - ); + let value = + u64::from_be_bytes(rec.value[..].try_into().map_err(|_| TlvError::BytesToU64)?); Ok(Some(value)) } else { Ok(None) @@ -217,7 +203,7 @@ impl TlvStream { } impl Serialize for TlvStream { - fn serialize(&self, serializer: S) -> Result { + fn serialize(&self, serializer: S) -> std::result::Result { let mut tmp = self.clone(); let bytes = tmp.to_bytes().map_err(serde::ser::Error::custom)?; serializer.serialize_str(&hex::encode(bytes)) @@ -225,14 +211,14 @@ impl Serialize for TlvStream { } impl<'de> Deserialize<'de> for TlvStream { - fn deserialize>(deserializer: D) -> Result { + fn deserialize>(deserializer: D) -> std::result::Result { struct V; impl<'de> serde::de::Visitor<'de> for V { type Value = TlvStream; fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "a hex string representing a Lightning TLV stream") } - fn visit_str(self, s: &str) -> Result { + fn visit_str(self, s: &str) -> std::result::Result { let bytes = hex::decode(s).map_err(E::custom)?; TlvStream::from_bytes_auto(&bytes).map_err(E::custom) } @@ -242,8 +228,8 @@ impl<'de> Deserialize<'de> for TlvStream { } impl TryFrom<&[u8]> for TlvStream { - type Error = anyhow::Error; - fn try_from(value: &[u8]) -> Result { + type Error = TlvError; + fn try_from(value: &[u8]) -> std::result::Result { TlvStream::from_bytes(value) } } @@ -326,15 +312,15 @@ pub fn encode_tu64(v: u64) -> Vec { /// Decode a BOLT #1 `tu64`, enforcing minimal form. /// Empty slice -> 0. Leading 0x00 or >8 bytes is invalid. -fn decode_tu64(raw: &[u8]) -> Result { +fn decode_tu64(raw: &[u8]) -> Result { if raw.is_empty() { return Ok(0); } if raw.len() > 8 { - return Err(TlvError::Other("tu64 too long".into())); + return Err(TlvError::Overflow); } if raw[0] == 0 { - return Err(TlvError::Other("non-minimal tu64 (leading zero)".into())); + return Err(TlvError::LeadingZero); } let mut buf = [0u8; 8]; buf[8 - raw.len()..].copy_from_slice(raw);