common/htlc_wire: add towire/fromwire helpers for wrapped tlv streams.

And make sure we check the length properly in fromwire!

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
This commit is contained in:
Rusty Russell
2025-08-14 11:40:46 +09:30
parent 63065aa99c
commit 6fbc5d02ca

View File

@@ -78,6 +78,18 @@ struct existing_htlc *new_existing_htlc(const tal_t *ctx,
return existing;
}
static void towire_len_and_tlvstream(u8 **pptr, struct tlv_field *extra_tlvs)
{
/* Making a copy is a bit awful, but it's the easiest way to
* get the length */
u8 *tmp_pptr = tal_arr(tmpctx, u8, 0);
towire_tlvstream_raw(&tmp_pptr, extra_tlvs);
assert(tal_bytelen(tmp_pptr) == (u16)tal_bytelen(tmp_pptr));
towire_u16(pptr, tal_bytelen(tmp_pptr));
towire_u8_array(pptr, tmp_pptr, tal_bytelen(tmp_pptr));
}
/* FIXME: We could adapt tools/generate-wire.py to generate structures
* and code like this. */
void towire_added_htlc(u8 **pptr, const struct added_htlc *added)
@@ -94,13 +106,8 @@ void towire_added_htlc(u8 **pptr, const struct added_htlc *added)
} else
towire_bool(pptr, false);
if (added->extra_tlvs) {
u8 *tmp_pptr = tal_arr(tmpctx, u8, 0);
towire_tlvstream_raw(&tmp_pptr, added->extra_tlvs);
towire_bool(pptr, true);
towire_u16(pptr, tal_bytelen(tmp_pptr));
towire_u8_array(pptr, tmp_pptr,
tal_bytelen(tmp_pptr));
towire_len_and_tlvstream(pptr, added->extra_tlvs);
} else
towire_bool(pptr, false);
towire_bool(pptr, added->fail_immediate);
@@ -131,13 +138,8 @@ void towire_existing_htlc(u8 **pptr, const struct existing_htlc *existing)
} else
towire_bool(pptr, false);
if (existing->extra_tlvs) {
u8 *tmp_pptr = tal_arr(tmpctx, u8, 0);
towire_tlvstream_raw(&tmp_pptr, existing->extra_tlvs);
towire_bool(pptr, true);
towire_u16(pptr, tal_bytelen(tmp_pptr));
towire_u8_array(pptr, tmp_pptr,
tal_bytelen(tmp_pptr));
towire_len_and_tlvstream(pptr, existing->extra_tlvs);
} else
towire_bool(pptr, false);
}
@@ -192,6 +194,28 @@ void towire_shachain(u8 **pptr, const struct shachain *shachain)
}
}
static struct tlv_field *fromwire_len_and_tlvstream(const tal_t *ctx,
const u8 **cursor, size_t *max)
{
struct tlv_field *tlvs = tal_arr(ctx, struct tlv_field, 0);
size_t len = fromwire_u16(cursor, max);
/* Subtle: we are not using fromwire_tal_arrn here, which
* would do this. */
if (len > *max) {
fromwire_fail(cursor, max);
return NULL;
}
/* NOTE: We might consider to be more strict and only allow for
* known tlv types from the tlvs_tlv_update_add_htlc_tlvs
* record. */
if (!fromwire_tlv(cursor, &len, NULL, 0, cast_const(void *, ctx),
&tlvs, FROMWIRE_TLV_ANY_TYPE, NULL, NULL))
return tal_free(tlvs);
return tlvs;
}
void fromwire_added_htlc(const u8 **cursor, size_t *max,
struct added_htlc *added)
{
@@ -207,17 +231,7 @@ void fromwire_added_htlc(const u8 **cursor, size_t *max,
} else
added->path_key = NULL;
if (fromwire_bool(cursor, max)) {
size_t tlv_len = fromwire_u16(cursor, max);
/* NOTE: We might consider to be more strict and only allow for
* known tlv types from the tlvs_tlv_update_add_htlc_tlvs
* record. */
const u64 *allowed = cast_const(u64 *, FROMWIRE_TLV_ANY_TYPE);
added->extra_tlvs = tal_arr(added, struct tlv_field, 0);
if (!fromwire_tlv(cursor, &tlv_len, NULL, 0, added,
&added->extra_tlvs, allowed, NULL, NULL)) {
tal_free(added->extra_tlvs);
added->extra_tlvs = NULL;
}
added->extra_tlvs = fromwire_len_and_tlvstream(added, cursor, max);
} else
added->extra_tlvs = NULL;
added->fail_immediate = fromwire_bool(cursor, max);
@@ -250,17 +264,7 @@ struct existing_htlc *fromwire_existing_htlc(const tal_t *ctx,
} else
existing->path_key = NULL;
if (fromwire_bool(cursor, max)) {
size_t tlv_len = fromwire_u16(cursor, max);
/* NOTE: We might consider to be more strict and only allow for
* known tlv types from the tlvs_tlv_update_add_htlc_tlvs
* record. */
const u64 *allowed = cast_const(u64 *, FROMWIRE_TLV_ANY_TYPE);
existing->extra_tlvs = tal_arr(existing, struct tlv_field, 0);
if (!fromwire_tlv(cursor, &tlv_len, NULL, 0, existing,
&existing->extra_tlvs, allowed, NULL, NULL)) {
tal_free(existing->extra_tlvs);
existing->extra_tlvs = NULL;
}
existing->extra_tlvs = fromwire_len_and_tlvstream(existing, cursor, max);
} else
existing->extra_tlvs = NULL;
return existing;