diff --git a/common/msg_queue.c b/common/msg_queue.c index f3926ab67..ef562914e 100644 --- a/common/msg_queue.c +++ b/common/msg_queue.c @@ -1,5 +1,6 @@ #include "config.h" #include +#include #include #include #include @@ -9,11 +10,35 @@ struct msg_queue { const u8 **q; }; +static int extract_fd(const u8 *msg) +{ + const u8 *p = msg + sizeof(u16); + size_t len = tal_count(msg) - sizeof(u16); + + if (fromwire_peektype(msg) != MSG_PASS_FD) + return -1; + + return fromwire_u32(&p, &len); +} + +/* Close any fds left in queue! */ +static void destroy_msg_queue(struct msg_queue *q) +{ + for (size_t i = 0; i < tal_count(q->q); i++) { + int fd = extract_fd(q->q[i]); + if (fd != -1) + close(fd); + } +} + struct msg_queue *msg_queue_new(const tal_t *ctx, bool fd_passing) { struct msg_queue *q = tal(ctx, struct msg_queue); q->fd_passing = fd_passing; q->q = tal_arr(q, const u8 *, 0); + + if (q->fd_passing) + tal_add_destructor(q, destroy_msg_queue); return q; } @@ -62,14 +87,9 @@ const u8 *msg_dequeue(struct msg_queue *q) int msg_extract_fd(const struct msg_queue *q, const u8 *msg) { - const u8 *p = msg + sizeof(u16); - size_t len = tal_count(msg) - sizeof(u16); - assert(q->fd_passing); - if (fromwire_peektype(msg) != MSG_PASS_FD) - return -1; - return fromwire_u32(&p, &len); + return extract_fd(msg); } void msg_wake(const struct msg_queue *q)