diff --git a/src/basic/socket-util.c b/src/basic/socket-util.c index a913102e13..1da0ed6616 100644 --- a/src/basic/socket-util.c +++ b/src/basic/socket-util.c @@ -1003,9 +1003,10 @@ int getpeergroups(int fd, gid_t **ret) { return (int) n; } -int send_one_fd_sa( +ssize_t send_one_fd_iov_sa( int transport_fd, int fd, + struct iovec *iov, size_t iovlen, const struct sockaddr *sa, socklen_t len, int flags) { @@ -1016,28 +1017,58 @@ int send_one_fd_sa( struct msghdr mh = { .msg_name = (struct sockaddr*) sa, .msg_namelen = len, - .msg_control = &control, - .msg_controllen = sizeof(control), + .msg_iov = iov, + .msg_iovlen = iovlen, }; - struct cmsghdr *cmsg; + ssize_t k; assert(transport_fd >= 0); - assert(fd >= 0); - cmsg = CMSG_FIRSTHDR(&mh); - cmsg->cmsg_level = SOL_SOCKET; - cmsg->cmsg_type = SCM_RIGHTS; - cmsg->cmsg_len = CMSG_LEN(sizeof(int)); - memcpy(CMSG_DATA(cmsg), &fd, sizeof(int)); + /* + * We need either an FD or data to send. + * If there's nothing, return an error. + */ + if (fd < 0 && !iov) + return -EINVAL; - mh.msg_controllen = CMSG_SPACE(sizeof(int)); - if (sendmsg(transport_fd, &mh, MSG_NOSIGNAL | flags) < 0) - return -errno; + if (fd >= 0) { + struct cmsghdr *cmsg; - return 0; + mh.msg_control = &control; + mh.msg_controllen = sizeof(control); + + cmsg = CMSG_FIRSTHDR(&mh); + cmsg->cmsg_level = SOL_SOCKET; + cmsg->cmsg_type = SCM_RIGHTS; + cmsg->cmsg_len = CMSG_LEN(sizeof(int)); + memcpy(CMSG_DATA(cmsg), &fd, sizeof(int)); + + mh.msg_controllen = CMSG_SPACE(sizeof(int)); + } + k = sendmsg(transport_fd, &mh, MSG_NOSIGNAL | flags); + if (k < 0) + return (ssize_t) -errno; + + return k; } -int receive_one_fd(int transport_fd, int flags) { +int send_one_fd_sa( + int transport_fd, + int fd, + const struct sockaddr *sa, socklen_t len, + int flags) { + + assert(fd >= 0); + + return (int) send_one_fd_iov_sa(transport_fd, fd, NULL, 0, sa, len, flags); +} + +ssize_t receive_one_fd_iov( + int transport_fd, + struct iovec *iov, size_t iovlen, + int flags, + int *ret_fd) { + union { struct cmsghdr cmsghdr; uint8_t buf[CMSG_SPACE(sizeof(int))]; @@ -1045,10 +1076,14 @@ int receive_one_fd(int transport_fd, int flags) { struct msghdr mh = { .msg_control = &control, .msg_controllen = sizeof(control), + .msg_iov = iov, + .msg_iovlen = iovlen, }; struct cmsghdr *cmsg, *found = NULL; + ssize_t k; assert(transport_fd >= 0); + assert(ret_fd); /* * Receive a single FD via @transport_fd. We don't care for @@ -1058,8 +1093,9 @@ int receive_one_fd(int transport_fd, int flags) { * combination with send_one_fd(). */ - if (recvmsg(transport_fd, &mh, MSG_CMSG_CLOEXEC | flags) < 0) - return -errno; + k = recvmsg(transport_fd, &mh, MSG_CMSG_CLOEXEC | flags); + if (k < 0) + return (ssize_t) -errno; CMSG_FOREACH(cmsg, &mh) { if (cmsg->cmsg_level == SOL_SOCKET && @@ -1071,12 +1107,33 @@ int receive_one_fd(int transport_fd, int flags) { } } - if (!found) { + if (!found) cmsg_close_all(&mh); - return -EIO; - } - return *(int*) CMSG_DATA(found); + /* If didn't receive an FD or any data, return an error. */ + if (k == 0 && !found) + return -EIO; + + if (found) + *ret_fd = *(int*) CMSG_DATA(found); + else + *ret_fd = -1; + + return k; +} + +int receive_one_fd(int transport_fd, int flags) { + int fd; + ssize_t k; + + k = receive_one_fd_iov(transport_fd, NULL, 0, flags, &fd); + if (k == 0) + return fd; + + /* k must be negative, since receive_one_fd_iov() only returns + * a positive value if data was received through the iov. */ + assert(k < 0); + return (int) k; } ssize_t next_datagram_size_fd(int fd) { diff --git a/src/basic/socket-util.h b/src/basic/socket-util.h index 8e23cf2dbd..82781a0de1 100644 --- a/src/basic/socket-util.h +++ b/src/basic/socket-util.h @@ -130,11 +130,19 @@ int getpeercred(int fd, struct ucred *ucred); int getpeersec(int fd, char **ret); int getpeergroups(int fd, gid_t **ret); +ssize_t send_one_fd_iov_sa( + int transport_fd, + int fd, + struct iovec *iov, size_t iovlen, + const struct sockaddr *sa, socklen_t len, + int flags); int send_one_fd_sa(int transport_fd, int fd, const struct sockaddr *sa, socklen_t len, int flags); -#define send_one_fd(transport_fd, fd, flags) send_one_fd_sa(transport_fd, fd, NULL, 0, flags) +#define send_one_fd_iov(transport_fd, fd, iov, iovlen, flags) send_one_fd_iov_sa(transport_fd, fd, iov, iovlen, NULL, 0, flags) +#define send_one_fd(transport_fd, fd, flags) send_one_fd_iov_sa(transport_fd, fd, NULL, 0, NULL, 0, flags) +ssize_t receive_one_fd_iov(int transport_fd, struct iovec *iov, size_t iovlen, int flags, int *ret_fd); int receive_one_fd(int transport_fd, int flags); ssize_t next_datagram_size_fd(int fd); diff --git a/src/core/dynamic-user.c b/src/core/dynamic-user.c index f380db553e..5e20783102 100644 --- a/src/core/dynamic-user.c +++ b/src/core/dynamic-user.c @@ -312,20 +312,8 @@ static int pick_uid(char **suggested_paths, const char *name, uid_t *ret_uid) { static int dynamic_user_pop(DynamicUser *d, uid_t *ret_uid, int *ret_lock_fd) { uid_t uid = UID_INVALID; struct iovec iov = IOVEC_INIT(&uid, sizeof(uid)); - union { - struct cmsghdr cmsghdr; - uint8_t buf[CMSG_SPACE(sizeof(int))]; - } control = {}; - struct msghdr mh = { - .msg_control = &control, - .msg_controllen = sizeof(control), - .msg_iov = &iov, - .msg_iovlen = 1, - }; - struct cmsghdr *cmsg; - + int lock_fd; ssize_t k; - int lock_fd = -1; assert(d); assert(ret_uid); @@ -334,15 +322,9 @@ static int dynamic_user_pop(DynamicUser *d, uid_t *ret_uid, int *ret_lock_fd) { /* Read the UID and lock fd that is stored in the storage AF_UNIX socket. This should be called with the lock * on the socket taken. */ - k = recvmsg(d->storage_socket[0], &mh, MSG_DONTWAIT|MSG_CMSG_CLOEXEC); + k = receive_one_fd_iov(d->storage_socket[0], &iov, 1, MSG_DONTWAIT, &lock_fd); if (k < 0) - return -errno; - - cmsg = cmsg_find(&mh, SOL_SOCKET, SCM_RIGHTS, CMSG_LEN(sizeof(int))); - if (cmsg) - lock_fd = *(int*) CMSG_DATA(cmsg); - else - cmsg_close_all(&mh); /* just in case... */ + return (int) k; *ret_uid = uid; *ret_lock_fd = lock_fd; @@ -352,42 +334,11 @@ static int dynamic_user_pop(DynamicUser *d, uid_t *ret_uid, int *ret_lock_fd) { static int dynamic_user_push(DynamicUser *d, uid_t uid, int lock_fd) { struct iovec iov = IOVEC_INIT(&uid, sizeof(uid)); - union { - struct cmsghdr cmsghdr; - uint8_t buf[CMSG_SPACE(sizeof(int))]; - } control = {}; - struct msghdr mh = { - .msg_control = &control, - .msg_controllen = sizeof(control), - .msg_iov = &iov, - .msg_iovlen = 1, - }; - ssize_t k; assert(d); /* Store the UID and lock_fd in the storage socket. This should be called with the socket pair lock taken. */ - - if (lock_fd >= 0) { - struct cmsghdr *cmsg; - - cmsg = CMSG_FIRSTHDR(&mh); - cmsg->cmsg_level = SOL_SOCKET; - cmsg->cmsg_type = SCM_RIGHTS; - cmsg->cmsg_len = CMSG_LEN(sizeof(int)); - memcpy(CMSG_DATA(cmsg), &lock_fd, sizeof(int)); - - mh.msg_controllen = CMSG_SPACE(sizeof(int)); - } else { - mh.msg_control = NULL; - mh.msg_controllen = 0; - } - - k = sendmsg(d->storage_socket[1], &mh, MSG_DONTWAIT|MSG_NOSIGNAL); - if (k < 0) - return -errno; - - return 0; + return send_one_fd_iov(d->storage_socket[1], lock_fd, &iov, 1, MSG_DONTWAIT); } static void unlink_uid_lock(int lock_fd, uid_t uid, const char *name) { diff --git a/src/test/test-socket-util.c b/src/test/test-socket-util.c index ac2ea52a5c..8099f13703 100644 --- a/src/test/test-socket-util.c +++ b/src/test/test-socket-util.c @@ -6,8 +6,11 @@ #include "alloc-util.h" #include "async.h" +#include "exit-status.h" #include "fd-util.h" +#include "fileio.h" #include "in-addr-util.h" +#include "io-util.h" #include "log.h" #include "macro.h" #include "process-util.h" @@ -481,9 +484,215 @@ static void test_getpeercred_getpeergroups(void) { } safe_close_pair(pair); + _exit(EXIT_SUCCESS); } } +static void test_passfd_read(void) { + static const char file_contents[] = "test contents for passfd"; + _cleanup_close_pair_ int pair[2] = { -1, -1 }; + int r; + + assert_se(socketpair(AF_UNIX, SOCK_DGRAM, 0, pair) >= 0); + + r = safe_fork("(passfd_read)", FORK_DEATHSIG|FORK_LOG|FORK_WAIT, NULL); + assert_se(r >= 0); + + if (r == 0) { + /* Child */ + char tmpfile[] = "/tmp/test-socket-util-passfd-read-XXXXXX"; + _cleanup_close_ int tmpfd = -1; + + pair[0] = safe_close(pair[0]); + + tmpfd = mkostemp_safe(tmpfile); + assert_se(tmpfd >= 0); + assert_se(write(tmpfd, file_contents, strlen(file_contents)) == (ssize_t) strlen(file_contents)); + tmpfd = safe_close(tmpfd); + + tmpfd = open(tmpfile, O_RDONLY); + assert_se(tmpfd >= 0); + assert_se(unlink(tmpfile) == 0); + + assert_se(send_one_fd(pair[1], tmpfd, MSG_DONTWAIT) == 0); + _exit(EXIT_SUCCESS); + } + + /* Parent */ + char buf[64]; + struct iovec iov = IOVEC_INIT(buf, sizeof(buf)-1); + _cleanup_close_ int fd = -1; + + pair[1] = safe_close(pair[1]); + + assert_se(receive_one_fd_iov(pair[0], &iov, 1, MSG_DONTWAIT, &fd) == 0); + + assert_se(fd >= 0); + r = read(fd, buf, sizeof(buf)-1); + assert_se(r >= 0); + buf[r] = 0; + assert_se(streq(buf, file_contents)); +} + +static void test_passfd_contents_read(void) { + _cleanup_close_pair_ int pair[2] = { -1, -1 }; + static const char file_contents[] = "test contents in the file"; + static const char wire_contents[] = "test contents on the wire"; + int r; + + assert_se(socketpair(AF_UNIX, SOCK_DGRAM, 0, pair) >= 0); + + r = safe_fork("(passfd_contents_read)", FORK_DEATHSIG|FORK_LOG|FORK_WAIT, NULL); + assert_se(r >= 0); + + if (r == 0) { + /* Child */ + struct iovec iov = IOVEC_INIT_STRING(wire_contents); + char tmpfile[] = "/tmp/test-socket-util-passfd-contents-read-XXXXXX"; + _cleanup_close_ int tmpfd = -1; + + pair[0] = safe_close(pair[0]); + + tmpfd = mkostemp_safe(tmpfile); + assert_se(tmpfd >= 0); + assert_se(write(tmpfd, file_contents, strlen(file_contents)) == (ssize_t) strlen(file_contents)); + tmpfd = safe_close(tmpfd); + + tmpfd = open(tmpfile, O_RDONLY); + assert_se(tmpfd >= 0); + assert_se(unlink(tmpfile) == 0); + + assert_se(send_one_fd_iov(pair[1], tmpfd, &iov, 1, MSG_DONTWAIT) > 0); + _exit(EXIT_SUCCESS); + } + + /* Parent */ + char buf[64]; + struct iovec iov = IOVEC_INIT(buf, sizeof(buf)-1); + _cleanup_close_ int fd = -1; + ssize_t k; + + pair[1] = safe_close(pair[1]); + + k = receive_one_fd_iov(pair[0], &iov, 1, MSG_DONTWAIT, &fd); + assert_se(k > 0); + buf[k] = 0; + assert_se(streq(buf, wire_contents)); + + assert_se(fd >= 0); + r = read(fd, buf, sizeof(buf)-1); + assert_se(r >= 0); + buf[r] = 0; + assert_se(streq(buf, file_contents)); +} + +static void test_receive_nopassfd(void) { + _cleanup_close_pair_ int pair[2] = { -1, -1 }; + static const char wire_contents[] = "no fd passed here"; + int r; + + assert_se(socketpair(AF_UNIX, SOCK_DGRAM, 0, pair) >= 0); + + r = safe_fork("(receive_nopassfd)", FORK_DEATHSIG|FORK_LOG|FORK_WAIT, NULL); + assert_se(r >= 0); + + if (r == 0) { + /* Child */ + struct iovec iov = IOVEC_INIT_STRING(wire_contents); + + pair[0] = safe_close(pair[0]); + + assert_se(send_one_fd_iov(pair[1], -1, &iov, 1, MSG_DONTWAIT) > 0); + _exit(EXIT_SUCCESS); + } + + /* Parent */ + char buf[64]; + struct iovec iov = IOVEC_INIT(buf, sizeof(buf)-1); + int fd = -999; + ssize_t k; + + pair[1] = safe_close(pair[1]); + + k = receive_one_fd_iov(pair[0], &iov, 1, MSG_DONTWAIT, &fd); + assert_se(k > 0); + buf[k] = 0; + assert_se(streq(buf, wire_contents)); + + /* no fd passed here, confirm it was reset */ + assert_se(fd == -1); +} + +static void test_send_nodata_nofd(void) { + _cleanup_close_pair_ int pair[2] = { -1, -1 }; + int r; + + assert_se(socketpair(AF_UNIX, SOCK_DGRAM, 0, pair) >= 0); + + r = safe_fork("(send_nodata_nofd)", FORK_DEATHSIG|FORK_LOG|FORK_WAIT, NULL); + assert_se(r >= 0); + + if (r == 0) { + /* Child */ + pair[0] = safe_close(pair[0]); + + assert_se(send_one_fd_iov(pair[1], -1, NULL, 0, MSG_DONTWAIT) == -EINVAL); + _exit(EXIT_SUCCESS); + } + + /* Parent */ + char buf[64]; + struct iovec iov = IOVEC_INIT(buf, sizeof(buf)-1); + int fd = -999; + ssize_t k; + + pair[1] = safe_close(pair[1]); + + k = receive_one_fd_iov(pair[0], &iov, 1, MSG_DONTWAIT, &fd); + /* recvmsg() will return errno EAGAIN if nothing was sent */ + assert_se(k == -EAGAIN); + + /* receive_one_fd_iov returned error, so confirm &fd wasn't touched */ + assert_se(fd == -999); +} + +static void test_send_emptydata(void) { + _cleanup_close_pair_ int pair[2] = { -1, -1 }; + int r; + + assert_se(socketpair(AF_UNIX, SOCK_DGRAM, 0, pair) >= 0); + + r = safe_fork("(send_emptydata)", FORK_DEATHSIG|FORK_LOG|FORK_WAIT, NULL); + assert_se(r >= 0); + + if (r == 0) { + /* Child */ + struct iovec iov = IOVEC_INIT_STRING(""); /* zero-length iov */ + assert_se(iov.iov_len == 0); + + pair[0] = safe_close(pair[0]); + + /* This will succeed, since iov is set. */ + assert_se(send_one_fd_iov(pair[1], -1, &iov, 1, MSG_DONTWAIT) == 0); + _exit(EXIT_SUCCESS); + } + + /* Parent */ + char buf[64]; + struct iovec iov = IOVEC_INIT(buf, sizeof(buf)-1); + int fd = -999; + ssize_t k; + + pair[1] = safe_close(pair[1]); + + k = receive_one_fd_iov(pair[0], &iov, 1, MSG_DONTWAIT, &fd); + /* receive_one_fd_iov() returns -EIO if an fd is not found and no data was returned. */ + assert_se(k == -EIO); + + /* receive_one_fd_iov returned error, so confirm &fd wasn't touched */ + assert_se(fd == -999); +} + int main(int argc, char *argv[]) { log_set_max_level(LOG_DEBUG); @@ -512,5 +721,11 @@ int main(int argc, char *argv[]) { test_getpeercred_getpeergroups(); + test_passfd_read(); + test_passfd_contents_read(); + test_receive_nopassfd(); + test_send_nodata_nofd(); + test_send_emptydata(); + return 0; }