diff --git a/src/libstore/legacy-ssh-store.cc b/src/libstore/legacy-ssh-store.cc index 06bef9d08..4f020c452 100644 --- a/src/libstore/legacy-ssh-store.cc +++ b/src/libstore/legacy-ssh-store.cc @@ -22,45 +22,10 @@ std::string LegacySSHStoreConfig::doc() } -struct LegacySSHStore::Connection +struct LegacySSHStore::Connection : public ServeProto::BasicClientConnection { std::unique_ptr sshConn; - FdSink to; - FdSource from; - ServeProto::Version remoteVersion; bool good = true; - - /** - * Coercion to `ServeProto::ReadConn`. This makes it easy to use the - * factored out serve protocol searlizers with a - * `LegacySSHStore::Connection`. - * - * The serve protocol connection types are unidirectional, unlike - * this type. - */ - operator ServeProto::ReadConn () - { - return ServeProto::ReadConn { - .from = from, - .version = remoteVersion, - }; - } - - /* - * Coercion to `ServeProto::WriteConn`. This makes it easy to use the - * factored out serve protocol searlizers with a - * `LegacySSHStore::Connection`. - * - * The serve protocol connection types are unidirectional, unlike - * this type. - */ - operator ServeProto::WriteConn () - { - return ServeProto::WriteConn { - .to = to, - .version = remoteVersion, - }; - } }; @@ -96,28 +61,20 @@ ref LegacySSHStore::openConnection() conn->to = FdSink(conn->sshConn->in.get()); conn->from = FdSource(conn->sshConn->out.get()); + StringSink saved; + TeeSource tee(conn->from, saved); try { - conn->to << SERVE_MAGIC_1 << SERVE_PROTOCOL_VERSION; - conn->to.flush(); - - StringSink saved; - try { - TeeSource tee(conn->from, saved); - unsigned int magic = readInt(tee); - if (magic != SERVE_MAGIC_2) - throw Error("'nix-store --serve' protocol mismatch from '%s'", host); - } catch (SerialisationError & e) { - /* In case the other side is waiting for our input, - close it. */ - conn->sshConn->in.close(); - auto msg = conn->from.drain(); - throw Error("'nix-store --serve' protocol mismatch from '%s', got '%s'", - host, chomp(saved.s + msg)); + conn->remoteVersion = ServeProto::BasicClientConnection::handshake( + conn->to, tee, SERVE_PROTOCOL_VERSION, host); + } catch (SerialisationError & e) { + // in.close(): Don't let the remote block on us not writing. + conn->sshConn->in.close(); + { + NullSink nullSink; + conn->from.drainInto(nullSink); } - conn->remoteVersion = readInt(conn->from); - if (GET_PROTOCOL_MAJOR(conn->remoteVersion) != 0x200) - throw Error("unsupported 'nix-store --serve' protocol version on '%s'", host); - + throw Error("'nix-store --serve' protocol mismatch from '%s', got '%s'", + host, chomp(saved.s)); } catch (EndOfFile & e) { throw Error("cannot connect to '%1%'", host); } @@ -232,16 +189,16 @@ void LegacySSHStore::narFromPath(const StorePath & path, Sink & sink) } -void LegacySSHStore::putBuildSettings(Connection & conn) +static ServeProto::BuildOptions buildSettings() { - ServeProto::write(*this, conn, ServeProto::BuildOptions { + return { .maxSilentTime = settings.maxSilentTime, .buildTimeout = settings.buildTimeout, .maxLogSize = settings.maxLogSize, .nrRepeats = 0, // buildRepeat hasn't worked for ages anyway .enforceDeterminism = 0, .keepFailed = settings.keepFailed, - }); + }; } @@ -250,14 +207,7 @@ BuildResult LegacySSHStore::buildDerivation(const StorePath & drvPath, const Bas { auto conn(connections->get()); - conn->to - << ServeProto::Command::BuildDerivation - << printStorePath(drvPath); - writeDerivation(conn->to, *this, drv); - - putBuildSettings(*conn); - - conn->to.flush(); + conn->putBuildDerivationRequest(*this, drvPath, drv, buildSettings()); return ServeProto::Serialise::read(*this, *conn); } @@ -288,7 +238,7 @@ void LegacySSHStore::buildPaths(const std::vector & drvPaths, Build } conn->to << ss; - putBuildSettings(*conn); + ServeProto::write(*this, *conn, buildSettings()); conn->to.flush(); @@ -328,15 +278,8 @@ StorePathSet LegacySSHStore::queryValidPaths(const StorePathSet & paths, SubstituteFlag maybeSubstitute) { auto conn(connections->get()); - - conn->to - << ServeProto::Command::QueryValidPaths - << false // lock - << maybeSubstitute; - ServeProto::write(*this, *conn, paths); - conn->to.flush(); - - return ServeProto::Serialise::read(*this, *conn); + return conn->queryValidPaths(*this, + false, paths, maybeSubstitute); } diff --git a/src/libstore/legacy-ssh-store.hh b/src/libstore/legacy-ssh-store.hh index 7cee31d66..bdf79eab3 100644 --- a/src/libstore/legacy-ssh-store.hh +++ b/src/libstore/legacy-ssh-store.hh @@ -78,10 +78,6 @@ struct LegacySSHStore : public virtual LegacySSHStoreConfig, public virtual Stor RepairFlag repair = NoRepair) override { unsupported("addToStore"); } -private: - - void putBuildSettings(Connection & conn); - public: BuildResult buildDerivation(const StorePath & drvPath, const BasicDerivation & drv, diff --git a/src/libstore/serve-protocol-impl.cc b/src/libstore/serve-protocol-impl.cc new file mode 100644 index 000000000..b39212884 --- /dev/null +++ b/src/libstore/serve-protocol-impl.cc @@ -0,0 +1,69 @@ +#include "serve-protocol-impl.hh" +#include "build-result.hh" +#include "derivations.hh" + +namespace nix { + +ServeProto::Version ServeProto::BasicClientConnection::handshake( + BufferedSink & to, + Source & from, + ServeProto::Version localVersion, + std::string_view host) +{ + to << SERVE_MAGIC_1 << localVersion; + to.flush(); + + unsigned int magic = readInt(from); + if (magic != SERVE_MAGIC_2) + throw Error("'nix-store --serve' protocol mismatch from '%s'", host); + auto remoteVersion = readInt(from); + if (GET_PROTOCOL_MAJOR(remoteVersion) != 0x200) + throw Error("unsupported 'nix-store --serve' protocol version on '%s'", host); + return remoteVersion; +} + +ServeProto::Version ServeProto::BasicServerConnection::handshake( + BufferedSink & to, + Source & from, + ServeProto::Version localVersion) +{ + unsigned int magic = readInt(from); + if (magic != SERVE_MAGIC_1) throw Error("protocol mismatch"); + to << SERVE_MAGIC_2 << localVersion; + to.flush(); + return readInt(from); +} + + +StorePathSet ServeProto::BasicClientConnection::queryValidPaths( + const Store & store, + bool lock, const StorePathSet & paths, + SubstituteFlag maybeSubstitute) +{ + to + << ServeProto::Command::QueryValidPaths + << lock + << maybeSubstitute; + write(store, *this, paths); + to.flush(); + + return Serialise::read(store, *this); +} + + +void ServeProto::BasicClientConnection::putBuildDerivationRequest( + const Store & store, + const StorePath & drvPath, const BasicDerivation & drv, + const ServeProto::BuildOptions & options) +{ + to + << ServeProto::Command::BuildDerivation + << store.printStorePath(drvPath); + writeDerivation(to, store, drv); + + ServeProto::write(store, *this, options); + + to.flush(); +} + +} diff --git a/src/libstore/serve-protocol-impl.hh b/src/libstore/serve-protocol-impl.hh index 6f3b177ac..fd8d94697 100644 --- a/src/libstore/serve-protocol-impl.hh +++ b/src/libstore/serve-protocol-impl.hh @@ -10,6 +10,7 @@ #include "serve-protocol.hh" #include "length-prefixed-protocol-helper.hh" +#include "store-api.hh" namespace nix { @@ -56,4 +57,101 @@ struct ServeProto::Serialise /* protocol-specific templates */ +struct ServeProto::BasicClientConnection +{ + FdSink to; + FdSource from; + ServeProto::Version remoteVersion; + + /** + * Establishes connection, negotiating version. + * + * @return the version provided by the other side of the + * connection. + * + * @param to Taken by reference to allow for various error handling + * mechanisms. + * + * @param from Taken by reference to allow for various error + * handling mechanisms. + * + * @param localVersion Our version which is sent over + * + * @param host Just used to add context to thrown exceptions. + */ + static ServeProto::Version handshake( + BufferedSink & to, + Source & from, + ServeProto::Version localVersion, + std::string_view host); + + /** + * Coercion to `ServeProto::ReadConn`. This makes it easy to use the + * factored out serve protocol serializers with a + * `LegacySSHStore::Connection`. + * + * The serve protocol connection types are unidirectional, unlike + * this type. + */ + operator ServeProto::ReadConn () + { + return ServeProto::ReadConn { + .from = from, + .version = remoteVersion, + }; + } + + /** + * Coercion to `ServeProto::WriteConn`. This makes it easy to use the + * factored out serve protocol serializers with a + * `LegacySSHStore::Connection`. + * + * The serve protocol connection types are unidirectional, unlike + * this type. + */ + operator ServeProto::WriteConn () + { + return ServeProto::WriteConn { + .to = to, + .version = remoteVersion, + }; + } + + StorePathSet queryValidPaths( + const Store & remoteStore, + bool lock, const StorePathSet & paths, + SubstituteFlag maybeSubstitute); + + /** + * Just the request half, because Hydra may do other things between + * issuing the request and reading the `BuildResult` response. + */ + void putBuildDerivationRequest( + const Store & store, + const StorePath & drvPath, const BasicDerivation & drv, + const ServeProto::BuildOptions & options); +}; + +struct ServeProto::BasicServerConnection +{ + /** + * Establishes connection, negotiating version. + * + * @return the version provided by the other side of the + * connection. + * + * @param to Taken by reference to allow for various error handling + * mechanisms. + * + * @param from Taken by reference to allow for various error + * handling mechanisms. + * + * @param localVersion Our version which is sent over + */ + static ServeProto::Version handshake( + BufferedSink & to, + Source & from, + ServeProto::Version localVersion); +}; + } diff --git a/src/libstore/serve-protocol.hh b/src/libstore/serve-protocol.hh index 1665b935f..8c112bb74 100644 --- a/src/libstore/serve-protocol.hh +++ b/src/libstore/serve-protocol.hh @@ -59,6 +59,14 @@ struct ServeProto Version version; }; + /** + * Stripped down serialization logic suitable for sharing with Hydra. + * + * @todo remove once Hydra uses Store abstraction consistently. + */ + struct BasicClientConnection; + struct BasicServerConnection; + /** * Data type for canonical pairs of serialisers for the serve protocol. * diff --git a/src/nix-store/nix-store.cc b/src/nix-store/nix-store.cc index 0a0a3ab1a..40378e123 100644 --- a/src/nix-store/nix-store.cc +++ b/src/nix-store/nix-store.cc @@ -828,11 +828,9 @@ static void opServe(Strings opFlags, Strings opArgs) FdSink out(STDOUT_FILENO); /* Exchange the greeting. */ - unsigned int magic = readInt(in); - if (magic != SERVE_MAGIC_1) throw Error("protocol mismatch"); - out << SERVE_MAGIC_2 << SERVE_PROTOCOL_VERSION; - out.flush(); - ServeProto::Version clientVersion = readInt(in); + ServeProto::Version clientVersion = + ServeProto::BasicServerConnection::handshake( + out, in, SERVE_PROTOCOL_VERSION); ServeProto::ReadConn rconn { .from = in, diff --git a/tests/unit/libstore/data/serve-protocol/handshake-to-client.bin b/tests/unit/libstore/data/serve-protocol/handshake-to-client.bin new file mode 100644 index 000000000..15ba4b5e3 Binary files /dev/null and b/tests/unit/libstore/data/serve-protocol/handshake-to-client.bin differ diff --git a/tests/unit/libstore/serve-protocol.cc b/tests/unit/libstore/serve-protocol.cc index 8f256d1e6..597c0b570 100644 --- a/tests/unit/libstore/serve-protocol.cc +++ b/tests/unit/libstore/serve-protocol.cc @@ -1,3 +1,4 @@ +#include #include #include @@ -6,6 +7,7 @@ #include "serve-protocol.hh" #include "serve-protocol-impl.hh" #include "build-result.hh" +#include "file-descriptor.hh" #include "tests/protocol.hh" #include "tests/characterization.hh" @@ -401,4 +403,112 @@ VERSIONED_CHARACTERIZATION_TEST( }, })) +TEST_F(ServeProtoTest, handshake_log) +{ + CharacterizationTest::writeTest("handshake-to-client", [&]() -> std::string { + StringSink toClientLog; + + Pipe toClient, toServer; + toClient.create(); + toServer.create(); + + ServeProto::Version clientResult, serverResult; + + auto thread = std::thread([&]() { + FdSink out { toServer.writeSide.get() }; + FdSource in0 { toClient.readSide.get() }; + TeeSource in { in0, toClientLog }; + clientResult = ServeProto::BasicClientConnection::handshake( + out, in, defaultVersion, "blah"); + }); + + { + FdSink out { toClient.writeSide.get() }; + FdSource in { toServer.readSide.get() }; + serverResult = ServeProto::BasicServerConnection::handshake( + out, in, defaultVersion); + }; + + thread.join(); + + return std::move(toClientLog.s); + }); +} + +/// Has to be a `BufferedSink` for handshake. +struct NullBufferedSink : BufferedSink { + void writeUnbuffered(std::string_view data) override { } +}; + +TEST_F(ServeProtoTest, handshake_client_replay) +{ + CharacterizationTest::readTest("handshake-to-client", [&](std::string toClientLog) { + NullBufferedSink nullSink; + + StringSource in { toClientLog }; + auto clientResult = ServeProto::BasicClientConnection::handshake( + nullSink, in, defaultVersion, "blah"); + + EXPECT_EQ(clientResult, defaultVersion); + }); +} + +TEST_F(ServeProtoTest, handshake_client_truncated_replay_throws) +{ + CharacterizationTest::readTest("handshake-to-client", [&](std::string toClientLog) { + for (size_t len = 0; len < toClientLog.size(); ++len) { + NullBufferedSink nullSink; + StringSource in { + // truncate + toClientLog.substr(0, len) + }; + if (len < 8) { + EXPECT_THROW( + ServeProto::BasicClientConnection::handshake( + nullSink, in, defaultVersion, "blah"), + EndOfFile); + } else { + // Not sure why cannot keep on checking for `EndOfFile`. + EXPECT_THROW( + ServeProto::BasicClientConnection::handshake( + nullSink, in, defaultVersion, "blah"), + Error); + } + } + }); +} + +TEST_F(ServeProtoTest, handshake_client_corrupted_throws) +{ + CharacterizationTest::readTest("handshake-to-client", [&](const std::string toClientLog) { + for (size_t idx = 0; idx < toClientLog.size(); ++idx) { + // corrupt a copy + std::string toClientLogCorrupt = toClientLog; + toClientLogCorrupt[idx] *= 4; + ++toClientLogCorrupt[idx]; + + NullBufferedSink nullSink; + StringSource in { toClientLogCorrupt }; + + if (idx < 4 || idx == 9) { + // magic bytes don't match + EXPECT_THROW( + ServeProto::BasicClientConnection::handshake( + nullSink, in, defaultVersion, "blah"), + Error); + } else if (idx < 8 || idx >= 12) { + // Number out of bounds + EXPECT_THROW( + ServeProto::BasicClientConnection::handshake( + nullSink, in, defaultVersion, "blah"), + SerialisationError); + } else { + auto ver = ServeProto::BasicClientConnection::handshake( + nullSink, in, defaultVersion, "blah"); + EXPECT_NE(ver, defaultVersion); + } + } + }); +} + }