diff --git a/src/basic/socket-util.h b/src/basic/socket-util.h index 95643135df..fba4efef81 100644 --- a/src/basic/socket-util.h +++ b/src/basic/socket-util.h @@ -158,6 +158,14 @@ int flush_accept(int fd); struct cmsghdr* cmsg_find(struct msghdr *mh, int level, int type, socklen_t length); +/* Type-safe, dereferencing version of cmsg_find() */ +#define CMSG_FIND_DATA(mh, level, type, ctype) \ + ({ \ + struct cmsghdr *_found; \ + _found = cmsg_find(mh, level, type, CMSG_LEN(sizeof(ctype))); \ + (ctype*) (_found ? CMSG_DATA(_found) : NULL); \ + }) + /* * Certain hardware address types (e.g Infiniband) do not fit into sll_addr * (8 bytes) and run over the structure. This macro returns the correct size that diff --git a/src/import/importd.c b/src/import/importd.c index 340b12ecd5..417e686bb9 100644 --- a/src/import/importd.c +++ b/src/import/importd.c @@ -557,9 +557,8 @@ static int manager_on_notify(sd_event_source *s, int fd, uint32_t revents, void .msg_control = &control, .msg_controllen = sizeof(control), }; - struct ucred *ucred = NULL; + struct ucred *ucred; Manager *m = userdata; - struct cmsghdr *cmsg; char *p, *e; Transfer *t; Iterator i; @@ -574,17 +573,12 @@ static int manager_on_notify(sd_event_source *s, int fd, uint32_t revents, void cmsg_close_all(&msghdr); - CMSG_FOREACH(cmsg, &msghdr) - if (cmsg->cmsg_level == SOL_SOCKET && - cmsg->cmsg_type == SCM_CREDENTIALS && - cmsg->cmsg_len == CMSG_LEN(sizeof(struct ucred))) - ucred = (struct ucred*) CMSG_DATA(cmsg); - if (msghdr.msg_flags & MSG_TRUNC) { log_warning("Got overly long notification datagram, ignoring."); return 0; } + ucred = CMSG_FIND_DATA(&msghdr, SOL_SOCKET, SCM_CREDENTIALS, struct ucred); if (!ucred || ucred->pid <= 0) { log_warning("Got notification datagram lacking credential information, ignoring."); return 0; diff --git a/src/journal/journald-stream.c b/src/journal/journald-stream.c index ec6dad62e8..202ac3cda2 100644 --- a/src/journal/journald-stream.c +++ b/src/journal/journald-stream.c @@ -491,8 +491,7 @@ static int stdout_stream_scan(StdoutStream *s, bool force_flush) { static int stdout_stream_process(sd_event_source *es, int fd, uint32_t revents, void *userdata) { uint8_t buf[CMSG_SPACE(sizeof(struct ucred))]; StdoutStream *s = userdata; - struct ucred *ucred = NULL; - struct cmsghdr *cmsg; + struct ucred *ucred; struct iovec iovec; size_t limit; ssize_t l; @@ -541,25 +540,14 @@ static int stdout_stream_process(sd_event_source *es, int fd, uint32_t revents, goto terminate; } - CMSG_FOREACH(cmsg, &msghdr) - if (cmsg->cmsg_level == SOL_SOCKET && - cmsg->cmsg_type == SCM_CREDENTIALS && - cmsg->cmsg_len == CMSG_LEN(sizeof(struct ucred))) { - assert(!ucred); - ucred = (struct ucred *)CMSG_DATA(cmsg); - break; - } - - /* Invalidate the context if the pid of the sender changed. - * This happens when a forked process inherits stdout / stderr - * from a parent. In this case getpeercred returns the ucred - * of the parent, which can be invalid if the parent has exited - * in the meantime. + /* Invalidate the context if the pid of the sender changed. This happens when a forked process + * inherits stdout / stderr from a parent. In this case getpeercred returns the ucred of the parent, + * which can be invalid if the parent has exited in the meantime. */ + ucred = CMSG_FIND_DATA(&msghdr, SOL_SOCKET, SCM_CREDENTIALS, struct ucred); if (ucred && ucred->pid != s->ucred.pid) { - /* force out any previously half-written lines from a - * different process, before we switch to the new ucred - * structure for everything we just added */ + /* force out any previously half-written lines from a different process, before we switch to + * the new ucred structure for everything we just added */ r = stdout_stream_scan(s, true); if (r < 0) goto terminate; diff --git a/src/nspawn/nspawn.c b/src/nspawn/nspawn.c index cc2f1521ff..c2148596b7 100644 --- a/src/nspawn/nspawn.c +++ b/src/nspawn/nspawn.c @@ -3698,8 +3698,7 @@ static int nspawn_dispatch_notify_fd(sd_event_source *source, int fd, uint32_t r .msg_control = &control, .msg_controllen = sizeof(control), }; - struct cmsghdr *cmsg; - struct ucred *ucred = NULL; + struct ucred *ucred; ssize_t n; pid_t inner_child_pid; _cleanup_strv_free_ char **tags = NULL; @@ -3721,15 +3720,7 @@ static int nspawn_dispatch_notify_fd(sd_event_source *source, int fd, uint32_t r cmsg_close_all(&msghdr); - CMSG_FOREACH(cmsg, &msghdr) { - if (cmsg->cmsg_level == SOL_SOCKET && - cmsg->cmsg_type == SCM_CREDENTIALS && - cmsg->cmsg_len == CMSG_LEN(sizeof(struct ucred))) { - - ucred = (struct ucred*) CMSG_DATA(cmsg); - } - } - + ucred = CMSG_FIND_DATA(&msghdr, SOL_SOCKET, SCM_CREDENTIALS, struct ucred); if (!ucred || ucred->pid != inner_child_pid) { log_debug("Received notify message without valid credentials. Ignoring."); return 0; diff --git a/src/shared/ask-password-api.c b/src/shared/ask-password-api.c index b7b7426058..7574a510c2 100644 --- a/src/shared/ask-password-api.c +++ b/src/shared/ask-password-api.c @@ -940,15 +940,12 @@ int ask_password_agent( continue; } - if (msghdr.msg_controllen < CMSG_LEN(sizeof(struct ucred)) || - control.cmsghdr.cmsg_level != SOL_SOCKET || - control.cmsghdr.cmsg_type != SCM_CREDENTIALS || - control.cmsghdr.cmsg_len != CMSG_LEN(sizeof(struct ucred))) { + ucred = CMSG_FIND_DATA(&msghdr, SOL_SOCKET, SCM_CREDENTIALS, struct ucred); + if (!ucred) { log_debug("Received message without credentials. Ignoring."); continue; } - ucred = (struct ucred*) CMSG_DATA(&control.cmsghdr); if (ucred->uid != 0) { log_debug("Got request from unprivileged user. Ignoring."); continue; diff --git a/src/udev/udevd.c b/src/udev/udevd.c index 4b15159c5b..cfda47f849 100644 --- a/src/udev/udevd.c +++ b/src/udev/udevd.c @@ -916,9 +916,8 @@ static int on_worker(sd_event_source *s, int fd, uint32_t revents, void *userdat .msg_control = &control, .msg_controllen = sizeof(control), }; - struct cmsghdr *cmsg; ssize_t size; - struct ucred *ucred = NULL; + struct ucred *ucred; struct worker *worker; size = recvmsg_safe(fd, &msghdr, MSG_DONTWAIT); @@ -937,12 +936,7 @@ static int on_worker(sd_event_source *s, int fd, uint32_t revents, void *userdat continue; } - CMSG_FOREACH(cmsg, &msghdr) - if (cmsg->cmsg_level == SOL_SOCKET && - cmsg->cmsg_type == SCM_CREDENTIALS && - cmsg->cmsg_len == CMSG_LEN(sizeof(struct ucred))) - ucred = (struct ucred*) CMSG_DATA(cmsg); - + ucred = CMSG_FIND_DATA(&msghdr, SOL_SOCKET, SCM_CREDENTIALS, struct ucred); if (!ucred || ucred->pid <= 0) { log_warning("Ignoring worker message without valid PID"); continue;