diff --git a/include/spinscale/co/nonPostingPromise.h b/include/spinscale/co/nonPostingPromise.h index ab6b328..9c623bc 100644 --- a/include/spinscale/co/nonPostingPromise.h +++ b/include/spinscale/co/nonPostingPromise.h @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -126,12 +127,8 @@ struct NonPostingPromise << std::this_thread::get_id() << " Non-viral non-posting: invoking callerLambda directly.\n"; #endif - if (calleePromise.returnValues.myExceptionPtr) { - std::rethrow_exception( - calleePromise.returnValues.myExceptionPtr); - } - - calleePromise.callerLambda(); + auto callerLambda = std::move(calleePromise.callerLambda); + callerLambda(); return std::noop_coroutine(); } diff --git a/include/spinscale/co/nonViralCompletion.h b/include/spinscale/co/nonViralCompletion.h new file mode 100644 index 0000000..b000d50 --- /dev/null +++ b/include/spinscale/co/nonViralCompletion.h @@ -0,0 +1,40 @@ +#ifndef NON_VIRAL_COMPLETION_H +#define NON_VIRAL_COMPLETION_H + +#include +#include + +namespace sscl::co { + +class NonViralCompletion +{ +public: + explicit NonViralCompletion(std::exception_ptr &exceptionPtr) + : exceptionPtr(exceptionPtr) + {} + + bool hasException() const noexcept + { + return exceptionPtr != nullptr; + } + + void checkAndRethrowException() const + { + if (exceptionPtr) + { + std::rethrow_exception(exceptionPtr); + } + } + + std::exception_ptr releaseException() noexcept + { + return std::exchange(exceptionPtr, nullptr); + } + +private: + std::exception_ptr &exceptionPtr; +}; + +} // namespace sscl::co + +#endif // NON_VIRAL_COMPLETION_H diff --git a/include/spinscale/co/postingPromise.h b/include/spinscale/co/postingPromise.h index a7161f5..1f6079e 100644 --- a/include/spinscale/co/postingPromise.h +++ b/include/spinscale/co/postingPromise.h @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -164,15 +165,12 @@ struct PostingPromise std::cout << "final_suspend" << ": " << std::this_thread::get_id() << " Non-viral: posting callerLambda completion to callerIoContext.\n"; #endif + auto callerLambda = std::move(calleePromise.callerLambda); boost::asio::post( calleePromise.callerIoContext, - [&calleeRef = calleePromise]() + [callerLambda = std::move(callerLambda)]() mutable { - if (calleeRef.returnValues.myExceptionPtr) { - std::rethrow_exception(calleeRef.returnValues.myExceptionPtr); - } - - calleeRef.callerLambda(); + callerLambda(); }); } else