diff --git a/common/amount.c b/common/amount.c index 4ded1a2ac..beeaa14b4 100644 --- a/common/amount.c +++ b/common/amount.c @@ -581,11 +581,9 @@ bool amount_msat_fee(struct amount_msat *fee, * - fee_base_msat + ( amount_to_forward * fee_proportional_millionths / 1000000 ) */ fee_base.millisatoshis = fee_base_msat; - - if (mul_overflows_u64(amt.millisatoshis, fee_proportional_millionths)) + if (!amount_msat_mul_div(&fee_prop, amt, fee_proportional_millionths, + 1000000)) return false; - fee_prop.millisatoshis = amt.millisatoshis * fee_proportional_millionths - / 1000000; return amount_msat_add(fee, fee_base, fee_prop); } diff --git a/common/test/run-amount.c b/common/test/run-amount.c index 7b408959d..5f8a96a0b 100644 --- a/common/test/run-amount.c +++ b/common/test/run-amount.c @@ -85,6 +85,23 @@ static void test_amount_sub_fee(struct amount_msat msat, assert(amount_msat_greater(in2, in)); } +static void test_amount_fee(struct amount_msat msat, u32 base, u32 prop, + struct amount_msat expected_msat) +{ + struct amount_msat fee; + assert(amount_msat_fee(&fee, msat, base, prop)); + assert(amount_msat_eq(fee, expected_msat)); +} + +static void test_amount_fee_str(const char *msat_str, u32 base, u32 prop, + u64 expected) +{ + struct amount_msat msat; + struct amount_msat expected_msat = amount_msat(expected); + assert(parse_amount_msat(&msat, msat_str, strlen(msat_str))); + return test_amount_fee(msat, base, prop, expected_msat); +} + static void test_amount_with_fee(void) { for (int basebits = 0; basebits < 33; basebits++) { @@ -93,13 +110,57 @@ static void test_amount_with_fee(void) for (int propbits = 0; propbits < 20; propbits++) { u32 prop = (1ULL << propbits); - for (int amtbits1 = 0; amtbits1 < 42; amtbits1++) { - for (int amtbits2 = 0; amtbits2 < 42; amtbits2++) { + for (int amtbits1 = 0; amtbits1 < 63; amtbits1++) { + for (int amtbits2 = 0; amtbits2 < 63; amtbits2++) { test_amount_sub_fee(amount_msat((1ULL << amtbits1) | (1ULL << amtbits2)), base, prop); } } } } + + for (int basebits = 0; basebits < 33; basebits++) { + u32 base = (1ULL << basebits); + for (int propbits = 0; propbits < 20; propbits++) { + u32 prop = (1ULL << propbits); + test_amount_fee(amount_msat(0), base, prop, + amount_msat(base)); + } + } + for (int basebits = 0; basebits < 33; basebits++) { + u32 base = (1ULL << basebits); + for (int amtbits = 0; amtbits < 63; amtbits++) + test_amount_fee(amount_msat(1ULL << amtbits), base, 0, + amount_msat(base)); + } + for (int amtbits = 0; amtbits < 63; amtbits++) + test_amount_fee(amount_msat(1ULL << amtbits), 0, 0, + amount_msat(0)); + + test_amount_fee_str("1msat", 1, 1, 1); + test_amount_fee_str("1msat", 1, 500000, 1); + test_amount_fee_str("1msat", 1, 1000000, 2); + + test_amount_fee_str("1msat", 1234567890, 1, 1234567890ULL); + test_amount_fee_str("1msat", 1234567890, 500000, 1234567890ULL); + test_amount_fee_str("1msat", 1234567890, 1000000, 1234567891ULL); + + test_amount_fee_str("1btc", 1, 1, 100001ULL); + test_amount_fee_str("1btc", 1, 500000, 50000000001ULL); + test_amount_fee_str("1btc", 1, 1000000, 100000000001ULL); + + test_amount_fee_str("1btc", 1234567890, 1, 1234667890ULL); + test_amount_fee_str("1btc", 1234567890, 500000, 51234567890ULL); + test_amount_fee_str("1btc", 1234567890, 1000000, 101234567890ULL); + + test_amount_fee_str("21000000btc", 1, 1, 2100000000001ULL); + test_amount_fee_str("21000000btc", 1, 500000, 1050000000000000001ULL); + test_amount_fee_str("21000000btc", 1, 1000000, 2100000000000000001ULL); + + test_amount_fee_str("21000000btc", 1234567890, 1, 2101234567890ULL); + test_amount_fee_str("21000000btc", 1234567890, 500000, + 1050000001234567890ULL); + test_amount_fee_str("21000000btc", 1234567890, 1000000, + 2100000001234567890ULL); } #define FAIL_MSAT(msatp, str) \