diff --git a/include/spinscale/co/group.h b/include/spinscale/co/group.h index 07a71b2..634fdcc 100644 --- a/include/spinscale/co/group.h +++ b/include/spinscale/co/group.h @@ -1,6 +1,7 @@ #ifndef GROUP_H #define GROUP_H +#include #include #include #include @@ -60,6 +61,19 @@ concept AwaitableIface = requires(T &t) { { get_operator_co_await(t) }; } && AwaiterIface()))>; +template +T &asAwaiter(T &t) noexcept +{ + return t; +} + +template +auto asAwaiter(T &t) noexcept(noexcept(get_operator_co_await(t))) + -> decltype(get_operator_co_await(t)) +{ + return get_operator_co_await(t); +} + } // namespace detail template @@ -71,8 +85,27 @@ concept AwaiterIface = detail::AwaiterIface; template concept AwaitableOrAwaiterIface = AwaiterIface || AwaitableIface; -template -requires AwaitableOrAwaiterIface +/** Typical usage — parallel members, then gather: + * + * co::Group group; + * + * auto bodyInit = body.initializeCReq(exceptionPtr, noopCallback); + * auto legInit = leg.initializeCReq(exceptionPtr, noopCallback); + * ViralNonPostingInvoker batch = app.joltAllPuppetThreadsCReq(...); + * + * group.add(bodyInit); + * group.add(legInit); + * group.add(batch); + * + * co_await group.getAwaitAllSettlementsInvoker(); + * group.checkForAndReThrowGroupExceptions(); + * + * (void)bodyInit.completedReturnValues(); + * + * // When walking settlement slots by index: + * settlements[i].invokerAs>() + * .completedReturnValues(); + */ struct Group { enum class AwaitingCondition { @@ -91,9 +124,23 @@ struct Group UNSETTLED, COMPLETED, EXCEPTION_THROWN }; - SettlementDescriptor(Invoker &_invoker) - : invoker(std::ref(_invoker)) - {} + template + void bindMemberRef(Member &member) + { + memberInvokerRef = std::ref(member); + } + + template + Member &invokerAs() const + { + try { + return std::any_cast>( + memberInvokerRef).get(); + } catch (const std::bad_any_cast &) { + throw std::runtime_error( + "Group settlement invoker type mismatch"); + } + } void setSettlementStatus() noexcept { @@ -109,7 +156,7 @@ struct Group TypeE type = TypeE::UNSETTLED; std::exception_ptr calleeException = nullptr; std::exception_ptr adapterException = nullptr; - std::reference_wrapper invoker; + std::any memberInvokerRef; }; struct SettlementAwaitingInvoker; @@ -466,28 +513,17 @@ struct Group * target async fn, and also to convey its results back to the Group class. * It's effectively a go-between coro that provides the outcomes that Invokers * normally provide, without needing, itself, to be co_awaited. + * + * settlementIndex is captured by value (not a vector iterator) so adapter + * coros remain valid if settlements reallocate during concurrent add(). */ - NonAwaitableNonPostingAdapterCoro nonAwaitableAdapterCoro( + template + NonAwaitableNonPostingAdapterCoro memberAdapterCoro( + Member &memberInvoker, std::size_t settlementIndex) noexcept { - /** EXPLANATION: - * It's very convenient that our design for the NonViralPostingInvoker - * coincidentally allows us to supply a lambda that can be used to test - * for the settlement conditions that are being waited on by the Group's - * co_awaiter. - * - * settlementIndex is captured by value (not a vector iterator) so adapter - * coros remain valid if settlements reallocate during concurrent add(). - */ try { - /* Return values remain in the callee promise until the caller-owned - * invoker is destroyed (~PostingInvoker). The group co_awaiter reads - * results via settlements[settlementIndex].invoker after awaiting. - * - * Index settlements[] each time; do not cache a reference across - * co_await because concurrent add() may reallocate the vector. - */ - co_await s.rsrc.settlements[settlementIndex].invoker.get(); + co_await detail::asAwaiter(memberInvoker); } catch (...) { @@ -505,12 +541,8 @@ struct Group co_return; } - /** EXPLANATION: - * Each invoker passed to add() must outlive this Group and the callee frame - * (see ~PostingInvoker). The group co_awaiter reads return values from those - * invokers after awaiting; do not destroy an invoker until reads are done. - */ - void add(Invoker &invoker) + template + void add(Member &memberInvoker) { std::size_t settlementIndex = 0; @@ -525,10 +557,11 @@ struct Group } settlementIndex = s.rsrc.settlements.size(); - s.rsrc.settlements.emplace_back(invoker); + s.rsrc.settlements.emplace_back(); + s.rsrc.settlements[settlementIndex].bindMemberRef(memberInvoker); } - nonAwaitableAdapterCoro(settlementIndex); + memberAdapterCoro(memberInvoker, settlementIndex); } void checkForAndReThrowGroupExceptions() const diff --git a/include/spinscale/puppetApplication.h b/include/spinscale/puppetApplication.h index 7a45fae..b9281fa 100644 --- a/include/spinscale/puppetApplication.h +++ b/include/spinscale/puppetApplication.h @@ -39,7 +39,7 @@ public: protected: using PuppetLifetimeMgmtInvoker = PuppetThread::ViralThreadLifetimeMgmtInvoker; - using PuppetLifetimeMgmtGroup = co::Group; + using PuppetLifetimeMgmtGroup = co::Group; void addAllPuppetLifetimeInvokersToGroup( PuppetLifetimeMgmtGroup &group, diff --git a/src/puppetApplication.cpp b/src/puppetApplication.cpp index 9c58640..3522ff0 100644 --- a/src/puppetApplication.cpp +++ b/src/puppetApplication.cpp @@ -85,9 +85,7 @@ PuppetApplication::joltAllPuppetThreadsCReq( addAllPuppetLifetimeInvokersToGroup( group, invokers, PuppetThread::ThreadOp::JOLT); - PuppetLifetimeMgmtGroup::AwaitAllSettlementsInvoker groupAwaitAll( - group); - co_await groupAwaitAll; + co_await group.getAwaitAllSettlementsInvoker(); group.checkForAndReThrowGroupExceptions(); threadsHaveBeenJolted = true; @@ -111,9 +109,7 @@ PuppetApplication::allPuppetThreadsLifetimeOpCReq( std::vector invokers; addAllPuppetLifetimeInvokersToGroup(group, invokers, threadOp); - PuppetLifetimeMgmtGroup::AwaitAllSettlementsInvoker groupAwaitAll( - group); - co_await groupAwaitAll; + co_await group.getAwaitAllSettlementsInvoker(); group.checkForAndReThrowGroupExceptions(); co_return; @@ -151,8 +147,8 @@ PuppetApplication::resumeAllPuppetThreadsCReq( co::ViralNonPostingInvoker PuppetApplication::exitAllPuppetThreadsCReq( - [[maybe_unused]] std::exception_ptr &exceptionPtr, - [[maybe_unused]] std::function callback) + std::exception_ptr &exceptionPtr, + std::function callback) { if (componentThreads.empty()) { @@ -160,15 +156,10 @@ PuppetApplication::exitAllPuppetThreadsCReq( co_return; } - PuppetLifetimeMgmtGroup group; - std::vector invokers; - - addAllPuppetLifetimeInvokersToGroup( - group, invokers, PuppetThread::ThreadOp::EXIT); - PuppetLifetimeMgmtGroup::AwaitAllSettlementsInvoker groupAwaitAll( - group); - co_await groupAwaitAll; - group.checkForAndReThrowGroupExceptions(); + co_await allPuppetThreadsLifetimeOpCReq( + exceptionPtr, std::move(callback), + PuppetThread::ThreadOp::EXIT, + noPuppetThreadsToExitLogMessage); for (auto &thread : componentThreads) { thread->thread.join();