diff --git a/common/htlc_wire.c b/common/htlc_wire.c index b50e7c291..49c8211de 100644 --- a/common/htlc_wire.c +++ b/common/htlc_wire.c @@ -78,6 +78,18 @@ struct existing_htlc *new_existing_htlc(const tal_t *ctx, return existing; } +static void towire_len_and_tlvstream(u8 **pptr, struct tlv_field *extra_tlvs) +{ + /* Making a copy is a bit awful, but it's the easiest way to + * get the length */ + u8 *tmp_pptr = tal_arr(tmpctx, u8, 0); + towire_tlvstream_raw(&tmp_pptr, extra_tlvs); + + assert(tal_bytelen(tmp_pptr) == (u16)tal_bytelen(tmp_pptr)); + towire_u16(pptr, tal_bytelen(tmp_pptr)); + towire_u8_array(pptr, tmp_pptr, tal_bytelen(tmp_pptr)); +} + /* FIXME: We could adapt tools/generate-wire.py to generate structures * and code like this. */ void towire_added_htlc(u8 **pptr, const struct added_htlc *added) @@ -94,13 +106,8 @@ void towire_added_htlc(u8 **pptr, const struct added_htlc *added) } else towire_bool(pptr, false); if (added->extra_tlvs) { - u8 *tmp_pptr = tal_arr(tmpctx, u8, 0); - towire_tlvstream_raw(&tmp_pptr, added->extra_tlvs); - towire_bool(pptr, true); - towire_u16(pptr, tal_bytelen(tmp_pptr)); - towire_u8_array(pptr, tmp_pptr, - tal_bytelen(tmp_pptr)); + towire_len_and_tlvstream(pptr, added->extra_tlvs); } else towire_bool(pptr, false); towire_bool(pptr, added->fail_immediate); @@ -131,13 +138,8 @@ void towire_existing_htlc(u8 **pptr, const struct existing_htlc *existing) } else towire_bool(pptr, false); if (existing->extra_tlvs) { - u8 *tmp_pptr = tal_arr(tmpctx, u8, 0); - towire_tlvstream_raw(&tmp_pptr, existing->extra_tlvs); - towire_bool(pptr, true); - towire_u16(pptr, tal_bytelen(tmp_pptr)); - towire_u8_array(pptr, tmp_pptr, - tal_bytelen(tmp_pptr)); + towire_len_and_tlvstream(pptr, existing->extra_tlvs); } else towire_bool(pptr, false); } @@ -192,6 +194,28 @@ void towire_shachain(u8 **pptr, const struct shachain *shachain) } } +static struct tlv_field *fromwire_len_and_tlvstream(const tal_t *ctx, + const u8 **cursor, size_t *max) +{ + struct tlv_field *tlvs = tal_arr(ctx, struct tlv_field, 0); + size_t len = fromwire_u16(cursor, max); + + /* Subtle: we are not using fromwire_tal_arrn here, which + * would do this. */ + if (len > *max) { + fromwire_fail(cursor, max); + return NULL; + } + + /* NOTE: We might consider to be more strict and only allow for + * known tlv types from the tlvs_tlv_update_add_htlc_tlvs + * record. */ + if (!fromwire_tlv(cursor, &len, NULL, 0, cast_const(void *, ctx), + &tlvs, FROMWIRE_TLV_ANY_TYPE, NULL, NULL)) + return tal_free(tlvs); + return tlvs; +} + void fromwire_added_htlc(const u8 **cursor, size_t *max, struct added_htlc *added) { @@ -207,17 +231,7 @@ void fromwire_added_htlc(const u8 **cursor, size_t *max, } else added->path_key = NULL; if (fromwire_bool(cursor, max)) { - size_t tlv_len = fromwire_u16(cursor, max); - /* NOTE: We might consider to be more strict and only allow for - * known tlv types from the tlvs_tlv_update_add_htlc_tlvs - * record. */ - const u64 *allowed = cast_const(u64 *, FROMWIRE_TLV_ANY_TYPE); - added->extra_tlvs = tal_arr(added, struct tlv_field, 0); - if (!fromwire_tlv(cursor, &tlv_len, NULL, 0, added, - &added->extra_tlvs, allowed, NULL, NULL)) { - tal_free(added->extra_tlvs); - added->extra_tlvs = NULL; - } + added->extra_tlvs = fromwire_len_and_tlvstream(added, cursor, max); } else added->extra_tlvs = NULL; added->fail_immediate = fromwire_bool(cursor, max); @@ -250,17 +264,7 @@ struct existing_htlc *fromwire_existing_htlc(const tal_t *ctx, } else existing->path_key = NULL; if (fromwire_bool(cursor, max)) { - size_t tlv_len = fromwire_u16(cursor, max); - /* NOTE: We might consider to be more strict and only allow for - * known tlv types from the tlvs_tlv_update_add_htlc_tlvs - * record. */ - const u64 *allowed = cast_const(u64 *, FROMWIRE_TLV_ANY_TYPE); - existing->extra_tlvs = tal_arr(existing, struct tlv_field, 0); - if (!fromwire_tlv(cursor, &tlv_len, NULL, 0, existing, - &existing->extra_tlvs, allowed, NULL, NULL)) { - tal_free(existing->extra_tlvs); - existing->extra_tlvs = NULL; - } + existing->extra_tlvs = fromwire_len_and_tlvstream(existing, cursor, max); } else existing->extra_tlvs = NULL; return existing;