use amount_msat_mul_div operation to compute fees

Changelog-None

Signed-off-by: Lagrang3 <lagrang3@protonmail.com>
This commit is contained in:
Lagrang3
2025-03-02 11:26:59 +01:00
committed by Rusty Russell
parent b379500d21
commit a899dea3e1
2 changed files with 65 additions and 6 deletions

View File

@@ -581,11 +581,9 @@ bool amount_msat_fee(struct amount_msat *fee,
* - fee_base_msat + ( amount_to_forward * fee_proportional_millionths / 1000000 )
*/
fee_base.millisatoshis = fee_base_msat;
if (mul_overflows_u64(amt.millisatoshis, fee_proportional_millionths))
if (!amount_msat_mul_div(&fee_prop, amt, fee_proportional_millionths,
1000000))
return false;
fee_prop.millisatoshis = amt.millisatoshis * fee_proportional_millionths
/ 1000000;
return amount_msat_add(fee, fee_base, fee_prop);
}

View File

@@ -85,6 +85,23 @@ static void test_amount_sub_fee(struct amount_msat msat,
assert(amount_msat_greater(in2, in));
}
static void test_amount_fee(struct amount_msat msat, u32 base, u32 prop,
struct amount_msat expected_msat)
{
struct amount_msat fee;
assert(amount_msat_fee(&fee, msat, base, prop));
assert(amount_msat_eq(fee, expected_msat));
}
static void test_amount_fee_str(const char *msat_str, u32 base, u32 prop,
u64 expected)
{
struct amount_msat msat;
struct amount_msat expected_msat = amount_msat(expected);
assert(parse_amount_msat(&msat, msat_str, strlen(msat_str)));
return test_amount_fee(msat, base, prop, expected_msat);
}
static void test_amount_with_fee(void)
{
for (int basebits = 0; basebits < 33; basebits++) {
@@ -93,13 +110,57 @@ static void test_amount_with_fee(void)
for (int propbits = 0; propbits < 20; propbits++) {
u32 prop = (1ULL << propbits);
for (int amtbits1 = 0; amtbits1 < 42; amtbits1++) {
for (int amtbits2 = 0; amtbits2 < 42; amtbits2++) {
for (int amtbits1 = 0; amtbits1 < 63; amtbits1++) {
for (int amtbits2 = 0; amtbits2 < 63; amtbits2++) {
test_amount_sub_fee(amount_msat((1ULL << amtbits1) | (1ULL << amtbits2)), base, prop);
}
}
}
}
for (int basebits = 0; basebits < 33; basebits++) {
u32 base = (1ULL << basebits);
for (int propbits = 0; propbits < 20; propbits++) {
u32 prop = (1ULL << propbits);
test_amount_fee(amount_msat(0), base, prop,
amount_msat(base));
}
}
for (int basebits = 0; basebits < 33; basebits++) {
u32 base = (1ULL << basebits);
for (int amtbits = 0; amtbits < 63; amtbits++)
test_amount_fee(amount_msat(1ULL << amtbits), base, 0,
amount_msat(base));
}
for (int amtbits = 0; amtbits < 63; amtbits++)
test_amount_fee(amount_msat(1ULL << amtbits), 0, 0,
amount_msat(0));
test_amount_fee_str("1msat", 1, 1, 1);
test_amount_fee_str("1msat", 1, 500000, 1);
test_amount_fee_str("1msat", 1, 1000000, 2);
test_amount_fee_str("1msat", 1234567890, 1, 1234567890ULL);
test_amount_fee_str("1msat", 1234567890, 500000, 1234567890ULL);
test_amount_fee_str("1msat", 1234567890, 1000000, 1234567891ULL);
test_amount_fee_str("1btc", 1, 1, 100001ULL);
test_amount_fee_str("1btc", 1, 500000, 50000000001ULL);
test_amount_fee_str("1btc", 1, 1000000, 100000000001ULL);
test_amount_fee_str("1btc", 1234567890, 1, 1234667890ULL);
test_amount_fee_str("1btc", 1234567890, 500000, 51234567890ULL);
test_amount_fee_str("1btc", 1234567890, 1000000, 101234567890ULL);
test_amount_fee_str("21000000btc", 1, 1, 2100000000001ULL);
test_amount_fee_str("21000000btc", 1, 500000, 1050000000000000001ULL);
test_amount_fee_str("21000000btc", 1, 1000000, 2100000000000000001ULL);
test_amount_fee_str("21000000btc", 1234567890, 1, 2101234567890ULL);
test_amount_fee_str("21000000btc", 1234567890, 500000,
1050000001234567890ULL);
test_amount_fee_str("21000000btc", 1234567890, 1000000,
2100000001234567890ULL);
}
#define FAIL_MSAT(msatp, str) \