Ensure that Callback is called only once

Also, make Callback movable but uncopyable.
This commit is contained in:
Eelco Dolstra 2019-09-03 12:51:35 +02:00
parent 8c4ea7a451
commit 7348653ff4
No known key found for this signature in database
GPG key ID: 8170B4726D7198DE
5 changed files with 36 additions and 17 deletions

View file

@ -249,21 +249,23 @@ void BinaryCacheStore::queryPathInfoUncached(const Path & storePath,
auto narInfoFile = narInfoFileFor(storePath); auto narInfoFile = narInfoFileFor(storePath);
auto callbackPtr = std::make_shared<decltype(callback)>(std::move(callback));
getFile(narInfoFile, getFile(narInfoFile,
{[=](std::future<std::shared_ptr<std::string>> fut) { {[=](std::future<std::shared_ptr<std::string>> fut) {
try { try {
auto data = fut.get(); auto data = fut.get();
if (!data) return callback(nullptr); if (!data) return (*callbackPtr)(nullptr);
stats.narInfoRead++; stats.narInfoRead++;
callback((std::shared_ptr<ValidPathInfo>) (*callbackPtr)((std::shared_ptr<ValidPathInfo>)
std::make_shared<NarInfo>(*this, *data, narInfoFile)); std::make_shared<NarInfo>(*this, *data, narInfoFile));
(void) act; // force Activity into this lambda to ensure it stays alive (void) act; // force Activity into this lambda to ensure it stays alive
} catch (...) { } catch (...) {
callback.rethrow(); callbackPtr->rethrow();
} }
}}); }});
} }

View file

@ -77,13 +77,13 @@ struct CurlDownloader : public Downloader
DownloadItem(CurlDownloader & downloader, DownloadItem(CurlDownloader & downloader,
const DownloadRequest & request, const DownloadRequest & request,
Callback<DownloadResult> callback) Callback<DownloadResult> && callback)
: downloader(downloader) : downloader(downloader)
, request(request) , request(request)
, act(*logger, lvlTalkative, actDownload, , act(*logger, lvlTalkative, actDownload,
fmt(request.data ? "uploading '%s'" : "downloading '%s'", request.uri), fmt(request.data ? "uploading '%s'" : "downloading '%s'", request.uri),
{request.uri}, request.parentAct) {request.uri}, request.parentAct)
, callback(callback) , callback(std::move(callback))
, finalSink([this](const unsigned char * data, size_t len) { , finalSink([this](const unsigned char * data, size_t len) {
if (this->request.dataCallback) { if (this->request.dataCallback) {
writtenToSink += len; writtenToSink += len;
@ -665,7 +665,7 @@ struct CurlDownloader : public Downloader
return; return;
} }
enqueueItem(std::make_shared<DownloadItem>(*this, request, callback)); enqueueItem(std::make_shared<DownloadItem>(*this, request, std::move(callback)));
} }
}; };

View file

@ -137,17 +137,19 @@ protected:
auto request(makeRequest(path)); auto request(makeRequest(path));
auto callbackPtr = std::make_shared<decltype(callback)>(std::move(callback));
getDownloader()->enqueueDownload(request, getDownloader()->enqueueDownload(request,
{[callback, this](std::future<DownloadResult> result) { {[callbackPtr, this](std::future<DownloadResult> result) {
try { try {
callback(result.get().data); (*callbackPtr)(result.get().data);
} catch (DownloadError & e) { } catch (DownloadError & e) {
if (e.error == Downloader::NotFound || e.error == Downloader::Forbidden) if (e.error == Downloader::NotFound || e.error == Downloader::Forbidden)
return callback(std::shared_ptr<std::string>()); return (*callbackPtr)(std::shared_ptr<std::string>());
maybeDisable(); maybeDisable();
callback.rethrow(); callbackPtr->rethrow();
} catch (...) { } catch (...) {
callback.rethrow(); callbackPtr->rethrow();
} }
}}); }});
} }

View file

@ -365,8 +365,10 @@ void Store::queryPathInfo(const Path & storePath,
} catch (...) { return callback.rethrow(); } } catch (...) { return callback.rethrow(); }
auto callbackPtr = std::make_shared<decltype(callback)>(std::move(callback));
queryPathInfoUncached(storePath, queryPathInfoUncached(storePath,
{[this, storePath, hashPart, callback](std::future<std::shared_ptr<ValidPathInfo>> fut) { {[this, storePath, hashPart, callbackPtr](std::future<std::shared_ptr<ValidPathInfo>> fut) {
try { try {
auto info = fut.get(); auto info = fut.get();
@ -386,8 +388,8 @@ void Store::queryPathInfo(const Path & storePath,
throw InvalidPath("path '%s' is not valid", storePath); throw InvalidPath("path '%s' is not valid", storePath);
} }
callback(ref<ValidPathInfo>(info)); (*callbackPtr)(ref<ValidPathInfo>(info));
} catch (...) { callback.rethrow(); } } catch (...) { callbackPtr->rethrow(); }
}}); }});
} }

View file

@ -445,21 +445,34 @@ string get(const T & map, const string & key, const string & def = "")
type T or an exception. (We abuse std::future<T> to pass the value or type T or an exception. (We abuse std::future<T> to pass the value or
exception.) */ exception.) */
template<typename T> template<typename T>
struct Callback class Callback
{ {
std::function<void(std::future<T>)> fun; std::function<void(std::future<T>)> fun;
std::atomic_flag done = ATOMIC_FLAG_INIT;
public:
Callback(std::function<void(std::future<T>)> fun) : fun(fun) { } Callback(std::function<void(std::future<T>)> fun) : fun(fun) { }
void operator()(T && t) const Callback(Callback && callback) : fun(std::move(callback.fun))
{ {
auto prev = callback.done.test_and_set();
if (prev) done.test_and_set();
}
void operator()(T && t)
{
auto prev = done.test_and_set();
assert(!prev);
std::promise<T> promise; std::promise<T> promise;
promise.set_value(std::move(t)); promise.set_value(std::move(t));
fun(promise.get_future()); fun(promise.get_future());
} }
void rethrow(const std::exception_ptr & exc = std::current_exception()) const void rethrow(const std::exception_ptr & exc = std::current_exception())
{ {
auto prev = done.test_and_set();
assert(!prev);
std::promise<T> promise; std::promise<T> promise;
promise.set_exception(exc); promise.set_exception(exc);
fun(promise.get_future()); fun(promise.get_future());