From e7685a77b41bbd1b8289aeaf75fccaf4bb68a361 Mon Sep 17 00:00:00 2001 From: Lennart Poettering Date: Mon, 26 Feb 2018 15:41:38 +0100 Subject: [PATCH] util: add new safe_close_above_stdio() wrapper At various places we only want to close fds if they are not stdin/stdout/stderr, i.e. fds 0, 1, 2. Let's add a unified helper call for that, and port everything over. --- coccinelle/close-above-stdio.cocci | 36 ++++++++++++++++++++++++++++++ src/basic/fd-util.h | 7 ++++++ src/basic/log.c | 9 +------- src/basic/process-util.c | 3 +-- src/basic/terminal-util.c | 3 +-- src/journal/cat.c | 4 +--- src/nspawn/nspawn-setuid.c | 9 +++----- 7 files changed, 50 insertions(+), 21 deletions(-) create mode 100644 coccinelle/close-above-stdio.cocci diff --git a/coccinelle/close-above-stdio.cocci b/coccinelle/close-above-stdio.cocci new file mode 100644 index 0000000000..44b3b1c9f1 --- /dev/null +++ b/coccinelle/close-above-stdio.cocci @@ -0,0 +1,36 @@ +@@ +expression fd; +@@ +- if (fd > 2) +- safe_close(fd); ++ safe_close_above_stdio(fd); +@@ +expression fd; +@@ +- if (fd > 2) +- fd = safe_close(fd); ++ fd = safe_close_above_stdio(fd); +@@ +expression fd; +@@ +- if (fd >= 3) +- safe_close(fd); ++ safe_close_above_stdio(fd); +@@ +expression fd; +@@ +- if (fd >= 3) +- fd = safe_close(fd); ++ fd = safe_close_above_stdio(fd); +@@ +expression fd; +@@ +- if (fd > STDERR_FILENO) +- safe_close(fd); ++ safe_close_above_stdio(fd); +@@ +expression fd; +@@ +- if (fd > STDERR_FILENO) +- fd = safe_close(fd); ++ fd = safe_close_above_stdio(fd); diff --git a/src/basic/fd-util.h b/src/basic/fd-util.h index 284856ae6d..4e8d9bc40a 100644 --- a/src/basic/fd-util.h +++ b/src/basic/fd-util.h @@ -35,6 +35,13 @@ int close_nointr(int fd); int safe_close(int fd); void safe_close_pair(int p[]); +static inline int safe_close_above_stdio(int fd) { + if (fd < 3) /* Don't close stdin/stdout/stderr, but still invalidate the fd by returning -1 */ + return -1; + + return safe_close(fd); +} + void close_many(const int fds[], unsigned n_fd); int fclose_nointr(FILE *f); diff --git a/src/basic/log.c b/src/basic/log.c index 72b60da6c6..7a7f2cbec1 100644 --- a/src/basic/log.c +++ b/src/basic/log.c @@ -94,14 +94,7 @@ static char *log_abort_msg = NULL; } while (false) static void log_close_console(void) { - - if (console_fd < 0) - return; - - if (console_fd >= 3) - safe_close(console_fd); - - console_fd = -1; + console_fd = safe_close_above_stdio(console_fd); } static int log_open_console(void) { diff --git a/src/basic/process-util.c b/src/basic/process-util.c index aa41b3b686..66a7557fba 100644 --- a/src/basic/process-util.c +++ b/src/basic/process-util.c @@ -1428,8 +1428,7 @@ int fork_agent(const char *name, const int except[], unsigned n_except, pid_t *r _exit(EXIT_FAILURE); } - if (fd > STDERR_FILENO) - close(fd); + safe_close_above_stdio(fd); } /* Count arguments */ diff --git a/src/basic/terminal-util.c b/src/basic/terminal-util.c index cddbb461bd..cdad4cb621 100644 --- a/src/basic/terminal-util.c +++ b/src/basic/terminal-util.c @@ -917,8 +917,7 @@ int make_stdio(int fd) { if (dup2(fd, STDERR_FILENO) < 0 && r >= 0) r = -errno; - if (fd >= 3) - safe_close(fd); + safe_close_above_stdio(fd); /* Explicitly unset O_CLOEXEC, since if fd was < 3, then dup2() was a NOP and the bit hence possibly set. */ stdio_unset_cloexec(); diff --git a/src/journal/cat.c b/src/journal/cat.c index b2f9ed5010..c87a149a4c 100644 --- a/src/journal/cat.c +++ b/src/journal/cat.c @@ -141,9 +141,7 @@ int main(int argc, char *argv[]) { goto finish; } - if (fd >= 3) - safe_close(fd); - fd = -1; + fd = safe_close_above_stdio(fd); if (argc <= optind) (void) execl("/bin/cat", "/bin/cat", NULL); diff --git a/src/nspawn/nspawn-setuid.c b/src/nspawn/nspawn-setuid.c index 8f2359ad91..c4ad172512 100644 --- a/src/nspawn/nspawn-setuid.c +++ b/src/nspawn/nspawn-setuid.c @@ -57,10 +57,8 @@ static int spawn_getent(const char *database, const char *key, pid_t *rpid) { if (dup3(pipe_fds[1], STDOUT_FILENO, 0) < 0) _exit(EXIT_FAILURE); - if (pipe_fds[0] > 2) - safe_close(pipe_fds[0]); - if (pipe_fds[1] > 2) - safe_close(pipe_fds[1]); + safe_close_above_stdio(pipe_fds[0]); + safe_close_above_stdio(pipe_fds[1]); nullfd = open("/dev/null", O_RDWR); if (nullfd < 0) @@ -72,8 +70,7 @@ static int spawn_getent(const char *database, const char *key, pid_t *rpid) { if (dup3(nullfd, STDERR_FILENO, 0) < 0) _exit(EXIT_FAILURE); - if (nullfd > 2) - safe_close(nullfd); + safe_close_above_stdio(nullfd); close_all_fds(NULL, 0);