diff --git a/common/dev_disconnect.c b/common/dev_disconnect.c index 76df764b9..2feaf9278 100644 --- a/common/dev_disconnect.c +++ b/common/dev_disconnect.c @@ -59,6 +59,7 @@ enum dev_disconnect_out dev_disconnect_out(const struct node_id *id, int pkt_typ next_dev_disconnect(); if (!dev_disconnect_line[0] + || dev_disconnect_line[0] == DEV_DISCONNECT_IN_AFTER_RECV || !streq(peer_wire_name(pkt_type), dev_disconnect_line+1)) return DEV_DISCONNECT_OUT_NORMAL; @@ -76,6 +77,32 @@ enum dev_disconnect_out dev_disconnect_out(const struct node_id *id, int pkt_typ return dev_disconnect_line[0]; } +enum dev_disconnect_in dev_disconnect_in(const struct node_id *id, int pkt_type) +{ + if (dev_disconnect_fd == -1) + return DEV_DISCONNECT_IN_NORMAL; + + if (!dev_disconnect_count) + next_dev_disconnect(); + + if (dev_disconnect_line[0] != DEV_DISCONNECT_IN_AFTER_RECV + || !streq(peer_wire_name(pkt_type), dev_disconnect_line+1)) + return DEV_DISCONNECT_IN_NORMAL; + + if (--dev_disconnect_count != 0) { + return DEV_DISCONNECT_IN_NORMAL; + } + + if (lseek(dev_disconnect_fd, dev_disconnect_len+1, SEEK_CUR) < 0) { + err(1, "lseek failure"); + } + + status_peer_debug(id, "dev_disconnect: %s (%s)", + dev_disconnect_line, + peer_wire_name(pkt_type)); + return dev_disconnect_line[0]; +} + void dev_sabotage_fd(int fd, bool close_fd) { int fds[2]; diff --git a/common/dev_disconnect.h b/common/dev_disconnect.h index 333ebf26f..0237cd25f 100644 --- a/common/dev_disconnect.h +++ b/common/dev_disconnect.h @@ -23,6 +23,16 @@ enum dev_disconnect_out { /* Force a close fd before or after a certain packet type */ enum dev_disconnect_out dev_disconnect_out(const struct node_id *id, int pkt_type); +enum dev_disconnect_in { + /* Do nothing. */ + DEV_DISCONNECT_IN_NORMAL = '=', + /* Close connection after receiving packet. */ + DEV_DISCONNECT_IN_AFTER_RECV = '<', +}; + +/* Force a close fd after receiving a certain packet type */ +enum dev_disconnect_in dev_disconnect_in(const struct node_id *id, int pkt_type); + /* Make next write on fd fail as if they'd disconnected. */ void dev_sabotage_fd(int fd, bool close_fd);