From b8cfa2da7c103d612b9608ade580c0a41538450c Mon Sep 17 00:00:00 2001 From: Lennart Poettering Date: Tue, 13 Oct 2020 18:06:45 +0200 Subject: [PATCH] fd-util: port close_all_fds() to close_range() --- src/basic/fd-util.c | 90 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/src/basic/fd-util.c b/src/basic/fd-util.c index db869cbd54..e37b6944a8 100644 --- a/src/basic/fd-util.c +++ b/src/basic/fd-util.c @@ -21,6 +21,7 @@ #include "path-util.h" #include "process-util.h" #include "socket-util.h" +#include "sort-util.h" #include "stat-util.h" #include "stdio-util.h" #include "tmpfile-util.h" @@ -210,13 +211,102 @@ static int get_max_fd(void) { return (int) (m - 1); } +static int cmp_int(const int *a, const int *b) { + return CMP(*a, *b); +} + int close_all_fds(const int except[], size_t n_except) { + static bool have_close_range = true; /* Assume we live in the future */ _cleanup_closedir_ DIR *d = NULL; struct dirent *de; int r = 0; assert(n_except == 0 || except); + if (have_close_range) { + /* In the best case we have close_range() to close all fds between a start and an end fd, + * which we can use on the "inverted" exception array, i.e. all intervals between all + * adjacent pairs from the sorted exception array. This changes loop complexity from O(n) + * where n is number of open fds to O(m⋅log(m)) where m is the number of fds to keep + * open. Given that we assume n ≫ m that's preferable to us. */ + + if (n_except == 0) { + /* Close everything. Yay! */ + + if (close_range(3, -1, 0) >= 0) + return 1; + + if (!ERRNO_IS_NOT_SUPPORTED(errno) && !ERRNO_IS_PRIVILEGE(errno)) + return -errno; + + have_close_range = false; + } else { + _cleanup_free_ int *sorted_malloc = NULL; + size_t n_sorted; + int *sorted; + + assert(n_except < SIZE_MAX); + n_sorted = n_except + 1; + + if (n_sorted > 64) /* Use heap for large numbers of fds, stack otherwise */ + sorted = sorted_malloc = new(int, n_sorted); + else + sorted = newa(int, n_sorted); + + if (sorted) { + int c = 0; + + memcpy(sorted, except, n_except * sizeof(int)); + + /* Let's add fd 2 to the list of fds, to simplify the loop below, as this + * allows us to cover the head of the array the same way as the body */ + sorted[n_sorted-1] = 2; + + typesafe_qsort(sorted, n_sorted, cmp_int); + + for (size_t i = 0; i < n_sorted-1; i++) { + int start, end; + + start = MAX(sorted[i], 2); /* The first three fds shall always remain open */ + end = MAX(sorted[i+1], 2); + + assert(end >= start); + + if (end - start <= 1) + continue; + + /* Close everything between the start and end fds (both of which shall stay open) */ + if (close_range(start + 1, end - 1, 0) < 0) { + if (!ERRNO_IS_NOT_SUPPORTED(errno) && !ERRNO_IS_PRIVILEGE(errno)) + return -errno; + + have_close_range = false; + break; + } + + c += end - start - 1; + } + + if (have_close_range) { + /* The loop succeeded. Let's now close everything beyond the end */ + + if (sorted[n_sorted-1] >= INT_MAX) /* Dont let the addition below overflow */ + return c; + + if (close_range(sorted[n_sorted-1] + 1, -1, 0) >= 0) + return c + 1; + + if (!ERRNO_IS_NOT_SUPPORTED(errno) && !ERRNO_IS_PRIVILEGE(errno)) + return -errno; + + have_close_range = false; + } + } + } + + /* Fallback on OOM or if close_range() is not supported */ + } + d = opendir("/proc/self/fd"); if (!d) { int fd, max_fd;