diff --git a/src/libutil/unix-domain-socket.cc b/src/libutil/unix-domain-socket.cc index 8949461d2..05bbb5ba3 100644 --- a/src/libutil/unix-domain-socket.cc +++ b/src/libutil/unix-domain-socket.cc @@ -1,6 +1,7 @@ #include "file-system.hh" #include "processes.hh" #include "unix-domain-socket.hh" +#include "util.hh" #include #include @@ -75,21 +76,35 @@ void connect(int fd, const std::string & path) addr.sun_family = AF_UNIX; if (path.size() + 1 >= sizeof(addr.sun_path)) { + Pipe pipe; + pipe.create(); Pid pid = startProcess([&]() { - Path dir = dirOf(path); - if (chdir(dir.c_str()) == -1) - throw SysError("chdir to '%s' failed", dir); - std::string base(baseNameOf(path)); - if (base.size() + 1 >= sizeof(addr.sun_path)) - throw Error("socket path '%s' is too long", base); - memcpy(addr.sun_path, base.c_str(), base.size() + 1); - if (connect(fd, (struct sockaddr *) &addr, sizeof(addr)) == -1) - throw SysError("cannot connect to socket at '%s'", path); - _exit(0); + try { + pipe.readSide.close(); + Path dir = dirOf(path); + if (chdir(dir.c_str()) == -1) + throw SysError("chdir to '%s' failed", dir); + std::string base(baseNameOf(path)); + if (base.size() + 1 >= sizeof(addr.sun_path)) + throw Error("socket path '%s' is too long", base); + memcpy(addr.sun_path, base.c_str(), base.size() + 1); + if (connect(fd, (struct sockaddr *) &addr, sizeof(addr)) == -1) + throw SysError("cannot connect to socket at '%s'", path); + writeFull(pipe.writeSide.get(), "0\n"); + } catch (SysError & e) { + writeFull(pipe.writeSide.get(), fmt("%d\n", e.errNo)); + } catch (...) { + writeFull(pipe.writeSide.get(), "-1\n"); + } }); - int status = pid.wait(); - if (status != 0) + pipe.writeSide.close(); + auto errNo = string2Int(chomp(drainFD(pipe.readSide.get()))); + if (!errNo || *errNo == -1) throw Error("cannot connect to socket at '%s'", path); + else if (*errNo > 0) { + errno = *errNo; + throw SysError("cannot connect to socket at '%s'", path); + } } else { memcpy(addr.sun_path, path.c_str(), path.size() + 1); if (connect(fd, (struct sockaddr *) &addr, sizeof(addr)) == -1)