Merge pull request #9787 from obsidiansystems/bind-proc-syserror

`bind`: give same treatment as `connect` in #8544, dedup
This commit is contained in:
Théophane Hufschmitt 2024-01-18 09:34:15 +01:00 committed by GitHub
commit 28674247ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -38,52 +38,20 @@ AutoCloseFD createUnixDomainSocket(const Path & path, mode_t mode)
return fdSocket;
}
static struct sockaddr* safeSockAddrPointerCast(struct sockaddr_un *addr) {
// Casting between types like these legacy C library interfaces require
// is forbidden in C++.
// To maintain backwards compatibility, the implementation of the
// bind function contains some hints to the compiler that allow for this
static void bindConnectProcHelper(
std::string_view operationName, auto && operation,
int fd, const std::string & path)
{
struct sockaddr_un addr;
addr.sun_family = AF_UNIX;
// Casting between types like these legacy C library interfaces
// require is forbidden in C++. To maintain backwards
// compatibility, the implementation of the bind/connect functions
// contains some hints to the compiler that allow for this
// special case.
return reinterpret_cast<struct sockaddr *>(addr);
}
void bind(int fd, const std::string & path)
{
unlink(path.c_str());
struct sockaddr_un addr;
addr.sun_family = AF_UNIX;
auto psaddr {safeSockAddrPointerCast(&addr)};
if (path.size() + 1 >= sizeof(addr.sun_path)) {
Pid pid = startProcess([&] {
Path dir = dirOf(path);
if (chdir(dir.c_str()) == -1)
throw SysError("chdir to '%s' failed", dir);
std::string base(baseNameOf(path));
if (base.size() + 1 >= sizeof(addr.sun_path))
throw Error("socket path '%s' is too long", base);
memcpy(addr.sun_path, base.c_str(), base.size() + 1);
if (bind(fd, psaddr, sizeof(addr)) == -1)
throw SysError("cannot bind to socket '%s'", path);
_exit(0);
});
int status = pid.wait();
if (status != 0)
throw Error("cannot bind to socket '%s'", path);
} else {
memcpy(addr.sun_path, path.c_str(), path.size() + 1);
if (bind(fd, psaddr, sizeof(addr)) == -1)
throw SysError("cannot bind to socket '%s'", path);
}
}
void connect(int fd, const std::string & path)
{
struct sockaddr_un addr;
addr.sun_family = AF_UNIX;
auto psaddr {safeSockAddrPointerCast(&addr)};
auto * psaddr = reinterpret_cast<struct sockaddr *>(&addr);
if (path.size() + 1 >= sizeof(addr.sun_path)) {
Pipe pipe;
@ -98,8 +66,8 @@ void connect(int fd, const std::string & path)
if (base.size() + 1 >= sizeof(addr.sun_path))
throw Error("socket path '%s' is too long", base);
memcpy(addr.sun_path, base.c_str(), base.size() + 1);
if (connect(fd, psaddr, sizeof(addr)) == -1)
throw SysError("cannot connect to socket at '%s'", path);
if (operation(fd, psaddr, sizeof(addr)) == -1)
throw SysError("cannot %s to socket at '%s'", operationName, path);
writeFull(pipe.writeSide.get(), "0\n");
} catch (SysError & e) {
writeFull(pipe.writeSide.get(), fmt("%d\n", e.errNo));
@ -110,16 +78,30 @@ void connect(int fd, const std::string & path)
pipe.writeSide.close();
auto errNo = string2Int<int>(chomp(drainFD(pipe.readSide.get())));
if (!errNo || *errNo == -1)
throw Error("cannot connect to socket at '%s'", path);
throw Error("cannot %s to socket at '%s'", operationName, path);
else if (*errNo > 0) {
errno = *errNo;
throw SysError("cannot connect to socket at '%s'", path);
throw SysError("cannot %s to socket at '%s'", operationName, path);
}
} else {
memcpy(addr.sun_path, path.c_str(), path.size() + 1);
if (connect(fd, psaddr, sizeof(addr)) == -1)
throw SysError("cannot connect to socket at '%s'", path);
if (operation(fd, psaddr, sizeof(addr)) == -1)
throw SysError("cannot %s to socket at '%s'", operationName, path);
}
}
void bind(int fd, const std::string & path)
{
unlink(path.c_str());
bindConnectProcHelper("bind", ::bind, fd, path);
}
void connect(int fd, const std::string & path)
{
bindConnectProcHelper("connect", ::connect, fd, path);
}
}