diff --git a/plugins/lsps-plugin/src/lsps2/cln.rs b/plugins/lsps-plugin/src/lsps2/cln.rs new file mode 100644 index 000000000..6e3d6d232 --- /dev/null +++ b/plugins/lsps-plugin/src/lsps2/cln.rs @@ -0,0 +1,727 @@ +//! Backfill structs for missing or incomplete Core Lightning types. +//! +//! This module provides struct implementations that are not available or +//! fully accessible in the core-lightning crate, enabling better compatibility +//! and interoperability with Core Lightning's RPC interface. +use cln_rpc::primitives::{Amount, ShortChannelId}; +use hex::FromHex; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +use crate::lsps2::cln::tlv::TlvStream; + +pub const TLV_FORWARD_AMT: u64 = 2; +pub const TLV_OUTGOING_CLTV: u64 = 4; +pub const TLV_SHORT_CHANNEL_ID: u64 = 6; +pub const TLV_PAYMENT_SECRET: u64 = 8; + +#[derive(Debug, Deserialize)] +#[allow(unused)] +pub struct Onion { + pub forward_msat: Option, + #[serde(deserialize_with = "from_hex")] + pub next_onion: Vec, + pub outgoing_cltv_value: Option, + pub payload: TlvStream, + // pub payload: TlvStream, + #[serde(deserialize_with = "from_hex")] + pub shared_secret: Vec, + pub short_channel_id: Option, + pub total_msat: Option, + #[serde(rename = "type")] + pub type_: Option, +} + +#[derive(Debug, Deserialize)] +#[allow(unused)] +pub struct Htlc { + pub amount_msat: Amount, + pub cltv_expiry: u32, + pub cltv_expiry_relative: u16, + pub id: u64, + #[serde(deserialize_with = "from_hex")] + pub payment_hash: Vec, + pub short_channel_id: ShortChannelId, + pub extra_tlvs: Option, +} + +#[derive(Debug, Serialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum HtlcAcceptedResult { + Continue, + Fail, + Resolve, +} + +impl std::fmt::Display for HtlcAcceptedResult { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let s = match self { + HtlcAcceptedResult::Continue => "continue", + HtlcAcceptedResult::Fail => "fail", + HtlcAcceptedResult::Resolve => "resolve", + }; + write!(f, "{s}") + } +} + +#[derive(Debug, Deserialize)] +pub struct HtlcAcceptedRequest { + pub htlc: Htlc, + pub onion: Onion, + pub forward_to: Option, +} + +#[derive(Debug, Serialize)] +pub struct HtlcAcceptedResponse { + pub result: HtlcAcceptedResult, + #[serde(skip_serializing_if = "Option::is_none")] + pub payment_key: Option, + #[serde(skip_serializing_if = "Option::is_none", serialize_with = "to_hex")] + pub payload: Option>, + #[serde(skip_serializing_if = "Option::is_none", serialize_with = "to_hex")] + pub forward_to: Option>, + #[serde(skip_serializing_if = "Option::is_none", serialize_with = "to_hex")] + pub extra_tlvs: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub failure_message: Option, + #[serde(skip_serializing_if = "Option::is_none", serialize_with = "to_hex")] + pub failure_onion: Option>, +} + +impl HtlcAcceptedResponse { + pub fn continue_( + payload: Option>, + forward_to: Option>, + extra_tlvs: Option>, + ) -> Self { + Self { + result: HtlcAcceptedResult::Continue, + payment_key: None, + payload, + forward_to, + extra_tlvs, + failure_message: None, + failure_onion: None, + } + } + + pub fn fail(failure_message: Option, failure_onion: Option>) -> Self { + Self { + result: HtlcAcceptedResult::Fail, + payment_key: None, + payload: None, + forward_to: None, + extra_tlvs: None, + failure_message, + failure_onion, + } + } +} + +/// Deserializes a lowercase hex string to a `Vec`. +pub fn from_hex<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + use serde::de::Error; + String::deserialize(deserializer) + .and_then(|string| Vec::from_hex(string).map_err(|err| Error::custom(err.to_string()))) +} + +pub fn to_hex(bytes: &Option>, serializer: S) -> Result +where + S: Serializer, +{ + match bytes { + Some(data) => serializer.serialize_str(&hex::encode(data)), + None => serializer.serialize_none(), + } +} + +pub mod tlv { + use anyhow::Result; + use serde::{de::Error as DeError, Deserialize, Deserializer, Serialize, Serializer}; + use std::{convert::TryFrom, fmt}; + + #[derive(Clone, Debug)] + pub struct TlvRecord { + pub type_: u64, + pub value: Vec, + } + + #[derive(Clone, Debug, Default)] + pub struct TlvStream(pub Vec); + + #[derive(Debug)] + pub enum TlvError { + DuplicateType(u64), + NotSorted, + LengthMismatch(u64, usize, usize), + Truncated, + NonCanonicalBigSize, + TrailingBytes, + Hex(hex::FromHexError), + Other(String), + } + + impl fmt::Display for TlvError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TlvError::DuplicateType(t) => write!(f, "duplicate tlv type {}", t), + TlvError::NotSorted => write!(f, "tlv types must be strictly increasing"), + TlvError::LengthMismatch(t, e, g) => { + write!(f, "length mismatch type {}: expected {}, got {}", t, e, g) + } + TlvError::Truncated => write!(f, "truncated input"), + TlvError::NonCanonicalBigSize => write!(f, "non-canonical bigsize encoding"), + TlvError::TrailingBytes => write!(f, "leftover bytes after parsing"), + TlvError::Hex(e) => write!(f, "hex error: {}", e), + TlvError::Other(s) => write!(f, "{}", s), + } + } + } + + impl std::error::Error for TlvError {} + impl From for TlvError { + fn from(e: hex::FromHexError) -> Self { + TlvError::Hex(e) + } + } + + impl TlvStream { + pub fn to_bytes(&mut self) -> Result> { + self.0.sort_by_key(|r| r.type_); + for w in self.0.windows(2) { + if w[0].type_ == w[1].type_ { + return Err(TlvError::DuplicateType(w[0].type_).into()); + } + if w[0].type_ > w[1].type_ { + return Err(TlvError::NotSorted.into()); + } + } + let mut out = Vec::new(); + for rec in &self.0 { + out.extend(encode_bigsize(rec.type_)); + out.extend(encode_bigsize(rec.value.len() as u64)); + out.extend(&rec.value); + } + Ok(out) + } + + pub fn from_bytes(mut bytes: &[u8]) -> Result { + let mut recs = Vec::new(); + let mut last_type: Option = None; + + while !bytes.is_empty() { + let (t, n1) = decode_bigsize(bytes)?; + bytes = &bytes[n1..]; + let (len, n2) = decode_bigsize(bytes)?; + bytes = &bytes[n2..]; + + let l = + usize::try_from(len).map_err(|_| TlvError::Other("length too large".into()))?; + if bytes.len() < l { + return Err(TlvError::Truncated.into()); + } + let v = bytes[..l].to_vec(); + bytes = &bytes[l..]; + + if let Some(prev) = last_type { + if t == prev { + return Err(TlvError::DuplicateType(t).into()); + } + if t < prev { + return Err(TlvError::NotSorted.into()); + } + } + last_type = Some(t); + recs.push(TlvRecord { type_: t, value: v }); + } + Ok(TlvStream(recs)) + } + + pub fn from_bytes_with_length_prefix(bytes: &[u8]) -> Result { + if bytes.is_empty() { + return Err(TlvError::Truncated.into()); + } + + let (length, length_bytes) = decode_bigsize(bytes)?; + let remaining = &bytes[length_bytes..]; + + let length_usize = usize::try_from(length) + .map_err(|_| TlvError::Other("length prefix too large".into()))?; + + if remaining.len() != length_usize { + return Err(TlvError::LengthMismatch(0, length_usize, remaining.len()).into()); + } + + Self::from_bytes(remaining) + } + + /// Attempt to auto-detect whether the input has a length prefix or not + /// First tries to parse as length-prefixed, then falls back to raw TLV + /// parsing. + pub fn from_bytes_auto(bytes: &[u8]) -> Result { + // Try length-prefixed first + if let Ok(stream) = Self::from_bytes_with_length_prefix(bytes) { + return Ok(stream); + } + + // Fall back to raw TLV parsing + Self::from_bytes(bytes) + } + + /// Get a reference to the value of a TLV record by type. + pub fn get(&self, type_: u64) -> Option<&[u8]> { + self.0 + .iter() + .find(|rec| rec.type_ == type_) + .map(|rec| rec.value.as_slice()) + } + + /// Insert a TLV record (replaces if type already exists). + pub fn insert(&mut self, type_: u64, value: Vec) { + // If the type already exists, replace its value. + if let Some(rec) = self.0.iter_mut().find(|rec| rec.type_ == type_) { + rec.value = value; + return; + } + // Otherwise push and re-sort to maintain canonical order. + self.0.push(TlvRecord { type_, value }); + self.0.sort_by_key(|r| r.type_); + } + + /// Remove a record by type. + pub fn remove(&mut self, type_: u64) -> Option> { + if let Some(pos) = self.0.iter().position(|rec| rec.type_ == type_) { + Some(self.0.remove(pos).value) + } else { + None + } + } + + /// Check if a type exists. + pub fn contains(&self, type_: u64) -> bool { + self.0.iter().any(|rec| rec.type_ == type_) + } + + /// Insert or override a `tu64` value for `type_` (keeps canonical TLV order). + pub fn set_tu64(&mut self, type_: u64, value: u64) { + let enc = encode_tu64(value); + if let Some(rec) = self.0.iter_mut().find(|r| r.type_ == type_) { + rec.value = enc; + } else { + self.0.push(TlvRecord { type_, value: enc }); + self.0.sort_by_key(|r| r.type_); + } + } + + /// Read a `tu64` if present, validating minimal encoding. + /// Returns Ok(None) if the type isn't present. + pub fn get_tu64(&self, type_: u64) -> Result, TlvError> { + if let Some(rec) = self.0.iter().find(|r| r.type_ == type_) { + Ok(Some(decode_tu64(&rec.value)?)) + } else { + Ok(None) + } + } + } + + impl Serialize for TlvStream { + fn serialize(&self, serializer: S) -> Result { + let mut tmp = self.clone(); + let bytes = tmp.to_bytes().map_err(serde::ser::Error::custom)?; + serializer.serialize_str(&hex::encode(bytes)) + } + } + + impl<'de> Deserialize<'de> for TlvStream { + fn deserialize>(deserializer: D) -> Result { + struct V; + impl<'de> serde::de::Visitor<'de> for V { + type Value = TlvStream; + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "a hex string representing a Lightning TLV stream") + } + fn visit_str(self, s: &str) -> Result { + let bytes = hex::decode(s).map_err(E::custom)?; + TlvStream::from_bytes_auto(&bytes).map_err(E::custom) + } + } + deserializer.deserialize_str(V) + } + } + + impl TryFrom<&[u8]> for TlvStream { + type Error = anyhow::Error; + fn try_from(value: &[u8]) -> Result { + TlvStream::from_bytes(value) + } + } + + impl From> for TlvStream { + fn from(v: Vec) -> Self { + TlvStream(v) + } + } + + /// BOLT #1 BigSize encoding + fn encode_bigsize(x: u64) -> Vec { + let mut out = Vec::new(); + if x < 0xfd { + out.push(x as u8); + } else if x <= 0xffff { + out.push(0xfd); + out.extend_from_slice(&(x as u16).to_be_bytes()); + } else if x <= 0xffff_ffff { + out.push(0xfe); + out.extend_from_slice(&(x as u32).to_be_bytes()); + } else { + out.push(0xff); + out.extend_from_slice(&x.to_be_bytes()); + } + out + } + + fn decode_bigsize(input: &[u8]) -> Result<(u64, usize)> { + if input.is_empty() { + return Err(TlvError::Truncated.into()); + } + match input[0] { + n @ 0x00..=0xfc => Ok((n as u64, 1)), + 0xfd => { + if input.len() < 3 { + return Err(TlvError::Truncated.into()); + } + let v = u16::from_be_bytes([input[1], input[2]]) as u64; + if v < 0xfd { + return Err(TlvError::NonCanonicalBigSize.into()); + } + Ok((v, 3)) + } + 0xfe => { + if input.len() < 5 { + return Err(TlvError::Truncated.into()); + } + let v = u32::from_be_bytes([input[1], input[2], input[3], input[4]]) as u64; + if v <= 0xffff { + return Err(TlvError::NonCanonicalBigSize.into()); + } + Ok((v, 5)) + } + 0xff => { + if input.len() < 9 { + return Err(TlvError::Truncated.into()); + } + let v = u64::from_be_bytes([ + input[1], input[2], input[3], input[4], input[5], input[6], input[7], input[8], + ]); + if v <= 0xffff_ffff { + return Err(TlvError::NonCanonicalBigSize.into()); + } + Ok((v, 9)) + } + } + } + + /// Encode a BOLT #1 `tu64`: big-endian, minimal length (no leading 0x00). + /// Value 0 is encoded as zero-length. + pub fn encode_tu64(v: u64) -> Vec { + if v == 0 { + return Vec::new(); + } + let bytes = v.to_be_bytes(); + let first = bytes.iter().position(|&b| b != 0).unwrap(); // safe: v != 0 + bytes[first..].to_vec() + } + + /// Decode a BOLT #1 `tu64`, enforcing minimal form. + /// Empty slice -> 0. Leading 0x00 or >8 bytes is invalid. + fn decode_tu64(raw: &[u8]) -> Result { + if raw.is_empty() { + return Ok(0); + } + if raw.len() > 8 { + return Err(TlvError::Other("tu64 too long".into())); + } + if raw[0] == 0 { + return Err(TlvError::Other("non-minimal tu64 (leading zero)".into())); + } + let mut buf = [0u8; 8]; + buf[8 - raw.len()..].copy_from_slice(raw); + Ok(u64::from_be_bytes(buf)) + } + + #[cfg(test)] + mod tests { + use super::*; + use anyhow::Result; + + // Small helpers to keep tests readable + fn rec(type_: u64, value: &[u8]) -> TlvRecord { + TlvRecord { + type_, + value: value.to_vec(), + } + } + + fn build_bytes(type_: u64, value: &[u8]) -> Vec { + let mut v = Vec::new(); + v.extend(super::encode_bigsize(type_)); + v.extend(super::encode_bigsize(value.len() as u64)); + v.extend(value); + v + } + + #[test] + fn encode_then_decode_roundtrip() -> Result<()> { + let mut stream = TlvStream(vec![rec(1, &[0x01, 0x02]), rec(5, &[0xaa])]); + + // Encode + let bytes = stream.to_bytes()?; + // Expect exact TLV sequence: + // type=1 -> 0x01, len=2 -> 0x02, value=0x01 0x02 + // type=5 -> 0x05, len=1 -> 0x01, value=0xaa + assert_eq!(hex::encode(&bytes), "010201020501aa"); + + // Decode back + let decoded = TlvStream::from_bytes(&bytes)?; + assert_eq!(decoded.0.len(), 2); + assert_eq!(decoded.0[0].type_, 1); + assert_eq!(decoded.0[0].value, vec![0x01, 0x02]); + assert_eq!(decoded.0[1].type_, 5); + assert_eq!(decoded.0[1].value, vec![0xaa]); + + Ok(()) + } + + #[test] + fn json_hex_roundtrip() -> Result<()> { + let stream = TlvStream(vec![rec(1, &[0x01, 0x02]), rec(5, &[0xaa])]); + + // Serialize to hex string in JSON + let json = serde_json::to_string(&stream)?; + // It's a quoted hex string; check inner value + let s: String = serde_json::from_str(&json)?; + assert_eq!(s, "010201020501aa"); + + // And back from JSON hex + let back: TlvStream = serde_json::from_str(&json)?; + assert_eq!(back.0.len(), 2); + assert_eq!(back.0[0].type_, 1); + assert_eq!(back.0[0].value, vec![0x01, 0x02]); + assert_eq!(back.0[1].type_, 5); + assert_eq!(back.0[1].value, vec![0xaa]); + + Ok(()) + } + + #[test] + fn decode_with_len_prefix() -> Result<()> { + let payload = "1202039896800401760608000073000f2c0007"; + let stream = TlvStream::from_bytes_with_length_prefix(&hex::decode(payload).unwrap())?; + // let stream: TlvStream = serde_json::from_str(payload)?; + println!("TLV {:?}", stream.0); + + Ok(()) + } + + #[test] + fn bigsize_boundary_minimal_encodings() -> Result<()> { + // Types at 0xfc, 0xfd, 0x10000 to exercise size switches + let mut stream = TlvStream(vec![ + rec(0x00fc, &[0x11]), // single-byte bigsize + rec(0x00fd, &[0x22]), // 0xfd prefix + u16 + rec(0x0001_0000, &[0x33]), // 0xfe prefix + u32 + ]); + + let bytes = stream.to_bytes()?; // just ensure it encodes + // Decode back to confirm roundtrip/canonical encodings accepted + let back = TlvStream::from_bytes(&bytes)?; + assert_eq!(back.0[0].type_, 0x00fc); + assert_eq!(back.0[1].type_, 0x00fd); + assert_eq!(back.0[2].type_, 0x0001_0000); + Ok(()) + } + + #[test] + fn decode_rejects_non_canonical_bigsize() { + // (1) Non-canonical: 0xfd 00 fc encodes 0xfc but should be a single byte + let mut bytes = Vec::new(); + bytes.extend([0xfd, 0x00, 0xfc]); // non-canonical type + bytes.extend([0x01]); // len = 1 + bytes.extend([0x00]); // value + let err = TlvStream::from_bytes(&bytes).unwrap_err(); + assert!(format!("{}", err).contains("non-canonical")); + + // (2) Non-canonical: 0xfe 00 00 00 ff encodes 0xff but should be 0xfd-form + let mut bytes = Vec::new(); + bytes.extend([0xfe, 0x00, 0x00, 0x00, 0xff]); + bytes.extend([0x01]); + bytes.extend([0x00]); + let err = TlvStream::from_bytes(&bytes).unwrap_err(); + assert!(format!("{}", err).contains("non-canonical")); + + // (3) Non-canonical: 0xff 00..01 encodes 1, which should be single byte + let mut bytes = Vec::new(); + bytes.extend([0xff, 0, 0, 0, 0, 0, 0, 0, 1]); + bytes.extend([0x01]); + bytes.extend([0x00]); + let err = TlvStream::from_bytes(&bytes).unwrap_err(); + assert!(format!("{}", err).contains("non-canonical")); + } + + #[test] + fn decode_rejects_out_of_order_types() { + // Build two TLVs but put type 5 before type 1 + let mut bad = Vec::new(); + bad.extend(build_bytes(5, &[0xaa])); + bad.extend(build_bytes(1, &[0x00])); + + let err = TlvStream::from_bytes(&bad).unwrap_err(); + assert!( + format!("{}", err).contains("increasing") || format!("{}", err).contains("sorted"), + "expected ordering error, got: {err}" + ); + } + + #[test] + fn decode_rejects_duplicate_types() { + // Two records with same type=1 + let mut bad = Vec::new(); + bad.extend(build_bytes(1, &[0x01])); + bad.extend(build_bytes(1, &[0x02])); + let err = TlvStream::from_bytes(&bad).unwrap_err(); + assert!( + format!("{}", err).contains("duplicate"), + "expected duplicate error, got: {err}" + ); + } + + #[test] + fn encode_rejects_duplicate_types() { + // insert duplicate types and expect encode to fail + let mut s = TlvStream(vec![rec(1, &[0x01]), rec(1, &[0x02])]); + let err = s.to_bytes().unwrap_err(); + assert!( + format!("{}", err).contains("duplicate"), + "expected duplicate error, got: {err}" + ); + } + + #[test] + fn decode_truncated_value() { + // type=1, len=2 but only 1 byte of value provided + let mut bytes = Vec::new(); + bytes.extend(encode_bigsize(1)); + bytes.extend(encode_bigsize(2)); + bytes.push(0x00); // missing one more byte + let err = TlvStream::from_bytes(&bytes).unwrap_err(); + assert!( + format!("{}", err).contains("truncated"), + "expected truncated error, got: {err}" + ); + } + + #[test] + fn set_and_get_tu64_basic() -> Result<()> { + let mut s = TlvStream::default(); + s.set_tu64(42, 123456789); + assert_eq!(s.get_tu64(42)?, Some(123456789)); + Ok(()) + } + + #[test] + fn set_tu64_overwrite_keeps_order() -> Result<()> { + let mut s = TlvStream(vec![ + TlvRecord { + type_: 1, + value: vec![0xaa], + }, + TlvRecord { + type_: 10, + value: vec![0xbb], + }, + ]); + + // insert between 1 and 10 + s.set_tu64(5, 7); + assert_eq!( + s.0.iter().map(|r| r.type_).collect::>(), + vec![1, 5, 10] + ); + assert_eq!(s.get_tu64(5)?, Some(7)); + + // overwrite existing 5 (no duplicate, order preserved) + s.set_tu64(5, 9); + let types: Vec = s.0.iter().map(|r| r.type_).collect(); + assert_eq!(types, vec![1, 5, 10]); + assert_eq!(s.0.iter().filter(|r| r.type_ == 5).count(), 1); + assert_eq!(s.get_tu64(5)?, Some(9)); + Ok(()) + } + + #[test] + fn tu64_zero_encodes_empty_and_roundtrips() -> Result<()> { + let mut s = TlvStream::default(); + s.set_tu64(3, 0); + + // stored value is zero-length + let rec = s.0.iter().find(|r| r.type_ == 3).unwrap(); + assert!(rec.value.is_empty()); + + // wire round-trip + let mut sc = s.clone(); + let bytes = sc.to_bytes()?; + let s2 = TlvStream::from_bytes(&bytes)?; + assert_eq!(s2.get_tu64(3)?, Some(0)); + Ok(()) + } + + #[test] + fn get_tu64_missing_returns_none() -> Result<()> { + let s = TlvStream::default(); + assert_eq!(s.get_tu64(999)?, None); + Ok(()) + } + + #[test] + fn get_tu64_rejects_non_minimal_and_too_long() { + // non-minimal: leading zero + let mut s = TlvStream::default(); + s.0.push(TlvRecord { + type_: 9, + value: vec![0x00, 0x01], + }); + assert!(s.get_tu64(9).is_err()); + + // too long: 9 bytes + let mut s2 = TlvStream::default(); + s2.0.push(TlvRecord { + type_: 9, + value: vec![0; 9], + }); + assert!(s2.get_tu64(9).is_err()); + } + + #[test] + fn tu64_multi_roundtrip_bytes_and_json() -> Result<()> { + let mut s = TlvStream::default(); + s.set_tu64(42, 0); + s.set_tu64(7, 256); + + // wire roundtrip + let mut sc = s.clone(); + let bytes = sc.to_bytes()?; + let s2 = TlvStream::from_bytes(&bytes)?; + assert_eq!(s2.get_tu64(42)?, Some(0)); + assert_eq!(s2.get_tu64(7)?, Some(256)); + + // json hex roundtrip (custom Serialize/Deserialize) + let json = serde_json::to_string(&s)?; + let s3: TlvStream = serde_json::from_str(&json)?; + assert_eq!(s3.get_tu64(42)?, Some(0)); + assert_eq!(s3.get_tu64(7)?, Some(256)); + Ok(()) + } + } +} diff --git a/plugins/lsps-plugin/src/lsps2/handler.rs b/plugins/lsps-plugin/src/lsps2/handler.rs index a28587951..02e8a62f3 100644 --- a/plugins/lsps-plugin/src/lsps2/handler.rs +++ b/plugins/lsps-plugin/src/lsps2/handler.rs @@ -1,11 +1,15 @@ use crate::{ jsonrpc::{server::RequestHandler, JsonRpcResponse as _, RequestObject, RpcError}, - lsps0::primitives::ShortChannelId, + lsps0::primitives::{Msat, ShortChannelId}, lsps2::{ + cln::{HtlcAcceptedRequest, HtlcAcceptedResponse, TLV_FORWARD_AMT}, model::{ + compute_opening_fee, + failure_codes::{TEMPORARY_CHANNEL_FAILURE, UNKNOWN_NEXT_PEER}, DatastoreEntry, Lsps2BuyRequest, Lsps2BuyResponse, Lsps2GetInfoRequest, - Lsps2GetInfoResponse, Lsps2PolicyGetInfoRequest, Lsps2PolicyGetInfoResponse, - OpeningFeeParams, Promise, + Lsps2GetInfoResponse, Lsps2PolicyGetChannelCapacityRequest, + Lsps2PolicyGetChannelCapacityResponse, Lsps2PolicyGetInfoRequest, + Lsps2PolicyGetInfoResponse, OpeningFeeParams, Promise, }, DS_MAIN_KEY, DS_SUB_KEY, }, @@ -13,16 +17,25 @@ use crate::{ }; use anyhow::{Context, Result as AnyResult}; use async_trait::async_trait; +use bitcoin::hashes::Hash as _; +use chrono::Utc; use cln_rpc::{ model::{ - requests::{DatastoreMode, DatastoreRequest, GetinfoRequest}, - responses::{DatastoreResponse, GetinfoResponse}, + requests::{ + DatastoreMode, DatastoreRequest, DeldatastoreRequest, FundchannelRequest, + GetinfoRequest, ListdatastoreRequest, ListpeerchannelsRequest, + }, + responses::{ + DatastoreResponse, DeldatastoreResponse, FundchannelResponse, GetinfoResponse, + ListdatastoreResponse, ListpeerchannelsResponse, + }, }, + primitives::{Amount, AmountOrAll, ChannelState}, ClnRpc, }; -use log::warn; +use log::{debug, warn}; use rand::{rng, Rng as _}; -use std::path::PathBuf; +use std::{fmt, path::PathBuf, time::Duration}; #[async_trait] pub trait ClnApi: Send + Sync { @@ -31,9 +44,31 @@ pub trait ClnApi: Send + Sync { params: &Lsps2PolicyGetInfoRequest, ) -> AnyResult; + async fn lsps2_getchannelcapacity( + &self, + params: &Lsps2PolicyGetChannelCapacityRequest, + ) -> AnyResult; + async fn cln_getinfo(&self, params: &GetinfoRequest) -> AnyResult; async fn cln_datastore(&self, params: &DatastoreRequest) -> AnyResult; + + async fn cln_listdatastore( + &self, + params: &ListdatastoreRequest, + ) -> AnyResult; + + async fn cln_deldatastore( + &self, + params: &DeldatastoreRequest, + ) -> AnyResult; + + async fn cln_fundchannel(&self, params: &FundchannelRequest) -> AnyResult; + + async fn cln_listpeerchannels( + &self, + params: &ListpeerchannelsRequest, + ) -> AnyResult; } const DEFAULT_CLTV_EXPIRY_DELTA: u32 = 144; @@ -66,6 +101,17 @@ impl ClnApi for ClnApiRpc { .with_context(|| "calling dev-lsps2-getpolicy") } + async fn lsps2_getchannelcapacity( + &self, + params: &Lsps2PolicyGetChannelCapacityRequest, + ) -> AnyResult { + let mut rpc = self.create_rpc().await?; + rpc.call_raw("dev-lsps2-getchannelcapacity", params) + .await + .map_err(anyhow::Error::new) + .with_context(|| "calling dev-lsps2-getchannelcapacity") + } + async fn cln_getinfo(&self, params: &GetinfoRequest) -> AnyResult { let mut rpc = self.create_rpc().await?; rpc.call_typed(params) @@ -81,6 +127,47 @@ impl ClnApi for ClnApiRpc { .map_err(anyhow::Error::new) .with_context(|| "calling datastore") } + + async fn cln_listdatastore( + &self, + params: &ListdatastoreRequest, + ) -> AnyResult { + let mut rpc = self.create_rpc().await?; + rpc.call_typed(params) + .await + .map_err(anyhow::Error::new) + .with_context(|| "calling listdatastore") + } + + async fn cln_deldatastore( + &self, + params: &DeldatastoreRequest, + ) -> AnyResult { + let mut rpc = self.create_rpc().await?; + rpc.call_typed(params) + .await + .map_err(anyhow::Error::new) + .with_context(|| "calling deldatastore") + } + + async fn cln_fundchannel(&self, params: &FundchannelRequest) -> AnyResult { + let mut rpc = self.create_rpc().await?; + rpc.call_typed(params) + .await + .map_err(anyhow::Error::new) + .with_context(|| "calling fundchannel") + } + + async fn cln_listpeerchannels( + &self, + params: &ListpeerchannelsRequest, + ) -> AnyResult { + let mut rpc = self.create_rpc().await?; + rpc.call_typed(params) + .await + .map_err(anyhow::Error::new) + .with_context(|| "calling listpeerchannels") + } } /// Handler for the `lsps2.get_info` method. @@ -259,6 +346,330 @@ fn generate_jit_scid(best_blockheigt: u32) -> u64 { ((block as u64) << 40) | ((tx_idx as u64) << 16) | (output_idx as u64) } +pub struct HtlcAcceptedHookHandler { + api: A, + htlc_minimum_msat: u64, + backoff_listpeerchannels: Duration, +} + +impl HtlcAcceptedHookHandler { + pub fn new(api: A, htlc_minimum_msat: u64) -> Self { + Self { + api, + htlc_minimum_msat, + backoff_listpeerchannels: Duration::from_secs(10), + } + } + + pub async fn handle(&self, req: HtlcAcceptedRequest) -> AnyResult { + let scid = match req.onion.short_channel_id { + Some(scid) => scid, + None => { + // We are the final destination of this htlc. + return Ok(HtlcAcceptedResponse::continue_(None, None, None)); + } + }; + + // A) Is this SCID one that we care about? + let ds_req = ListdatastoreRequest { + key: Some(scid_ds_key(scid)), + }; + let ds_res = self.api.cln_listdatastore(&ds_req).await.map_err(|e| { + warn!("Failed to listpeerchannels via rpc {}", e); + RpcError::internal_error("Internal error") + })?; + + let (ds_rec, ds_gen) = match deserialize_by_key(&ds_res, scid_ds_key(scid)) { + Ok(r) => r, + Err(DsError::NotFound { .. }) => { + // We don't know the scid, continue. + return Ok(HtlcAcceptedResponse::continue_(None, None, None)); + } + Err(e @ DsError::MissingValue { .. }) + | Err(e @ DsError::HexDecode { .. }) + | Err(e @ DsError::JsonParse { .. }) => { + // We have a data issue, log and continue. + // Note: We may want to actually reject the htlc here or throw + // an error alltogether but we will try to fulfill this htlc for + // now. + warn!("datastore issue: {}", e); + return Ok(HtlcAcceptedResponse::continue_(None, None, None)); + } + }; + + // Fixme: Check that we don't have a channel yet with the peer that we await to + // become READY to use. + // --- + + // Fixme: We only accept no-mpp for now, mpp and other flows will be added later on + if ds_rec.expected_payment_size.is_some() { + warn!("mpp payments are not implemented yet"); + return Ok(HtlcAcceptedResponse::fail( + Some(UNKNOWN_NEXT_PEER.to_string()), + None, + )); + } + + // B) Is the fee option menu still valid? + let now = Utc::now(); + if now >= ds_rec.opening_fee_params.valid_until { + // Not valid anymore, remove from DS and fail HTLC. + let ds_req = DeldatastoreRequest { + generation: ds_gen, + key: scid_ds_key(scid), + }; + match self.api.cln_deldatastore(&ds_req).await { + Ok(_) => debug!("removed datastore for scid: {}, wasn't valid anymore", scid), + Err(e) => warn!("could not remove datastore for scid: {}: {}", scid, e), + }; + return Ok(HtlcAcceptedResponse::fail( + Some(TEMPORARY_CHANNEL_FAILURE.to_string()), + None, + )); + } + + // C) Is the amount in the boundaries of the fee menu? + if req.htlc.amount_msat.msat() < ds_rec.opening_fee_params.min_fee_msat.msat() + || req.htlc.amount_msat.msat() > ds_rec.opening_fee_params.max_payment_size_msat.msat() + { + // No! reject the HTLC. + debug!("amount_msat for scid: {}, was too low or to high", scid); + return Ok(HtlcAcceptedResponse::fail( + Some(UNKNOWN_NEXT_PEER.to_string()), + None, + )); + } + + // D) Check that the amount_msat covers the opening fee (only for non-mpp right now) + let opening_fee = if let Some(opening_fee) = compute_opening_fee( + req.htlc.amount_msat.msat(), + ds_rec.opening_fee_params.min_fee_msat.msat(), + ds_rec.opening_fee_params.proportional.ppm() as u64, + ) { + if opening_fee + self.htlc_minimum_msat >= req.htlc.amount_msat.msat() { + debug!("amount_msat for scid: {}, does not cover opening fee", scid); + return Ok(HtlcAcceptedResponse::fail( + Some(UNKNOWN_NEXT_PEER.to_string()), + None, + )); + } + opening_fee + } else { + // The computation overflowed. + debug!("amount_msat for scid: {}, was too low or to high", scid); + return Ok(HtlcAcceptedResponse::fail( + Some(UNKNOWN_NEXT_PEER.to_string()), + None, + )); + }; + + // E) We made it, open a channel to the peer. + let ch_cap_req = Lsps2PolicyGetChannelCapacityRequest { + opening_fee_params: ds_rec.opening_fee_params, + init_payment_size: Msat::from_msat(req.htlc.amount_msat.msat()), + scid, + }; + let ch_cap_res = match self.api.lsps2_getchannelcapacity(&ch_cap_req).await { + Ok(r) => r, + Err(e) => { + warn!("failed to get channel capacity for scid {}: {}", scid, e); + return Ok(HtlcAcceptedResponse::fail( + Some(UNKNOWN_NEXT_PEER.to_string()), + None, + )); + } + }; + + let cap = match ch_cap_res.channel_capacity_msat { + Some(c) => c, + None => { + debug!("policy giver does not allow channel for scid {}", scid); + return Ok(HtlcAcceptedResponse::fail( + Some(UNKNOWN_NEXT_PEER.to_string()), + None, + )); + } + }; + + // We take the policy-giver seriously, if the capacity is too low, we + // still try to open the channel. + // Fixme: We may check that the capacity is ge than the + // (amount_msat - opening fee) in the future. + // Fixme: Make this configurable, maybe return the whole request from + // the policy giver? + let fund_ch_req = FundchannelRequest { + announce: Some(false), + close_to: None, + compact_lease: None, + feerate: None, + minconf: None, + mindepth: Some(0), + push_msat: None, + request_amt: None, + reserve: None, + channel_type: None, // Fimxe: Core-Lightning is complaining that it doesn't support these channel_types + // channel_type: Some(vec![46, 50]), // Sets `option_zeroconf` and `option_scid_alias` + utxos: None, + amount: AmountOrAll::Amount(Amount::from_msat(cap)), + id: ds_rec.peer_id, + }; + + let fund_ch_res = match self.api.cln_fundchannel(&fund_ch_req).await { + Ok(r) => r, + Err(e) => { + // Fixme: Retry to fund the channel. + warn!("could not fund jit channel for scid {}: {}", scid, e); + return Ok(HtlcAcceptedResponse::fail( + Some(UNKNOWN_NEXT_PEER.to_string()), + None, + )); + } + }; + + // F) Wait for the peer to send `channel_ready`. + // Fixme: Use event to check for channel ready, + // Fixme: Check for htlc timeout if peer refuses to send "ready". + // Fixme: handle unexpected channel states. + let mut is_active = false; + while !is_active { + let ls_ch_req = ListpeerchannelsRequest { + id: Some(ds_rec.peer_id), + short_channel_id: None, + }; + let ls_ch_res = match self.api.cln_listpeerchannels(&ls_ch_req).await { + Ok(r) => r, + Err(e) => { + warn!("failed to fetch peer channels for scid {}: {}", scid, e); + tokio::time::sleep(self.backoff_listpeerchannels).await; + continue; + } + }; + let chs = ls_ch_res + .channels + .iter() + .find(|&ch| ch.channel_id.is_some_and(|id| id == fund_ch_res.channel_id)); + if let Some(ch) = chs { + debug!("jit channel for scid {} has state {:?}", scid, ch.state); + if ch.state == ChannelState::CHANNELD_NORMAL { + is_active = true; + } + } + tokio::time::sleep(self.backoff_listpeerchannels).await; + } + + // G) We got a working channel, deduct fee and forward htlc. + let deducted_amt_msat = req.htlc.amount_msat.msat() - opening_fee; + let mut payload = req.onion.payload.clone(); + payload.set_tu64(TLV_FORWARD_AMT, deducted_amt_msat); + + // It is okay to unwrap the next line as we do not have duplicate entries. + let payload_bytes = payload.to_bytes().unwrap(); + debug!("about to send payload: {:02x?}", &payload_bytes); + + let mut extra_tlvs = req.htlc.extra_tlvs.unwrap_or_default().clone(); + extra_tlvs.set_u64(65537, opening_fee); + let extra_tlvs_bytes = extra_tlvs.to_bytes().unwrap(); + debug!("extra_tlv: {:02x?}", extra_tlvs_bytes); + + Ok(HtlcAcceptedResponse::continue_( + Some(payload_bytes), + Some(fund_ch_res.channel_id.as_byte_array().to_vec()), + Some(extra_tlvs_bytes), + )) + } +} + +#[derive(Debug)] +pub enum DsError { + /// No datastore entry with this exact key. + NotFound { key: Vec }, + /// Entry existed but had neither `string` nor `hex`. + MissingValue { key: Vec }, + /// JSON parse failed (from `string` or decoded `hex`). + JsonParse { + key: Vec, + source: serde_json::Error, + }, + /// Hex decode failed. + HexDecode { + key: Vec, + source: hex::FromHexError, + }, +} + +impl fmt::Display for DsError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + DsError::NotFound { key } => write!(f, "no datastore entry for key {:?}", key), + DsError::MissingValue { key } => write!( + f, + "datastore entry had neither `string` nor `hex` for key {:?}", + key + ), + DsError::JsonParse { key, source } => { + write!(f, "failed to parse JSON at key {:?}: {}", key, source) + } + DsError::HexDecode { key, source } => { + write!(f, "failed to decode hex at key {:?}: {}", key, source) + } + } + } +} + +impl std::error::Error for DsError {} + +fn scid_ds_key(scid: ShortChannelId) -> Vec { + vec![ + DS_MAIN_KEY.to_string(), + DS_SUB_KEY.to_string(), + scid.to_string(), + ] +} + +pub fn deserialize_by_key( + resp: &ListdatastoreResponse, + key: K, +) -> std::result::Result<(DatastoreEntry, Option), DsError> +where + K: AsRef<[String]>, +{ + let wanted: &[String] = key.as_ref(); + + let ds = resp + .datastore + .iter() + .find(|d| d.key.as_slice() == wanted) + .ok_or_else(|| DsError::NotFound { + key: wanted.to_vec(), + })?; + + // Prefer `string`, fall back to `hex` + if let Some(s) = &ds.string { + let value = serde_json::from_str::(s).map_err(|e| DsError::JsonParse { + key: ds.key.clone(), + source: e, + })?; + return Ok((value, ds.generation)); + } + + if let Some(hx) = &ds.hex { + let bytes = hex::decode(hx).map_err(|e| DsError::HexDecode { + key: ds.key.clone(), + source: e, + })?; + let value = + serde_json::from_slice::(&bytes).map_err(|e| DsError::JsonParse { + key: ds.key.clone(), + source: e, + })?; + return Ok((value, ds.generation)); + } + + Err(DsError::MissingValue { + key: ds.key.clone(), + }) +} + #[cfg(test)] mod tests { use std::sync::{Arc, Mutex}; @@ -267,12 +678,18 @@ mod tests { use crate::{ jsonrpc::{JsonRpcRequest, ResponseObject}, lsps0::primitives::{Msat, Ppm}, - lsps2::model::PolicyOpeningFeeParams, + lsps2::{ + cln::{tlv::TlvStream, HtlcAcceptedResult}, + model::PolicyOpeningFeeParams, + }, util::wrap_payload_with_peer_id, }; use chrono::{TimeZone, Utc}; - use cln_rpc::primitives::{Amount, PublicKey}; - use cln_rpc::RpcError as ClnRpcError; + use cln_rpc::{model::responses::ListdatastoreDatastore, RpcError as ClnRpcError}; + use cln_rpc::{ + model::responses::ListpeerchannelsChannels, + primitives::{Amount, PublicKey, Sha256}, + }; use serde::Serialize; const PUBKEY: [u8; 33] = [ @@ -324,6 +741,17 @@ mod tests { cln_getinfo_error: Arc>>, cln_datastore_response: Arc>>, cln_datastore_error: Arc>>, + cln_listdatastore_response: Arc>>, + cln_listdatastore_error: Arc>>, + cln_deldatastore_response: Arc>>, + cln_deldatastore_error: Arc>>, + cln_fundchannel_response: Arc>>, + cln_fundchannel_error: Arc>>, + cln_listpeerchannels_response: Arc>>, + cln_listpeerchannels_error: Arc>>, + lsps2_getchannelcapacity_response: + Arc>>, + lsps2_getchannelcapacity_error: Arc>>, } #[async_trait] @@ -341,6 +769,24 @@ mod tests { panic!("No lsps2 response defined"); } + async fn lsps2_getchannelcapacity( + &self, + _params: &Lsps2PolicyGetChannelCapacityRequest, + ) -> AnyResult { + if let Some(err) = self.lsps2_getchannelcapacity_error.lock().unwrap().take() { + return Err(anyhow::Error::new(err).context("from fake api")); + } + if let Some(res) = self + .lsps2_getchannelcapacity_response + .lock() + .unwrap() + .take() + { + return Ok(res); + } + panic!("No lsps2 getchannelcapacity response defined"); + } + async fn cln_getinfo( &self, _params: &GetinfoRequest, @@ -366,6 +812,168 @@ mod tests { }; panic!("No cln datastore response defined"); } + + async fn cln_listdatastore( + &self, + _params: &ListdatastoreRequest, + ) -> AnyResult { + if let Some(err) = self.cln_listdatastore_error.lock().unwrap().take() { + return Err(anyhow::Error::new(err).context("from fake api")); + } + if let Some(res) = self.cln_listdatastore_response.lock().unwrap().take() { + return Ok(res); + } + panic!("No cln listdatastore response defined"); + } + + async fn cln_deldatastore( + &self, + _params: &DeldatastoreRequest, + ) -> AnyResult { + if let Some(err) = self.cln_deldatastore_error.lock().unwrap().take() { + return Err(anyhow::Error::new(err).context("from fake api")); + } + if let Some(res) = self.cln_deldatastore_response.lock().unwrap().take() { + return Ok(res); + } + panic!("No cln deldatastore response defined"); + } + + async fn cln_fundchannel( + &self, + _params: &FundchannelRequest, + ) -> AnyResult { + if let Some(err) = self.cln_fundchannel_error.lock().unwrap().take() { + return Err(anyhow::Error::new(err).context("from fake api")); + } + if let Some(res) = self.cln_fundchannel_response.lock().unwrap().take() { + return Ok(res); + } + panic!("No cln fundchannel response defined"); + } + + async fn cln_listpeerchannels( + &self, + _params: &ListpeerchannelsRequest, + ) -> AnyResult { + if let Some(err) = self.cln_listpeerchannels_error.lock().unwrap().take() { + return Err(anyhow::Error::new(err).context("from fake api")); + } + + if let Some(res) = self.cln_listpeerchannels_response.lock().unwrap().take() { + return Ok(res); + } + + // Default: return a ready channel + let channel = ListpeerchannelsChannels { + channel_id: Some(*Sha256::from_bytes_ref(&[1u8; 32])), + state: ChannelState::CHANNELD_NORMAL, + peer_id: create_peer_id(), + peer_connected: true, + alias: None, + closer: None, + funding: None, + funding_outnum: None, + funding_txid: None, + htlcs: None, + in_offered_msat: None, + initial_feerate: None, + last_feerate: None, + last_stable_connection: None, + last_tx_fee_msat: None, + lost_state: None, + max_accepted_htlcs: None, + minimum_htlc_in_msat: None, + next_feerate: None, + next_fee_step: None, + out_fulfilled_msat: None, + out_offered_msat: None, + owner: None, + private: None, + receivable_msat: None, + reestablished: None, + scratch_txid: None, + short_channel_id: None, + spendable_msat: None, + status: None, + their_reserve_msat: None, + to_us_msat: None, + total_msat: None, + close_to: None, + close_to_addr: None, + direction: None, + dust_limit_msat: None, + fee_base_msat: None, + fee_proportional_millionths: None, + feerate: None, + ignore_fee_limits: None, + in_fulfilled_msat: None, + in_payments_fulfilled: None, + in_payments_offered: None, + max_to_us_msat: None, + maximum_htlc_out_msat: None, + min_to_us_msat: None, + minimum_htlc_out_msat: None, + our_max_htlc_value_in_flight_msat: None, + our_reserve_msat: None, + our_to_self_delay: None, + out_payments_fulfilled: None, + out_payments_offered: None, + their_max_htlc_value_in_flight_msat: None, + their_to_self_delay: None, + updates: None, + inflight: None, + #[allow(deprecated)] + max_total_htlc_in_msat: None, + opener: cln_rpc::primitives::ChannelSide::LOCAL, + }; + + Ok(ListpeerchannelsResponse { + channels: vec![channel], + }) + } + } + + fn create_test_htlc_request( + scid: Option, + amount_msat: u64, + ) -> HtlcAcceptedRequest { + let payload = TlvStream::default(); + + HtlcAcceptedRequest { + onion: crate::lsps2::cln::Onion { + short_channel_id: scid, + payload, + next_onion: vec![], + forward_msat: None, + outgoing_cltv_value: None, + shared_secret: vec![], + total_msat: None, + type_: None, + }, + htlc: crate::lsps2::cln::Htlc { + amount_msat: Amount::from_msat(amount_msat), + cltv_expiry: 100, + cltv_expiry_relative: 10, + payment_hash: vec![], + extra_tlvs: None, + short_channel_id: ShortChannelId::from(123456789u64), + id: 0, + }, + forward_to: None, + } + } + + fn create_test_datastore_entry( + peer_id: PublicKey, + expected_payment_size: Option, + ) -> DatastoreEntry { + let (_, policy) = params_with_promise(&[0u8; 32]); + DatastoreEntry { + peer_id, + opening_fee_params: policy, + expected_payment_size, + } } fn minimal_getinfo(height: u32) -> GetinfoResponse { @@ -627,4 +1235,357 @@ mod tests { .unwrap_err(); assert_eq!(err.code, 203); } + #[tokio::test] + async fn test_htlc_no_scid_continues() { + let fake = FakeCln::default(); + let handler = HtlcAcceptedHookHandler::new(fake, 1000); + + // HTLC with no short_channel_id (final destination) + let req = create_test_htlc_request(None, 1000000); + + let result = handler.handle(req).await.unwrap(); + assert_eq!(result.result, HtlcAcceptedResult::Continue); + } + + #[tokio::test] + async fn test_htlc_unknown_scid_continues() { + let fake = FakeCln::default(); + let handler = HtlcAcceptedHookHandler::new(fake.clone(), 1000); + let scid = ShortChannelId::from(123456789u64); + + // Return empty datastore response (SCID not found) + *fake.cln_listdatastore_response.lock().unwrap() = + Some(ListdatastoreResponse { datastore: vec![] }); + + let req = create_test_htlc_request(Some(scid), 1000000); + + let result = handler.handle(req).await.unwrap(); + assert_eq!(result.result, HtlcAcceptedResult::Continue); + } + + #[tokio::test] + async fn test_htlc_expired_fee_menu_fails() { + let fake = FakeCln::default(); + let handler = HtlcAcceptedHookHandler::new(fake.clone(), 1000); + let peer_id = create_peer_id(); + let scid = ShortChannelId::from(123456789u64); + + // Create datastore entry with expired fee menu + let mut ds_entry = create_test_datastore_entry(peer_id, None); + ds_entry.opening_fee_params.valid_until = + Utc.with_ymd_and_hms(1970, 1, 1, 0, 0, 0).unwrap(); // expired + + let ds_entry_json = serde_json::to_string(&ds_entry).unwrap(); + *fake.cln_listdatastore_response.lock().unwrap() = Some(ListdatastoreResponse { + datastore: vec![ListdatastoreDatastore { + key: scid_ds_key(scid), + generation: Some(1), + hex: None, + string: Some(ds_entry_json), + }], + }); + + // Mock successful deletion + *fake.cln_deldatastore_response.lock().unwrap() = Some(DeldatastoreResponse { + generation: Some(1), + hex: None, + string: None, + key: scid_ds_key(scid), + }); + + let req = create_test_htlc_request(Some(scid), 1000000); + + let result = handler.handle(req).await.unwrap(); + assert_eq!(result.result, HtlcAcceptedResult::Fail); + assert_eq!( + result.failure_message.unwrap(), + TEMPORARY_CHANNEL_FAILURE.to_string() + ); + } + + #[tokio::test] + async fn test_htlc_amount_too_low_fails() { + let fake = FakeCln::default(); + let handler = HtlcAcceptedHookHandler::new(fake.clone(), 1000); + let peer_id = create_peer_id(); + let scid = ShortChannelId::from(123456789u64); + + let ds_entry = create_test_datastore_entry(peer_id, None); + let ds_entry_json = serde_json::to_string(&ds_entry).unwrap(); + + *fake.cln_listdatastore_response.lock().unwrap() = Some(ListdatastoreResponse { + datastore: vec![ListdatastoreDatastore { + key: scid_ds_key(scid), + generation: Some(1), + hex: None, + string: Some(ds_entry_json), + }], + }); + + // HTLC amount below minimum + let req = create_test_htlc_request(Some(scid), 100); + + let result = handler.handle(req).await.unwrap(); + assert_eq!(result.result, HtlcAcceptedResult::Fail); + assert_eq!( + result.failure_message.unwrap(), + UNKNOWN_NEXT_PEER.to_string() + ); + } + + #[tokio::test] + async fn test_htlc_amount_too_high_fails() { + let fake = FakeCln::default(); + let handler = HtlcAcceptedHookHandler::new(fake.clone(), 1000); + let peer_id = create_peer_id(); + let scid = ShortChannelId::from(123456789u64); + + let ds_entry = create_test_datastore_entry(peer_id, None); + let ds_entry_json = serde_json::to_string(&ds_entry).unwrap(); + + *fake.cln_listdatastore_response.lock().unwrap() = Some(ListdatastoreResponse { + datastore: vec![ListdatastoreDatastore { + key: scid_ds_key(scid), + generation: Some(1), + hex: None, + string: Some(ds_entry_json), + }], + }); + + // HTLC amount above maximum + let req = create_test_htlc_request(Some(scid), 200_000_000); + + let result = handler.handle(req).await.unwrap(); + assert_eq!(result.result, HtlcAcceptedResult::Fail); + assert_eq!( + result.failure_message.unwrap(), + UNKNOWN_NEXT_PEER.to_string() + ); + } + + #[tokio::test] + async fn test_htlc_amount_doesnt_cover_fee_fails() { + let fake = FakeCln::default(); + let handler = HtlcAcceptedHookHandler::new(fake.clone(), 1000); + let peer_id = create_peer_id(); + let scid = ShortChannelId::from(123456789u64); + + let ds_entry = create_test_datastore_entry(peer_id, None); + let ds_entry_json = serde_json::to_string(&ds_entry).unwrap(); + + *fake.cln_listdatastore_response.lock().unwrap() = Some(ListdatastoreResponse { + datastore: vec![ListdatastoreDatastore { + key: scid_ds_key(scid), + generation: Some(1), + hex: None, + string: Some(ds_entry_json), + }], + }); + + // HTLC amount just barely covers minimum fee but not minimum HTLC + let req = create_test_htlc_request(Some(scid), 2500); // min_fee is 2000, htlc_minimum is 1000 + + let result = handler.handle(req).await.unwrap(); + assert_eq!(result.result, HtlcAcceptedResult::Fail); + assert_eq!( + result.failure_message.unwrap(), + UNKNOWN_NEXT_PEER.to_string() + ); + } + + #[tokio::test] + async fn test_htlc_channel_capacity_request_fails() { + let fake = FakeCln::default(); + let handler = HtlcAcceptedHookHandler::new(fake.clone(), 1000); + let peer_id = create_peer_id(); + let scid = ShortChannelId::from(123456789u64); + + let ds_entry = create_test_datastore_entry(peer_id, None); + let ds_entry_json = serde_json::to_string(&ds_entry).unwrap(); + + *fake.cln_listdatastore_response.lock().unwrap() = Some(ListdatastoreResponse { + datastore: vec![ListdatastoreDatastore { + key: scid_ds_key(scid), + generation: Some(1), + hex: None, + string: Some(ds_entry_json), + }], + }); + + *fake.lsps2_getchannelcapacity_error.lock().unwrap() = Some(ClnRpcError { + code: Some(-1), + message: "capacity check failed".to_string(), + data: None, + }); + + let req = create_test_htlc_request(Some(scid), 10_000_000); + + let result = handler.handle(req).await.unwrap(); + assert_eq!(result.result, HtlcAcceptedResult::Fail); + assert_eq!( + result.failure_message.unwrap(), + UNKNOWN_NEXT_PEER.to_string() + ); + } + + #[tokio::test] + async fn test_htlc_policy_denies_channel() { + let fake = FakeCln::default(); + let handler = HtlcAcceptedHookHandler::new(fake.clone(), 1000); + let peer_id = create_peer_id(); + let scid = ShortChannelId::from(123456789u64); + + let ds_entry = create_test_datastore_entry(peer_id, None); + let ds_entry_json = serde_json::to_string(&ds_entry).unwrap(); + + *fake.cln_listdatastore_response.lock().unwrap() = Some(ListdatastoreResponse { + datastore: vec![ListdatastoreDatastore { + key: scid_ds_key(scid), + generation: Some(1), + hex: None, + string: Some(ds_entry_json), + }], + }); + + // Policy response with no channel capacity (denied) + *fake.lsps2_getchannelcapacity_response.lock().unwrap() = + Some(Lsps2PolicyGetChannelCapacityResponse { + channel_capacity_msat: None, + }); + + let req = create_test_htlc_request(Some(scid), 10_000_000); + + let result = handler.handle(req).await.unwrap(); + assert_eq!(result.result, HtlcAcceptedResult::Fail); + assert_eq!( + result.failure_message.unwrap(), + UNKNOWN_NEXT_PEER.to_string() + ); + } + + #[tokio::test] + async fn test_htlc_fund_channel_fails() { + let fake = FakeCln::default(); + let handler = HtlcAcceptedHookHandler::new(fake.clone(), 1000); + let peer_id = create_peer_id(); + let scid = ShortChannelId::from(123456789u64); + + let ds_entry = create_test_datastore_entry(peer_id, None); + let ds_entry_json = serde_json::to_string(&ds_entry).unwrap(); + + *fake.cln_listdatastore_response.lock().unwrap() = Some(ListdatastoreResponse { + datastore: vec![ListdatastoreDatastore { + key: scid_ds_key(scid), + generation: Some(1), + hex: None, + string: Some(ds_entry_json), + }], + }); + + *fake.lsps2_getchannelcapacity_response.lock().unwrap() = + Some(Lsps2PolicyGetChannelCapacityResponse { + channel_capacity_msat: Some(50_000_000), + }); + + *fake.cln_fundchannel_error.lock().unwrap() = Some(ClnRpcError { + code: Some(-1), + message: "insufficient funds".to_string(), + data: None, + }); + + let req = create_test_htlc_request(Some(scid), 10_000_000); + + let result = handler.handle(req).await.unwrap(); + assert_eq!(result.result, HtlcAcceptedResult::Fail); + assert_eq!( + result.failure_message.unwrap(), + UNKNOWN_NEXT_PEER.to_string() + ); + } + + #[tokio::test] + async fn test_htlc_successful_flow() { + let fake = FakeCln::default(); + let handler = HtlcAcceptedHookHandler { + api: fake.clone(), + htlc_minimum_msat: 1000, + backoff_listpeerchannels: Duration::from_millis(10), + }; + let peer_id = create_peer_id(); + let scid = ShortChannelId::from(123456789u64); + + let ds_entry = create_test_datastore_entry(peer_id, None); + let ds_entry_json = serde_json::to_string(&ds_entry).unwrap(); + + *fake.cln_listdatastore_response.lock().unwrap() = Some(ListdatastoreResponse { + datastore: vec![ListdatastoreDatastore { + key: scid_ds_key(scid), + generation: Some(1), + hex: None, + string: Some(ds_entry_json), + }], + }); + + *fake.lsps2_getchannelcapacity_response.lock().unwrap() = + Some(Lsps2PolicyGetChannelCapacityResponse { + channel_capacity_msat: Some(50_000_000), + }); + + *fake.cln_fundchannel_response.lock().unwrap() = Some(FundchannelResponse { + channel_id: *Sha256::from_bytes_ref(&[1u8; 32]), + outnum: 0, + txid: String::default(), + channel_type: None, + close_to: None, + mindepth: None, + tx: String::default(), + }); + + let req = create_test_htlc_request(Some(scid), 10_000_000); + + let result = handler.handle(req).await.unwrap(); + assert_eq!(result.result, HtlcAcceptedResult::Continue); + + assert!(result.payload.is_some()); + assert!(result.extra_tlvs.is_some()); + assert!(result.forward_to.is_some()); + + // The payload should have the deducted amount + let payload_bytes = result.payload.unwrap(); + let payload_tlv = TlvStream::from_bytes(&payload_bytes).unwrap(); + + // Should contain forward amount. + assert!(payload_tlv.get(TLV_FORWARD_AMT).is_some()); + } + + #[tokio::test] + async fn test_htlc_mpp_not_implemented() { + let fake = FakeCln::default(); + let handler = HtlcAcceptedHookHandler::new(fake.clone(), 1000); + let peer_id = create_peer_id(); + let scid = ShortChannelId::from(123456789u64); + + // Create entry with expected_payment_size (MPP mode) + let mut ds_entry = create_test_datastore_entry(peer_id, None); + ds_entry.expected_payment_size = Some(Msat::from_msat(1000000)); + let ds_entry_json = serde_json::to_string(&ds_entry).unwrap(); + + *fake.cln_listdatastore_response.lock().unwrap() = Some(ListdatastoreResponse { + datastore: vec![ListdatastoreDatastore { + key: scid_ds_key(scid), + generation: Some(1), + hex: None, + string: Some(ds_entry_json), + }], + }); + + let req = create_test_htlc_request(Some(scid), 10_000_000); + + let result = handler.handle(req).await.unwrap(); + assert_eq!(result.result, HtlcAcceptedResult::Fail); + assert_eq!( + result.failure_message.unwrap(), + UNKNOWN_NEXT_PEER.to_string() + ); + } } diff --git a/plugins/lsps-plugin/src/lsps2/mod.rs b/plugins/lsps-plugin/src/lsps2/mod.rs index 7ed8e7446..2b98aa1df 100644 --- a/plugins/lsps-plugin/src/lsps2/mod.rs +++ b/plugins/lsps-plugin/src/lsps2/mod.rs @@ -1,5 +1,6 @@ use cln_plugin::options; +pub mod cln; pub mod handler; pub mod model; diff --git a/plugins/lsps-plugin/src/lsps2/model.rs b/plugins/lsps-plugin/src/lsps2/model.rs index 7a186db0d..bcf8de079 100644 --- a/plugins/lsps-plugin/src/lsps2/model.rs +++ b/plugins/lsps-plugin/src/lsps2/model.rs @@ -7,6 +7,11 @@ use chrono::Utc; use log::debug; use serde::{Deserialize, Serialize}; +pub mod failure_codes { + pub const TEMPORARY_CHANNEL_FAILURE: &'static str = "1007"; + pub const UNKNOWN_NEXT_PEER: &'static str = "4010"; +} + #[derive(Clone, Debug, PartialEq)] pub enum Error { InvalidOpeningFeeParams, @@ -256,6 +261,18 @@ pub struct Lsps2PolicyGetInfoResponse { pub policy_opening_fee_params_menu: Vec, } +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct Lsps2PolicyGetChannelCapacityRequest { + pub opening_fee_params: OpeningFeeParams, + pub init_payment_size: Msat, + pub scid: ShortChannelId, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct Lsps2PolicyGetChannelCapacityResponse { + pub channel_capacity_msat: Option, +} + /// An internal representation of a policy of parameters for calculating the /// opening fee for a JIT channel. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] diff --git a/plugins/lsps-plugin/src/service.rs b/plugins/lsps-plugin/src/service.rs index ca620f9e5..b17c8e83b 100644 --- a/plugins/lsps-plugin/src/service.rs +++ b/plugins/lsps-plugin/src/service.rs @@ -6,6 +6,8 @@ use cln_lsps::jsonrpc::{server::JsonRpcServer, JsonRpcRequest}; use cln_lsps::lsps0::handler::Lsps0ListProtocolsHandler; use cln_lsps::lsps0::model::Lsps0listProtocolsRequest; use cln_lsps::lsps0::transport::{self, CustomMsg}; +use cln_lsps::lsps2::cln::{HtlcAcceptedRequest, HtlcAcceptedResponse}; +use cln_lsps::lsps2::handler::{ClnApiRpc, HtlcAcceptedHookHandler}; use cln_lsps::lsps2::model::{Lsps2BuyRequest, Lsps2GetInfoRequest}; use cln_lsps::util::wrap_payload_with_peer_id; use cln_lsps::{lsps0, lsps2, util, LSP_FEATURE_BIT}; @@ -27,6 +29,7 @@ const OPTION_ENABLED: options::FlagConfigOption = ConfigOption::new_flag( #[derive(Clone)] struct State { lsps_service: JsonRpcServer, + lsps2_enabled: bool, } #[tokio::main] @@ -44,6 +47,7 @@ async fn main() -> Result<(), anyhow::Error> { util::feature_bit_to_hex(LSP_FEATURE_BIT), ) .hook("custommsg", on_custommsg) + .hook("htlc_accepted", on_htlc_accepted) .configure() .await? { @@ -63,7 +67,7 @@ async fn main() -> Result<(), anyhow::Error> { }), ); - if plugin.option(&lsps2::OPTION_ENABLED)? { + let lsps2_enabled = if plugin.option(&lsps2::OPTION_ENABLED)? { log::debug!("lsps2 enabled"); let secret_hex = plugin.option(&lsps2::OPTION_PROMISE_SECRET)?; if let Some(secret_hex) = secret_hex { @@ -104,11 +108,17 @@ async fn main() -> Result<(), anyhow::Error> { ) .with_handler(Lsps2BuyRequest::METHOD.to_string(), Arc::new(buy_handler)); } - } + true + } else { + false + }; let lsps_service = lsps_builder.build(); - let state = State { lsps_service }; + let state = State { + lsps_service, + lsps2_enabled, + }; let plugin = plugin.start(state).await?; plugin.join().await } else { @@ -116,6 +126,27 @@ async fn main() -> Result<(), anyhow::Error> { } } +async fn on_htlc_accepted( + p: Plugin, + v: serde_json::Value, +) -> Result { + if !p.state().lsps2_enabled { + // just continue. + // Fixme: Add forward and extra tlvs from incoming. + let res = serde_json::to_value(&HtlcAcceptedResponse::continue_(None, None, None))?; + return Ok(res); + } + + let req: HtlcAcceptedRequest = serde_json::from_value(v)?; + let rpc_path = Path::new(&p.configuration().lightning_dir).join(&p.configuration().rpc_file); + let api = ClnApiRpc::new(rpc_path); + // Fixme: Use real htlc_minimum_amount. + let handler = HtlcAcceptedHookHandler::new(api, 1000); + let res = handler.handle(req).await?; + let res_val = serde_json::to_value(&res)?; + Ok(res_val) +} + async fn on_custommsg( p: Plugin, v: serde_json::Value, diff --git a/tests/plugins/lsps2_policy.py b/tests/plugins/lsps2_policy.py index e16eb4ae1..7588712df 100755 --- a/tests/plugins/lsps2_policy.py +++ b/tests/plugins/lsps2_policy.py @@ -41,5 +41,16 @@ def lsps2_getpolicy(request): ] } +@plugin.method("dev-lsps2-getchannelcapacity") +def lsps2_getchannelcapacity(request, init_payment_size, scid, opening_fee_params): + """ Returns an opening fee menu for the LSPS2 plugin. + """ + now = datetime.now(timezone.utc) + + # Is ISO 8601 format "YYYY-MM-DDThh:mm:ss.uuuZ" + valid_until = (now + timedelta(hours=1)).isoformat().replace('+00:00', 'Z') + + return { "channel_capacity_msat": 100000000 } + plugin.run()