diff --git a/include/spinscale/puppetApplication.h b/include/spinscale/puppetApplication.h index fbd7956..ac09fad 100644 --- a/include/spinscale/puppetApplication.h +++ b/include/spinscale/puppetApplication.h @@ -2,10 +2,10 @@ #define PUPPET_APPLICATION_H #include +#include #include #include #include -#include #include #include @@ -19,18 +19,16 @@ public: const std::vector> &threads); ~PuppetApplication() = default; - // Thread management methods - typedef std::function puppetThreadLifetimeMgmtOpCbFn; - NonViralNonPostingInvoker joltAllPuppetThreadsCReq( - cps::Callback callback); - NonViralNonPostingInvoker startAllPuppetThreadsCReq( - cps::Callback callback); - NonViralNonPostingInvoker pauseAllPuppetThreadsCReq( - cps::Callback callback); - NonViralNonPostingInvoker resumeAllPuppetThreadsCReq( - cps::Callback callback); - NonViralNonPostingInvoker exitAllPuppetThreadsCReq( - cps::Callback callback); + co::NonViralNonPostingInvoker joltAllPuppetThreadsCReq( + std::exception_ptr &exceptionPtr, std::function callback); + co::NonViralNonPostingInvoker startAllPuppetThreadsCReq( + std::exception_ptr &exceptionPtr, std::function callback); + co::NonViralNonPostingInvoker pauseAllPuppetThreadsCReq( + std::exception_ptr &exceptionPtr, std::function callback); + co::NonViralNonPostingInvoker resumeAllPuppetThreadsCReq( + std::exception_ptr &exceptionPtr, std::function callback); + co::NonViralNonPostingInvoker exitAllPuppetThreadsCReq( + std::exception_ptr &exceptionPtr, std::function callback); // CPU distribution method void distributeAndPinThreadsAcrossCpus(); @@ -59,9 +57,6 @@ protected: * a synchronization point for the entire system initialization. */ bool threadsHaveBeenJolted = false; - -private: - class PuppetThreadLifetimeMgmtOp; }; } // namespace sscl diff --git a/src/puppetApplication.cpp b/src/puppetApplication.cpp index 9780aa3..6dac615 100644 --- a/src/puppetApplication.cpp +++ b/src/puppetApplication.cpp @@ -1,205 +1,178 @@ #include -#include -#include -#include +#include +#include + +#include #include #include namespace sscl { +namespace puppet_application_detail { + +constexpr std::string_view noPuppetThreadsToStartLogMessage = + "Mrntt: No puppet threads to start"; +constexpr std::string_view noPuppetThreadsToPauseLogMessage = + "Mrntt: No puppet threads to pause"; +constexpr std::string_view noPuppetThreadsToResumeLogMessage = + "Mrntt: No puppet threads to resume"; +constexpr std::string_view noPuppetThreadsToExitLogMessage = + "Mrntt: No puppet threads to exit"; + +using PuppetLifetimeInvoker = PuppetThread::ViralThreadLifetimeMgmtInvoker; +using PuppetLifetimeGroup = co::Group; + +void addAllPuppetLifetimeInvokersToGroup( + PuppetLifetimeGroup &group, + std::vector &invokers, + const std::vector> &componentThreads, + PuppetThread::ThreadOp threadOp) +{ + invokers.reserve(componentThreads.size()); + + for (const auto &thread : componentThreads) + { + switch (threadOp) + { + case PuppetThread::ThreadOp::START: + invokers.emplace_back(thread->startThreadAReq()); + break; + case PuppetThread::ThreadOp::PAUSE: + invokers.emplace_back(thread->pauseThreadAReq()); + break; + case PuppetThread::ThreadOp::RESUME: + invokers.emplace_back(thread->resumeThreadAReq()); + break; + case PuppetThread::ThreadOp::EXIT: + invokers.emplace_back(thread->exitThreadAReq()); + break; + case PuppetThread::ThreadOp::JOLT: + invokers.emplace_back(thread->joltThreadAReq(thread)); + break; + default: + throw std::runtime_error( + std::string(__func__) + ": Invalid thread operation"); + } + + group.add(invokers.back()); + } +} + +co::NonViralNonPostingInvoker genericAllPuppetThreadsLifetimeOpCReq( + const std::vector> &componentThreads, + PuppetThread::ThreadOp threadOp, + std::string_view emptyThreadsLogMessage, + [[maybe_unused]] std::exception_ptr &exceptionPtr, + [[maybe_unused]] std::function callback) +{ + if (componentThreads.empty()) + { + std::cout << emptyThreadsLogMessage << "\n"; + co_return; + } + + PuppetLifetimeGroup group; + std::vector invokers; + + addAllPuppetLifetimeInvokersToGroup( + group, invokers, componentThreads, threadOp); + + co_await group.getAwaitAllSettlementsInvoker(); + group.checkForAndReThrowGroupExceptions(); + + co_return; +} + +} // namespace puppet_application_detail + PuppetApplication::PuppetApplication( const std::vector> &threads) : componentThreads(threads) { } -class PuppetApplication::PuppetThreadLifetimeMgmtOp -: public cps::NonPostedAsynchronousContinuation -{ -public: - PuppetThreadLifetimeMgmtOp( - PuppetApplication &parent, unsigned int nThreads, - cps::Callback callback) - : cps::NonPostedAsynchronousContinuation(callback), - loop(nThreads), - parent(parent) - {} - -public: - AsynchronousLoop loop; - PuppetApplication &parent; - -public: - void joltAllPuppetThreadsReq1( - [[maybe_unused]] std::shared_ptr context - ) - { - loop.incrementSuccessOrFailureDueTo(true); - if (!loop.isComplete()) { - return; - } - - parent.threadsHaveBeenJolted = true; - callOriginalCb(); - } - - void executeGenericOpOnAllPuppetThreadsReq1( - [[maybe_unused]] std::shared_ptr context - ) - { - loop.incrementSuccessOrFailureDueTo(true); - if (!loop.isComplete()) { - return; - } - - callOriginalCb(); - } - - void exitAllPuppetThreadsReq1( - [[maybe_unused]] std::shared_ptr context - ) - { - loop.incrementSuccessOrFailureDueTo(true); - if (!loop.isComplete()) { - return; - } - - for (auto& thread : parent.componentThreads) { - thread->thread.join(); - } - - callOriginalCb(); - } -}; - -void PuppetApplication::joltAllPuppetThreadsCReq( - cps::Callback callback - ) +co::NonViralNonPostingInvoker PuppetApplication::joltAllPuppetThreadsCReq( + [[maybe_unused]] std::exception_ptr &exceptionPtr, + [[maybe_unused]] std::function callback) { if (threadsHaveBeenJolted) { std::cout << "Mrntt: All puppet threads already JOLTed. " << "Skipping JOLT request." << "\n"; - callback.callbackFn(); - return; + co_return; } - // If no threads, set flag and call callback immediately - if (componentThreads.size() == 0 && callback.callbackFn) + if (componentThreads.empty()) { threadsHaveBeenJolted = true; - callback.callbackFn(); - return; + co_return; } - // Create a counter to track when all threads have been jolted - auto request = std::make_shared( - *this, componentThreads.size(), callback); + puppet_application_detail::PuppetLifetimeGroup group; + std::vector invokers; - for (auto& thread : componentThreads) - { - thread->joltThreadReq( - thread, - {request, std::bind( - &PuppetThreadLifetimeMgmtOp::joltAllPuppetThreadsReq1, - request.get(), request)}); - } + puppet_application_detail::addAllPuppetLifetimeInvokersToGroup( + group, invokers, componentThreads, PuppetThread::ThreadOp::JOLT); + + co_await group.getAwaitAllSettlementsInvoker(); + group.checkForAndReThrowGroupExceptions(); + + threadsHaveBeenJolted = true; + co_return; } -void PuppetApplication::startAllPuppetThreadsCReq( - cps::Callback callback - ) +co::NonViralNonPostingInvoker PuppetApplication::startAllPuppetThreadsCReq( + std::exception_ptr &exceptionPtr, std::function callback) { - // If no threads, call callback immediately - if (componentThreads.size() == 0 && callback.callbackFn) - { - callback.callbackFn(); - return; - } - - // Create a counter to track when all threads have started - auto request = std::make_shared( - *this, componentThreads.size(), callback); - - for (auto& thread : componentThreads) - { - thread->startThreadReq( - {request, std::bind( - &PuppetThreadLifetimeMgmtOp::executeGenericOpOnAllPuppetThreadsReq1, - request.get(), request)}); - } + return puppet_application_detail::genericAllPuppetThreadsLifetimeOpCReq( + componentThreads, PuppetThread::ThreadOp::START, + puppet_application_detail::noPuppetThreadsToStartLogMessage, + exceptionPtr, callback); } -void PuppetApplication::pauseAllPuppetThreadsCReq( - cps::Callback callback - ) +co::NonViralNonPostingInvoker PuppetApplication::pauseAllPuppetThreadsCReq( + std::exception_ptr &exceptionPtr, std::function callback) { - // If no threads, call callback immediately - if (componentThreads.size() == 0 && callback.callbackFn) - { - callback.callbackFn(); - return; - } - - // Create a counter to track when all threads have paused - auto request = std::make_shared( - *this, componentThreads.size(), callback); - - for (auto& thread : componentThreads) - { - thread->pauseThreadReq( - {request, std::bind( - &PuppetThreadLifetimeMgmtOp::executeGenericOpOnAllPuppetThreadsReq1, - request.get(), request)}); - } + return puppet_application_detail::genericAllPuppetThreadsLifetimeOpCReq( + componentThreads, PuppetThread::ThreadOp::PAUSE, + puppet_application_detail::noPuppetThreadsToPauseLogMessage, + exceptionPtr, callback); } -void PuppetApplication::resumeAllPuppetThreadsCReq( - cps::Callback callback - ) +co::NonViralNonPostingInvoker PuppetApplication::resumeAllPuppetThreadsCReq( + std::exception_ptr &exceptionPtr, std::function callback) { - // If no threads, call callback immediately - if (componentThreads.size() == 0 && callback.callbackFn) - { - callback.callbackFn(); - return; - } - - // Create a counter to track when all threads have resumed - auto request = std::make_shared( - *this, componentThreads.size(), callback); - - for (auto& thread : componentThreads) - { - thread->resumeThreadReq( - {request, std::bind( - &PuppetThreadLifetimeMgmtOp::executeGenericOpOnAllPuppetThreadsReq1, - request.get(), request)}); - } + return puppet_application_detail::genericAllPuppetThreadsLifetimeOpCReq( + componentThreads, PuppetThread::ThreadOp::RESUME, + puppet_application_detail::noPuppetThreadsToResumeLogMessage, + exceptionPtr, callback); } -void PuppetApplication::exitAllPuppetThreadsCReq( - cps::Callback callback - ) +co::NonViralNonPostingInvoker PuppetApplication::exitAllPuppetThreadsCReq( + [[maybe_unused]] std::exception_ptr &exceptionPtr, + [[maybe_unused]] std::function callback) { - // If no threads, call callback immediately - if (componentThreads.size() == 0 && callback.callbackFn) + if (componentThreads.empty()) { - callback.callbackFn(); - return; + std::cout << puppet_application_detail::noPuppetThreadsToExitLogMessage + << "\n"; + co_return; } - // Create a counter to track when all threads have exited - auto request = std::make_shared( - *this, componentThreads.size(), callback); + puppet_application_detail::PuppetLifetimeGroup group; + std::vector invokers; - for (auto& thread : componentThreads) - { - thread->exitThreadReq( - {request, std::bind( - &PuppetThreadLifetimeMgmtOp::exitAllPuppetThreadsReq1, - request.get(), request)}); + puppet_application_detail::addAllPuppetLifetimeInvokersToGroup( + group, invokers, componentThreads, PuppetThread::ThreadOp::EXIT); + + co_await group.getAwaitAllSettlementsInvoker(); + group.checkForAndReThrowGroupExceptions(); + + for (auto &thread : componentThreads) { + thread->thread.join(); } + + co_return; } void PuppetApplication::distributeAndPinThreadsAcrossCpus()