From 7348653ff4fc4e9b2dc24943aabdb57179b1c75a Mon Sep 17 00:00:00 2001 From: Eelco Dolstra Date: Tue, 3 Sep 2019 12:51:35 +0200 Subject: [PATCH] Ensure that Callback is called only once Also, make Callback movable but uncopyable. --- src/libstore/binary-cache-store.cc | 8 +++++--- src/libstore/download.cc | 6 +++--- src/libstore/http-binary-cache-store.cc | 12 +++++++----- src/libstore/store-api.cc | 8 +++++--- src/libutil/util.hh | 19 ++++++++++++++++--- 5 files changed, 36 insertions(+), 17 deletions(-) diff --git a/src/libstore/binary-cache-store.cc b/src/libstore/binary-cache-store.cc index 4527ee6b..e56be625 100644 --- a/src/libstore/binary-cache-store.cc +++ b/src/libstore/binary-cache-store.cc @@ -249,21 +249,23 @@ void BinaryCacheStore::queryPathInfoUncached(const Path & storePath, auto narInfoFile = narInfoFileFor(storePath); + auto callbackPtr = std::make_shared(std::move(callback)); + getFile(narInfoFile, {[=](std::future> fut) { try { auto data = fut.get(); - if (!data) return callback(nullptr); + if (!data) return (*callbackPtr)(nullptr); stats.narInfoRead++; - callback((std::shared_ptr) + (*callbackPtr)((std::shared_ptr) std::make_shared(*this, *data, narInfoFile)); (void) act; // force Activity into this lambda to ensure it stays alive } catch (...) { - callback.rethrow(); + callbackPtr->rethrow(); } }}); } diff --git a/src/libstore/download.cc b/src/libstore/download.cc index a7d05946..cdf56e09 100644 --- a/src/libstore/download.cc +++ b/src/libstore/download.cc @@ -77,13 +77,13 @@ struct CurlDownloader : public Downloader DownloadItem(CurlDownloader & downloader, const DownloadRequest & request, - Callback callback) + Callback && callback) : downloader(downloader) , request(request) , act(*logger, lvlTalkative, actDownload, fmt(request.data ? "uploading '%s'" : "downloading '%s'", request.uri), {request.uri}, request.parentAct) - , callback(callback) + , callback(std::move(callback)) , finalSink([this](const unsigned char * data, size_t len) { if (this->request.dataCallback) { writtenToSink += len; @@ -665,7 +665,7 @@ struct CurlDownloader : public Downloader return; } - enqueueItem(std::make_shared(*this, request, callback)); + enqueueItem(std::make_shared(*this, request, std::move(callback))); } }; diff --git a/src/libstore/http-binary-cache-store.cc b/src/libstore/http-binary-cache-store.cc index df2fb933..e631d95f 100644 --- a/src/libstore/http-binary-cache-store.cc +++ b/src/libstore/http-binary-cache-store.cc @@ -137,17 +137,19 @@ protected: auto request(makeRequest(path)); + auto callbackPtr = std::make_shared(std::move(callback)); + getDownloader()->enqueueDownload(request, - {[callback, this](std::future result) { + {[callbackPtr, this](std::future result) { try { - callback(result.get().data); + (*callbackPtr)(result.get().data); } catch (DownloadError & e) { if (e.error == Downloader::NotFound || e.error == Downloader::Forbidden) - return callback(std::shared_ptr()); + return (*callbackPtr)(std::shared_ptr()); maybeDisable(); - callback.rethrow(); + callbackPtr->rethrow(); } catch (...) { - callback.rethrow(); + callbackPtr->rethrow(); } }}); } diff --git a/src/libstore/store-api.cc b/src/libstore/store-api.cc index 3bb9db0b..88a5b2f4 100644 --- a/src/libstore/store-api.cc +++ b/src/libstore/store-api.cc @@ -365,8 +365,10 @@ void Store::queryPathInfo(const Path & storePath, } catch (...) { return callback.rethrow(); } + auto callbackPtr = std::make_shared(std::move(callback)); + queryPathInfoUncached(storePath, - {[this, storePath, hashPart, callback](std::future> fut) { + {[this, storePath, hashPart, callbackPtr](std::future> fut) { try { auto info = fut.get(); @@ -386,8 +388,8 @@ void Store::queryPathInfo(const Path & storePath, throw InvalidPath("path '%s' is not valid", storePath); } - callback(ref(info)); - } catch (...) { callback.rethrow(); } + (*callbackPtr)(ref(info)); + } catch (...) { callbackPtr->rethrow(); } }}); } diff --git a/src/libutil/util.hh b/src/libutil/util.hh index b538a0b4..686e81d3 100644 --- a/src/libutil/util.hh +++ b/src/libutil/util.hh @@ -445,21 +445,34 @@ string get(const T & map, const string & key, const string & def = "") type T or an exception. (We abuse std::future to pass the value or exception.) */ template -struct Callback +class Callback { std::function)> fun; + std::atomic_flag done = ATOMIC_FLAG_INIT; + +public: Callback(std::function)> fun) : fun(fun) { } - void operator()(T && t) const + Callback(Callback && callback) : fun(std::move(callback.fun)) { + auto prev = callback.done.test_and_set(); + if (prev) done.test_and_set(); + } + + void operator()(T && t) + { + auto prev = done.test_and_set(); + assert(!prev); std::promise promise; promise.set_value(std::move(t)); fun(promise.get_future()); } - void rethrow(const std::exception_ptr & exc = std::current_exception()) const + void rethrow(const std::exception_ptr & exc = std::current_exception()) { + auto prev = done.test_and_set(); + assert(!prev); std::promise promise; promise.set_exception(exc); fun(promise.get_future());