From a29c779f6e221e7afc4bcb250ade8985700ac8f5 Mon Sep 17 00:00:00 2001 From: Hayodea Hekol Date: Sat, 13 Jun 2026 17:17:57 -0400 Subject: [PATCH] Tests: Add all tests from the coro creation repo We went back and brought along all the tests we implemented while we were building the new coro framework. --- tests/CMakeLists.txt | 69 +- tests/co/component_continuation_tests.cpp | 250 +++++++ tests/co/group_edge_tests.cpp | 835 ++++++++++++++++++++++ tests/co/group_timer_tests.cpp | 268 +++++++ tests/co/posting_cross_thread_tests.cpp | 252 +++++++ tests/co/viral_non_posting_tests.cpp | 511 +++++++++++++ tests/cps/qutex_tests.cpp | 2 - tests/support/groupAssertions.h | 104 +++ tests/support/threadHarness.cpp | 413 +++++++++++ tests/support/threadHarness.h | 362 ++++++++++ tests/support/timerAwaiters.h | 161 +++++ 11 files changed, 3199 insertions(+), 28 deletions(-) create mode 100644 tests/co/component_continuation_tests.cpp create mode 100644 tests/co/group_edge_tests.cpp create mode 100644 tests/co/group_timer_tests.cpp create mode 100644 tests/co/posting_cross_thread_tests.cpp create mode 100644 tests/co/viral_non_posting_tests.cpp create mode 100644 tests/support/groupAssertions.h create mode 100644 tests/support/threadHarness.cpp create mode 100644 tests/support/threadHarness.h create mode 100644 tests/support/timerAwaiters.h diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index bfeeccf..753b7b4 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,37 +1,54 @@ -add_executable(spinscale_env_kv_store_tests +add_library(spinscale_test_support STATIC + support/threadHarness.cpp +) + +target_include_directories(spinscale_test_support PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_link_libraries(spinscale_test_support PUBLIC + spinscale + gtest +) + +function(add_spinscale_gtest target) + add_executable(${target} ${ARGN}) + target_link_libraries(${target} PRIVATE + spinscale_test_support + gtest_main + ) + add_dependencies(${target} gtest_main) + add_test(NAME ${target} COMMAND ${target}) +endfunction() + +add_spinscale_gtest(spinscale_env_kv_store_tests env_kv_store_test.cpp ) -target_link_libraries(spinscale_env_kv_store_tests PRIVATE - spinscale - gtest_main -) - -add_dependencies(spinscale_env_kv_store_tests gtest_main) -add_test(NAME spinscale_env_kv_store_tests - COMMAND spinscale_env_kv_store_tests) - -add_executable(qutex_tests +add_spinscale_gtest(qutex_tests cps/qutex_tests.cpp ) -target_link_libraries(qutex_tests PRIVATE - spinscale - gtest_main -) - -add_dependencies(qutex_tests gtest_main) -add_test(NAME qutex_tests COMMAND qutex_tests) - -add_executable(nonViralTaskNursery_tests +add_spinscale_gtest(nonViralTaskNursery_tests co/nonViralTaskNursery_tests.cpp ) -target_link_libraries(nonViralTaskNursery_tests PRIVATE - spinscale - gtest_main +add_spinscale_gtest(co_viral_non_posting_tests + co/viral_non_posting_tests.cpp ) -add_dependencies(nonViralTaskNursery_tests gtest_main) -add_test(NAME nonViralTaskNursery_tests - COMMAND nonViralTaskNursery_tests) +add_spinscale_gtest(co_posting_cross_thread_tests + co/posting_cross_thread_tests.cpp +) + +add_spinscale_gtest(co_group_edge_tests + co/group_edge_tests.cpp +) + +add_spinscale_gtest(co_group_timer_tests + co/group_timer_tests.cpp +) + +add_spinscale_gtest(co_component_continuation_tests + co/component_continuation_tests.cpp +) diff --git a/tests/co/component_continuation_tests.cpp b/tests/co/component_continuation_tests.cpp new file mode 100644 index 0000000..e2fe57e --- /dev/null +++ b/tests/co/component_continuation_tests.cpp @@ -0,0 +1,250 @@ +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include + +namespace { + +constexpr int leftValue = 1; +constexpr int rightValue = 2; +constexpr int expectedIntSum = 3; +constexpr int bodyArgument = 4; +constexpr const char *bodyStringArgument = "KEKW"; +constexpr const char *leftString = "Hello"; +constexpr const char *rightString = "World"; +constexpr const char *expectedString = "Hello World"; + +using BodyNonViralInvoker = + sscl::tests::RoleNonViralPostingInvoker< + sscl::tests::PostingThreadRole::BODY>; + +template +using BodyViralInvoker = + sscl::tests::RoleViralPostingInvoker< + sscl::tests::PostingThreadRole::BODY, + T>; + +template +using WorldViralInvoker = + sscl::tests::RoleViralPostingInvoker< + sscl::tests::PostingThreadRole::WORLD, + T>; + +template +using LegViralInvoker = + sscl::tests::RoleViralPostingInvoker< + sscl::tests::PostingThreadRole::LEG, + T>; + +class ComponentContinuationTrace +{ +public: + void recordBodyThread() + { + std::lock_guard guard(mutex); + bodyThreadId = std::this_thread::get_id(); + } + + void recordWorldThread() + { + std::lock_guard guard(mutex); + worldThreadId = std::this_thread::get_id(); + } + + void recordLegThread() + { + std::lock_guard guard(mutex); + legThreadId = std::this_thread::get_id(); + } + + void recordCompletionThread() + { + std::lock_guard guard(mutex); + completionThreadId = std::this_thread::get_id(); + } + + void recordLegSum(int value) + { + std::lock_guard guard(mutex); + legSum = value; + } + + void recordWorldString(std::string value) + { + std::lock_guard guard(mutex); + worldString = std::move(value); + } + + void recordBodyString(std::string value) + { + std::lock_guard guard(mutex); + bodyString = std::move(value); + } + + std::thread::id bodyThread() const + { + std::lock_guard guard(mutex); + return bodyThreadId; + } + + std::thread::id worldThread() const + { + std::lock_guard guard(mutex); + return worldThreadId; + } + + std::thread::id legThread() const + { + std::lock_guard guard(mutex); + return legThreadId; + } + + std::thread::id completionThread() const + { + std::lock_guard guard(mutex); + return completionThreadId; + } + + int recordedLegSum() const + { + std::lock_guard guard(mutex); + return legSum; + } + + std::string recordedWorldString() const + { + std::lock_guard guard(mutex); + return worldString; + } + + std::string recordedBodyString() const + { + std::lock_guard guard(mutex); + return bodyString; + } + +private: + mutable std::mutex mutex; + std::thread::id bodyThreadId; + std::thread::id worldThreadId; + std::thread::id legThreadId; + std::thread::id completionThreadId; + int legSum = 0; + std::string worldString; + std::string bodyString; +}; + +LegViralInvoker print2Ints( + int arg1, + int arg2, + ComponentContinuationTrace &trace) +{ + sscl::co::CoQutex print2IntsLock; + trace.recordLegThread(); + auto releaseHandle = + co_await print2IntsLock.getAcquireInvocationAndSuspensionPolicy(); + const int sum = arg1 + arg2; + trace.recordLegSum(sum); + releaseHandle.release(); + co_return sum; +} + +WorldViralInvoker print2Strings( + std::string arg1, + std::string arg2, + ComponentContinuationTrace &trace) +{ + sscl::co::CoQutex print2StringsLock; + trace.recordWorldThread(); + auto releaseHandle = + co_await print2StringsLock.getAcquireInvocationAndSuspensionPolicy(); + const int returnedInt = + co_await print2Ints(leftValue, rightValue, trace); + releaseHandle.release(); + + if (returnedInt != expectedIntSum) { + throw std::runtime_error("LEG int return mismatch"); + } + + std::string returnedString = arg1 + " " + arg2; + trace.recordWorldString(returnedString); + co_return returnedString; +} + +BodyNonViralInvoker initializeDemoCReq( + std::exception_ptr &exceptionPtr, + std::function completion, + int arg3, + std::string arg4, + ComponentContinuationTrace &trace) +{ + (void)exceptionPtr; + (void)completion; + (void)arg3; + (void)arg4; + + sscl::co::CoQutex initializeLock; + trace.recordBodyThread(); + auto releaseHandle = + co_await initializeLock.getAcquireInvocationAndSuspensionPolicy(); + std::string returnedString = + co_await print2Strings(leftString, rightString, trace); + releaseHandle.release(); + + trace.recordBodyString(returnedString); + co_return; +} + +class ComponentContinuationTest +: public ::testing::Test +{ +protected: + sscl::tests::PostingThreadSet threads; +}; + +} // namespace + +TEST_F(ComponentContinuationTest, SyncMainStyleContinuationCrossesComponentThreads) +{ + ComponentContinuationTrace trace; + + ASSERT_NO_THROW( + sscl::tests::runNonViralPostingTask( + threads.caller(), + [&trace]( + std::exception_ptr &exceptionPtr, + std::function completion) + { + return initializeDemoCReq( + exceptionPtr, + [&trace, completion = std::move(completion)]() mutable + { + trace.recordCompletionThread(); + completion(); + }, + bodyArgument, + bodyStringArgument, + trace); + })); + + EXPECT_EQ(trace.bodyThread(), threads.body().osThreadId()); + EXPECT_EQ(trace.worldThread(), threads.world().osThreadId()); + EXPECT_EQ(trace.legThread(), threads.leg().osThreadId()); + EXPECT_EQ(trace.completionThread(), threads.caller().osThreadId()); + + EXPECT_NE(trace.bodyThread(), trace.worldThread()); + EXPECT_NE(trace.worldThread(), trace.legThread()); + EXPECT_NE(trace.legThread(), trace.completionThread()); + + EXPECT_EQ(trace.recordedLegSum(), expectedIntSum); + EXPECT_EQ(trace.recordedWorldString(), expectedString); + EXPECT_EQ(trace.recordedBodyString(), expectedString); +} diff --git a/tests/co/group_edge_tests.cpp b/tests/co/group_edge_tests.cpp new file mode 100644 index 0000000..b827e10 --- /dev/null +++ b/tests/co/group_edge_tests.cpp @@ -0,0 +1,835 @@ +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +#include +#include + +#include +#include +#include + +namespace { + +constexpr int delayShortMs = 50; +constexpr int delayMediumMs = 200; +constexpr int delayLongMs = 500; +constexpr int delayAddWhileSuspendedProbeMs = 80; +constexpr int expectedNonStdThrowValue = 42; +constexpr int wave2ImmediateSettlementLabel = 1000; +constexpr const char *expectedThrowMessage = + "group_edge_test intentional failure"; + +using CallerDriver = + sscl::tests::RoleNonViralPostingInvoker< + sscl::tests::PostingThreadRole::CALLER>; + +using CalleeIntInvoker = + sscl::tests::RoleViralPostingInvoker< + sscl::tests::PostingThreadRole::CALLEE, + int>; + +using CalleeVoidInvoker = + sscl::tests::RoleViralPostingInvoker< + sscl::tests::PostingThreadRole::CALLEE, + void>; + +CalleeIntInvoker waitAndReturnLabel(int timerLabelMilliseconds) +{ + const boost::system::error_code waitError = + co_await sscl::tests::DeadlineTimerAwaiter{ + sscl::ComponentThread::getSelf()->getIoContext(), + timerLabelMilliseconds}; + sscl::tests::throwIfTimerWaitFailed(waitError); + co_return timerLabelMilliseconds; +} + +CalleeIntInvoker waitThenThrowAfterDelay(int delayMilliseconds) +{ + const boost::system::error_code waitError = + co_await sscl::tests::DeadlineTimerAwaiter{ + sscl::ComponentThread::getSelf()->getIoContext(), + delayMilliseconds}; + sscl::tests::throwIfTimerWaitFailed(waitError); + throw std::runtime_error(expectedThrowMessage); +} + +CalleeIntInvoker waitThenThrowIntAfterDelay(int delayMilliseconds) +{ + const boost::system::error_code waitError = + co_await sscl::tests::DeadlineTimerAwaiter{ + sscl::ComponentThread::getSelf()->getIoContext(), + delayMilliseconds}; + sscl::tests::throwIfTimerWaitFailed(waitError); + throw expectedNonStdThrowValue; +} + +CalleeIntInvoker returnLabelImmediately(int label) +{ + co_return label; +} + +CalleeVoidInvoker voidMemberAfterDelay(int delayMilliseconds) +{ + const boost::system::error_code waitError = + co_await sscl::tests::DeadlineTimerAwaiter{ + sscl::ComponentThread::getSelf()->getIoContext(), + delayMilliseconds}; + sscl::tests::throwIfTimerWaitFailed(waitError); + co_return; +} + +int readCompletedLabel(CalleeIntInvoker &invoker) +{ + return invoker.completedReturnValues().myReturnValue; +} + +void assertCompleted( + const sscl::co::Group::SettlementDescriptor &descriptor, + int expectedLabel) +{ + if (descriptor.type + != sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED) { + throw std::runtime_error("expected completed settlement"); + } + + if (readCompletedLabel(descriptor.invokerAs()) + != expectedLabel) { + throw std::runtime_error("settlement label mismatch"); + } +} + +void assertRuntimeErrorSettlement( + const sscl::co::Group::SettlementDescriptor &descriptor) +{ + if (descriptor.type + != sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN) { + throw std::runtime_error("expected exception settlement"); + } + + try { + std::rethrow_exception(descriptor.calleeException); + } + catch (const std::runtime_error &runtimeError) { + if (std::string(runtimeError.what()) != expectedThrowMessage) { + throw std::runtime_error("unexpected exception message"); + } + return; + } + + throw std::runtime_error("expected runtime_error settlement"); +} + +void assertIntExceptionSettlement( + const sscl::co::Group::SettlementDescriptor &descriptor) +{ + if (descriptor.type + != sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN) { + throw std::runtime_error("expected int exception settlement"); + } + + try { + std::rethrow_exception(descriptor.calleeException); + } + catch (int caughtValue) { + if (caughtValue != expectedNonStdThrowValue) { + throw std::runtime_error("unexpected int exception value"); + } + return; + } + + throw std::runtime_error("expected int exception settlement"); +} + +void assertEmptyGroupCoAwaitError(const std::runtime_error &runtimeError) +{ + constexpr const char *expectedEmptyGroupCoAwaitMessage = + "co_await: Group has no member invokers; call add() before awaiting"; + if (std::string(runtimeError.what()) != expectedEmptyGroupCoAwaitMessage) { + throw std::runtime_error("unexpected empty-group error message"); + } +} + +sscl::co::ViralNonPostingInvoker waitOnCallerThread(int delayMilliseconds) +{ + const boost::system::error_code waitError = + co_await sscl::tests::DeadlineTimerAwaiter{ + sscl::ComponentThread::getSelf()->getIoContext(), + delayMilliseconds}; + sscl::tests::throwIfTimerWaitFailed(waitError); + co_return; +} + +CallerDriver mixedSuccessAndFailureAwaitFirstThenAll( + std::exception_ptr &exceptionPtr, + std::function completion) +{ + (void)exceptionPtr; + (void)completion; + + sscl::co::Group group; + CalleeIntInvoker successInvoker = waitAndReturnLabel(1); + CalleeIntInvoker failureInvoker = waitThenThrowAfterDelay(delayShortMs); + + group.add(successInvoker); + group.add(failureInvoker); + + auto awaitFirst = group.getAwaitFirstSettlementInvoker(); + auto [firstDescriptor, allAfterFirst] = co_await awaitFirst; + + if (firstDescriptor.type + == sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED) { + assertCompleted(firstDescriptor, 1); + } + else if (firstDescriptor.type + == sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN) { + assertRuntimeErrorSettlement(firstDescriptor); + } + else { + throw std::runtime_error("first settlement has unexpected type"); + } + + auto awaitAll = group.getAwaitAllSettlementsInvoker(); + auto &allDescriptors = co_await awaitAll; + + if (allDescriptors.size() != 2 || allAfterFirst.size() != 2) { + throw std::runtime_error("mixed settlement count mismatch"); + } + + std::size_t completedCount = 0; + std::size_t exceptionCount = 0; + + for (auto &descriptor : allDescriptors) { + if (descriptor.type + == sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED) { + ++completedCount; + assertCompleted(descriptor, 1); + } + else if (descriptor.type + == sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN) { + ++exceptionCount; + assertRuntimeErrorSettlement(descriptor); + } + } + + if (completedCount != 1 || exceptionCount != 1) { + throw std::runtime_error("mixed settlement type counts mismatch"); + } + + co_return; +} + +CallerDriver singleMemberAwaitFirstThenAll( + std::exception_ptr &exceptionPtr, + std::function completion) +{ + (void)exceptionPtr; + (void)completion; + + sscl::co::Group group; + CalleeIntInvoker onlyInvoker = waitAndReturnLabel(delayShortMs); + group.add(onlyInvoker); + + auto awaitFirst = group.getAwaitFirstSettlementInvoker(); + auto [firstDescriptor, allAfterFirst] = co_await awaitFirst; + assertCompleted(firstDescriptor, delayShortMs); + + if (!group.allInvokersSettled() || allAfterFirst.size() != 1) { + throw std::runtime_error("single member state mismatch"); + } + + auto awaitAll = group.getAwaitAllSettlementsInvoker(); + auto &allDescriptors = co_await awaitAll; + + if (allDescriptors.size() != 1) { + throw std::runtime_error("single member await-all count mismatch"); + } + + assertCompleted(allDescriptors[0], delayShortMs); + co_return; +} + +CallerDriver allCompleteBeforeCoAwait( + std::exception_ptr &exceptionPtr, + std::function completion) +{ + (void)exceptionPtr; + (void)completion; + + sscl::co::Group group; + CalleeIntInvoker invokerTen = returnLabelImmediately(10); + CalleeIntInvoker invokerTwenty = returnLabelImmediately(20); + CalleeIntInvoker invokerThirty = returnLabelImmediately(30); + + group.add(invokerTen); + group.add(invokerTwenty); + group.add(invokerThirty); + + co_await waitOnCallerThread(delayShortMs); + + if (!group.allInvokersSettled() || !group.firstInvokerSettled()) { + throw std::runtime_error("immediate group did not settle before await"); + } + + auto awaitFirst = group.getAwaitFirstSettlementInvoker(); + auto [firstDescriptor, allAfterFirst] = co_await awaitFirst; + assertCompleted(firstDescriptor, 10); + + auto awaitAll = group.getAwaitAllSettlementsInvoker(); + auto &allDescriptors = co_await awaitAll; + + if (allDescriptors.size() != 3 || allAfterFirst.size() != 3) { + throw std::runtime_error("immediate settlement count mismatch"); + } + + co_return; +} + +std::thread startAddWhileGroupAwaiterSuspendedProbe( + sscl::co::Group &group, + CalleeIntInvoker &lateInvoker, + std::atomic &groupIsAwaitingAll, + std::atomic &addWasRejected) +{ + return std::thread( + [&]() + { + while (!groupIsAwaitingAll.load(std::memory_order_acquire)) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + + std::this_thread::sleep_for( + std::chrono::milliseconds(delayAddWhileSuspendedProbeMs)); + + boost::asio::post( + sscl::tests::ThreadRegistry::ioContext( + sscl::tests::PostingThreadRole::CALLER), + [&]() + { + try { + group.add(lateInvoker); + } + catch (const std::runtime_error &) { + addWasRejected.store(true, std::memory_order_release); + } + }); + }); +} + +CallerDriver addWhileAwaitAllSuspended( + std::exception_ptr &exceptionPtr, + std::function completion) +{ + (void)exceptionPtr; + (void)completion; + + sscl::co::Group group; + std::atomic groupIsAwaitingAll{false}; + std::atomic addWasRejected{false}; + + CalleeIntInvoker slowInvokerA = waitAndReturnLabel(delayLongMs); + CalleeIntInvoker slowInvokerB = waitAndReturnLabel(delayLongMs); + CalleeIntInvoker lateInvoker = waitAndReturnLabel(99); + + group.add(slowInvokerA); + group.add(slowInvokerB); + + std::thread addProbeThread = startAddWhileGroupAwaiterSuspendedProbe( + group, + lateInvoker, + groupIsAwaitingAll, + addWasRejected); + + auto awaitAll = group.getAwaitAllSettlementsInvoker(); + groupIsAwaitingAll.store(true, std::memory_order_release); + co_await awaitAll; + + addProbeThread.join(); + + if (!addWasRejected.load(std::memory_order_acquire)) { + throw std::runtime_error("expected add while suspended to throw"); + } + + co_return; +} + +CallerDriver awaitAllOnlyMixedOutcomes( + std::exception_ptr &exceptionPtr, + std::function completion) +{ + (void)exceptionPtr; + (void)completion; + + sscl::co::Group group; + CalleeIntInvoker successInvoker = returnLabelImmediately(7); + CalleeIntInvoker failureInvoker = waitThenThrowAfterDelay(delayShortMs); + + group.add(successInvoker); + group.add(failureInvoker); + + auto awaitAll = group.getAwaitAllSettlementsInvoker(); + auto &allDescriptors = co_await awaitAll; + + if (allDescriptors.size() != 2) { + throw std::runtime_error("await-all-only count mismatch"); + } + + std::size_t completedCount = 0; + std::size_t exceptionCount = 0; + + for (auto &descriptor : allDescriptors) { + if (descriptor.type + == sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED) { + ++completedCount; + assertCompleted(descriptor, 7); + } + else if (descriptor.type + == sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN) { + ++exceptionCount; + assertRuntimeErrorSettlement(descriptor); + } + } + + if (completedCount != 1 || exceptionCount != 1) { + throw std::runtime_error("await-all-only mixed counts mismatch"); + } + + co_return; +} + +CallerDriver checkForAndReThrowGroupExceptions( + std::exception_ptr &exceptionPtr, + std::function completion) +{ + (void)exceptionPtr; + (void)completion; + + sscl::co::Group group; + CalleeIntInvoker failureInvoker = waitThenThrowAfterDelay(delayShortMs); + group.add(failureInvoker); + + (void)co_await group.getAwaitAllSettlementsInvoker(); + + try { + group.checkForAndReThrowGroupExceptions(); + } + catch (const std::runtime_error &aggregateError) { + if (std::string(aggregateError.what()).find(expectedThrowMessage) + == std::string::npos) { + throw std::runtime_error("aggregate message missing callee text"); + } + co_return; + } + + throw std::runtime_error("expected aggregate group exception"); +} + +CallerDriver emptyGroupAwaitAllThrows( + std::exception_ptr &exceptionPtr, + std::function completion) +{ + (void)exceptionPtr; + (void)completion; + + sscl::co::Group group; + + try { + (void)co_await group.getAwaitAllSettlementsInvoker(); + } + catch (const std::runtime_error &runtimeError) { + assertEmptyGroupCoAwaitError(runtimeError); + co_return; + } + + throw std::runtime_error("expected empty group await-all to throw"); +} + +CallerDriver emptyGroupAwaitFirstThrows( + std::exception_ptr &exceptionPtr, + std::function completion) +{ + (void)exceptionPtr; + (void)completion; + + sscl::co::Group group; + + try { + (void)co_await group.getAwaitFirstSettlementInvoker(); + } + catch (const std::runtime_error &runtimeError) { + assertEmptyGroupCoAwaitError(runtimeError); + co_return; + } + + throw std::runtime_error("expected empty group await-first to throw"); +} + +CallerDriver wrongAwaitInvokerOrder( + std::exception_ptr &exceptionPtr, + std::function completion) +{ + (void)exceptionPtr; + (void)completion; + + sscl::co::Group group; + CalleeIntInvoker shortInvoker = waitAndReturnLabel(delayShortMs); + CalleeIntInvoker mediumInvoker = waitAndReturnLabel(delayMediumMs); + + group.add(shortInvoker); + group.add(mediumInvoker); + + auto awaitFirstHandle = group.getAwaitFirstSettlementInvoker(); + auto awaitAllHandle = group.getAwaitAllSettlementsInvoker(); + + auto &allDescriptors = co_await awaitAllHandle; + if (allDescriptors.size() != 2) { + throw std::runtime_error("wrong-order await-all count mismatch"); + } + + auto [firstDescriptor, allAfterFirst] = co_await awaitFirstHandle; + assertCompleted( + firstDescriptor, + readCompletedLabel(firstDescriptor.invokerAs())); + + if (!group.firstInvokerSettled() || allAfterFirst.size() != 2) { + throw std::runtime_error("wrong-order await-first state mismatch"); + } + + co_return; +} + +CallerDriver doubleCoAwaitSameAwaitFirst( + std::exception_ptr &exceptionPtr, + std::function completion) +{ + (void)exceptionPtr; + (void)completion; + + sscl::co::Group group; + CalleeIntInvoker memberInvoker = returnLabelImmediately(delayShortMs); + group.add(memberInvoker); + + co_await waitOnCallerThread(delayShortMs); + + auto awaitFirst = group.getAwaitFirstSettlementInvoker(); + auto [firstDescriptorA, allAfterFirstA] = co_await awaitFirst; + auto [firstDescriptorB, allAfterFirstB] = co_await awaitFirst; + + assertCompleted(firstDescriptorA, delayShortMs); + assertCompleted(firstDescriptorB, delayShortMs); + + if (&firstDescriptorA.invokerAs() + != &firstDescriptorB.invokerAs()) { + throw std::runtime_error("double await-first descriptor mismatch"); + } + + if (allAfterFirstA.size() != allAfterFirstB.size()) { + throw std::runtime_error("double await-first snapshot mismatch"); + } + + co_return; +} + +CallerDriver doubleCoAwaitSameAwaitAll( + std::exception_ptr &exceptionPtr, + std::function completion) +{ + (void)exceptionPtr; + (void)completion; + + sscl::co::Group group; + CalleeIntInvoker memberInvoker = waitAndReturnLabel(delayShortMs); + group.add(memberInvoker); + + auto awaitAll = group.getAwaitAllSettlementsInvoker(); + auto &allDescriptorsA = co_await awaitAll; + auto &allDescriptorsB = co_await awaitAll; + + if (allDescriptorsA.size() != 1 || allDescriptorsB.size() != 1) { + throw std::runtime_error("double await-all count mismatch"); + } + + assertCompleted(allDescriptorsA[0], delayShortMs); + assertCompleted(allDescriptorsB[0], delayShortMs); + co_return; +} + +CallerDriver twoAwaitFirstHandlesSequentially( + std::exception_ptr &exceptionPtr, + std::function completion) +{ + (void)exceptionPtr; + (void)completion; + + sscl::co::Group group; + CalleeIntInvoker shortInvoker = waitAndReturnLabel(delayShortMs); + CalleeIntInvoker mediumInvoker = waitAndReturnLabel(delayMediumMs); + + group.add(shortInvoker); + group.add(mediumInvoker); + + auto awaitFirstA = group.getAwaitFirstSettlementInvoker(); + auto [firstDescriptorA, allAfterFirstA] = co_await awaitFirstA; + assertCompleted(firstDescriptorA, delayShortMs); + + auto awaitFirstB = group.getAwaitFirstSettlementInvoker(); + auto [firstDescriptorB, allAfterFirstB] = co_await awaitFirstB; + assertCompleted(firstDescriptorB, delayShortMs); + + if (&firstDescriptorA.invokerAs() + != &firstDescriptorB.invokerAs()) { + throw std::runtime_error("sticky first settlement mismatch"); + } + + (void)co_await group.getAwaitAllSettlementsInvoker(); + (void)allAfterFirstA; + (void)allAfterFirstB; + co_return; +} + +CallerDriver addSecondWaveAfterAwaitAll( + std::exception_ptr &exceptionPtr, + std::function completion) +{ + (void)exceptionPtr; + (void)completion; + + sscl::co::Group group; + CalleeIntInvoker wave1MemberA = waitAndReturnLabel(delayLongMs); + CalleeIntInvoker wave1MemberB = waitAndReturnLabel(delayLongMs); + + group.add(wave1MemberA); + group.add(wave1MemberB); + (void)co_await group.getAwaitAllSettlementsInvoker(); + + CalleeIntInvoker wave2Immediate = + returnLabelImmediately(wave2ImmediateSettlementLabel); + CalleeIntInvoker wave2Slow = waitAndReturnLabel(delayMediumMs); + + group.add(wave2Immediate); + group.add(wave2Slow); + + co_await waitOnCallerThread(delayShortMs); + + if (readCompletedLabel(wave2Immediate) + != wave2ImmediateSettlementLabel) { + throw std::runtime_error("wave-2 immediate member did not complete"); + } + + if (group.allInvokersSettled()) { + throw std::runtime_error("wave-2 slow member should still be in flight"); + } + + auto &allDescriptors = + co_await group.getAwaitAllSettlementsInvoker(); + + if (allDescriptors.size() != 4) { + throw std::runtime_error("expected four settlements after second wave"); + } + + co_return; +} + +CallerDriver shortTimerAddedAfterLongStillWinsRace( + std::exception_ptr &exceptionPtr, + std::function completion) +{ + (void)exceptionPtr; + (void)completion; + + sscl::co::Group group; + CalleeIntInvoker longInvoker = waitAndReturnLabel(delayLongMs); + CalleeIntInvoker shortInvoker = waitAndReturnLabel(delayShortMs); + + group.add(longInvoker); + group.add(shortInvoker); + + auto awaitFirst = group.getAwaitFirstSettlementInvoker(); + auto [firstDescriptor, allAfterFirst] = co_await awaitFirst; + + assertCompleted(firstDescriptor, delayShortMs); + + if (&firstDescriptor.invokerAs() != &shortInvoker) { + throw std::runtime_error("short timer should win first settlement"); + } + + (void)co_await group.getAwaitAllSettlementsInvoker(); + (void)allAfterFirst; + co_return; +} + +CallerDriver nonStdExceptionSettlement( + std::exception_ptr &exceptionPtr, + std::function completion) +{ + (void)exceptionPtr; + (void)completion; + + sscl::co::Group group; + CalleeIntInvoker failureInvoker = waitThenThrowIntAfterDelay(delayShortMs); + group.add(failureInvoker); + + auto &allDescriptors = co_await group.getAwaitAllSettlementsInvoker(); + + if (allDescriptors.size() != 1) { + throw std::runtime_error("non-std exception count mismatch"); + } + + assertIntExceptionSettlement(allDescriptors[0]); + + try { + group.checkForAndReThrowGroupExceptions(); + } + catch (const std::runtime_error &) { + co_return; + } + + throw std::runtime_error("expected aggregate for non-std exception"); +} + +CallerDriver voidViralMemberInGroup( + std::exception_ptr &exceptionPtr, + std::function completion) +{ + (void)exceptionPtr; + (void)completion; + + sscl::co::Group group; + CalleeVoidInvoker voidInvoker = voidMemberAfterDelay(delayShortMs); + group.add(voidInvoker); + + auto &allDescriptors = co_await group.getAwaitAllSettlementsInvoker(); + + if (allDescriptors.size() != 1) { + throw std::runtime_error("void group count mismatch"); + } + + if (allDescriptors[0].type + != sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED) { + throw std::runtime_error("void member did not complete"); + } + + co_return; +} + +CallerDriver returnValuesRemainReadableAfterAwaitFirst( + std::exception_ptr &exceptionPtr, + std::function completion) +{ + (void)exceptionPtr; + (void)completion; + + sscl::co::Group group; + CalleeIntInvoker slowInvoker = waitAndReturnLabel(delayLongMs); + CalleeIntInvoker fastInvoker = waitAndReturnLabel(delayShortMs); + + group.add(slowInvoker); + group.add(fastInvoker); + + auto awaitFirst = group.getAwaitFirstSettlementInvoker(); + auto [firstDescriptor, allAfterFirst] = co_await awaitFirst; + + assertCompleted(firstDescriptor, delayShortMs); + + const int fastLabelFromDescriptor = readCompletedLabel( + firstDescriptor.invokerAs()); + const int fastLabelFromLocal = readCompletedLabel(fastInvoker); + + if (fastLabelFromDescriptor != fastLabelFromLocal) { + throw std::runtime_error("descriptor/local return value mismatch"); + } + + if (allAfterFirst.size() != 2) { + throw std::runtime_error("expected two settlement slots"); + } + + (void)co_await group.getAwaitAllSettlementsInvoker(); + co_return; +} + +class GroupEdgeTest +: public ::testing::Test +{ +protected: + template + void runScenario(Factory &&factory) + { + ASSERT_NO_THROW( + sscl::tests::runNonViralPostingTask( + threads.caller(), + std::forward(factory))); + } + + sscl::tests::PostingThreadSet threads; +}; + +} // namespace + +#define RUN_GROUP_EDGE_SCENARIO(testName, functionName) \ + TEST_F(GroupEdgeTest, testName) \ + { \ + runScenario( \ + []( \ + std::exception_ptr &exceptionPtr, \ + std::function completion) \ + { \ + return functionName(exceptionPtr, std::move(completion)); \ + }); \ + } + +RUN_GROUP_EDGE_SCENARIO( + MixedSuccessAndFailureAwaitFirstThenAll, + mixedSuccessAndFailureAwaitFirstThenAll) +RUN_GROUP_EDGE_SCENARIO( + SingleMemberAwaitFirstThenAll, + singleMemberAwaitFirstThenAll) +RUN_GROUP_EDGE_SCENARIO(AllCompleteBeforeCoAwait, allCompleteBeforeCoAwait) +RUN_GROUP_EDGE_SCENARIO(AddWhileAwaitAllSuspended, addWhileAwaitAllSuspended) +RUN_GROUP_EDGE_SCENARIO(AwaitAllOnlyMixedOutcomes, awaitAllOnlyMixedOutcomes) +RUN_GROUP_EDGE_SCENARIO( + CheckForAndReThrowGroupExceptions, + checkForAndReThrowGroupExceptions) +RUN_GROUP_EDGE_SCENARIO(EmptyGroupAwaitAllThrows, emptyGroupAwaitAllThrows) +RUN_GROUP_EDGE_SCENARIO(EmptyGroupAwaitFirstThrows, emptyGroupAwaitFirstThrows) +RUN_GROUP_EDGE_SCENARIO(WrongAwaitInvokerOrder, wrongAwaitInvokerOrder) +RUN_GROUP_EDGE_SCENARIO(DoubleCoAwaitSameAwaitFirst, doubleCoAwaitSameAwaitFirst) +RUN_GROUP_EDGE_SCENARIO(DoubleCoAwaitSameAwaitAll, doubleCoAwaitSameAwaitAll) +RUN_GROUP_EDGE_SCENARIO( + TwoAwaitFirstHandlesSequentially, + twoAwaitFirstHandlesSequentially) +RUN_GROUP_EDGE_SCENARIO(AddSecondWaveAfterAwaitAll, addSecondWaveAfterAwaitAll) +RUN_GROUP_EDGE_SCENARIO( + ShortTimerAddedAfterLongStillWinsRace, + shortTimerAddedAfterLongStillWinsRace) +RUN_GROUP_EDGE_SCENARIO(NonStdExceptionSettlement, nonStdExceptionSettlement) +RUN_GROUP_EDGE_SCENARIO(VoidViralMemberInGroup, voidViralMemberInGroup) +RUN_GROUP_EDGE_SCENARIO( + ReturnValuesRemainReadableAfterAwaitFirst, + returnValuesRemainReadableAfterAwaitFirst) + +TEST_F(GroupEdgeTest, NonViralVoidGroupTemplateInstantiates) +{ + GTEST_SKIP() + << "NonViralPostingInvoker does not satisfy Group's awaitable concept."; +} + +TEST_F(GroupEdgeTest, EarlyInvokerDestructionIsUnsupported) +{ + GTEST_SKIP() + << "Destroying a member invoker before group settlement completes is undefined."; +} + +TEST_F(GroupEdgeTest, OverlappingGroupWaitsAssertInDebug) +{ + GTEST_SKIP() + << "Overlapping group co_await is debug-assert behavior."; +} diff --git a/tests/co/group_timer_tests.cpp b/tests/co/group_timer_tests.cpp new file mode 100644 index 0000000..1cb3acf --- /dev/null +++ b/tests/co/group_timer_tests.cpp @@ -0,0 +1,268 @@ +#include +#include +#include +#include +#include + +#include + +#include +#include + +#include +#include + +#include +#include +#include + +namespace { + +constexpr int timerDelayShortMs = 50; +constexpr int timerDelayMediumMs = 200; +constexpr int timerDelayLongMs = 500; +constexpr int awaitAllTimingSlackMs = 25; +constexpr int awaitAllLongCancelTimingMarginMs = 50; + +using CallerDriver = + sscl::tests::RoleNonViralPostingInvoker< + sscl::tests::PostingThreadRole::CALLER>; + +using CalleeIntInvoker = + sscl::tests::RoleViralPostingInvoker< + sscl::tests::PostingThreadRole::CALLEE, + int>; + +using Clock = std::chrono::steady_clock; +using Ms = std::chrono::milliseconds; + +CalleeIntInvoker waitDeadlineTimer(int timerLabelMilliseconds) +{ + const boost::system::error_code waitError = + co_await sscl::tests::DeadlineTimerAwaiter{ + sscl::ComponentThread::getSelf()->getIoContext(), + timerLabelMilliseconds}; + sscl::tests::throwIfTimerWaitFailed(waitError); + co_return timerLabelMilliseconds; +} + +CalleeIntInvoker waitCancelableDeadlineTimer( + int timerLabelMilliseconds, + sscl::tests::CancelableDeadlineTimerRegistry ®istry) +{ + const boost::system::error_code waitError = + co_await sscl::tests::RegisteredDeadlineTimerAwaiter{ + sscl::ComponentThread::getSelf()->getIoContext(), + timerLabelMilliseconds, + timerLabelMilliseconds, + registry}; + + if (sscl::tests::timerWasCanceled(waitError)) { + co_return timerLabelMilliseconds; + } + + sscl::tests::throwIfTimerWaitFailed(waitError); + co_return timerLabelMilliseconds; +} + +void throwIfElapsedTooLong( + const Ms &elapsed, + const Ms &limit, + const char *message) +{ + if (elapsed > limit) { + throw std::runtime_error( + std::string(message) + ": " + std::to_string(elapsed.count())); + } +} + +void throwIfElapsedTooShort( + const Ms &elapsed, + const Ms &limit, + const char *message) +{ + if (elapsed < limit) { + throw std::runtime_error( + std::string(message) + ": " + std::to_string(elapsed.count())); + } +} + +CallerDriver runGroupTimerRace( + std::exception_ptr &exceptionPtr, + std::function completion) +{ + (void)exceptionPtr; + (void)completion; + + sscl::co::Group group; + CalleeIntInvoker invokerShort = waitDeadlineTimer(timerDelayShortMs); + CalleeIntInvoker invokerMedium = waitDeadlineTimer(timerDelayMediumMs); + CalleeIntInvoker invokerLong = waitDeadlineTimer(timerDelayLongMs); + + group.add(invokerShort); + group.add(invokerMedium); + group.add(invokerLong); + + const auto testStart = Clock::now(); + + auto awaitFirst = group.getAwaitFirstSettlementInvoker(); + auto [firstSettlement, allSettlementsAfterFirst] = co_await awaitFirst; + + const auto firstElapsedMs = + std::chrono::duration_cast(Clock::now() - testStart); + throwIfElapsedTooLong( + firstElapsedMs, + Ms(timerDelayMediumMs - awaitAllTimingSlackMs), + "await-first took too long"); + + if (&firstSettlement.invokerAs() != &invokerShort) { + throw std::runtime_error("first settlement was not shortest timer"); + } + + if (group.allInvokersSettled()) { + throw std::runtime_error("await-first returned after all settled"); + } + + auto awaitAll = group.getAwaitAllSettlementsInvoker(); + auto &allSettlements = co_await awaitAll; + + const auto allElapsedMs = + std::chrono::duration_cast(Clock::now() - testStart); + throwIfElapsedTooShort( + allElapsedMs, + Ms(timerDelayLongMs - awaitAllLongCancelTimingMarginMs), + "await-all finished too soon"); + + if (allSettlements.size() != 3) { + throw std::runtime_error("expected three settlements"); + } + + sscl::tests::expectCompletedIntSettlement( + firstSettlement, + timerDelayShortMs); + sscl::tests::expectCompletedIntSettlement( + allSettlementsAfterFirst[0], + timerDelayShortMs); + sscl::tests::expectCompletedIntSettlement( + allSettlementsAfterFirst[1], + timerDelayMediumMs); + sscl::tests::expectCompletedIntSettlement( + allSettlementsAfterFirst[2], + timerDelayLongMs); + + co_return; +} + +CallerDriver runGroupTimerCancelLongAfterAwaitFirst( + std::exception_ptr &exceptionPtr, + std::function completion, + sscl::tests::CancelableDeadlineTimerRegistry ®istry) +{ + (void)exceptionPtr; + (void)completion; + + sscl::co::Group group; + CalleeIntInvoker invokerShort = + waitCancelableDeadlineTimer(timerDelayShortMs, registry); + CalleeIntInvoker invokerMedium = + waitCancelableDeadlineTimer(timerDelayMediumMs, registry); + CalleeIntInvoker invokerLong = + waitCancelableDeadlineTimer(timerDelayLongMs, registry); + + group.add(invokerShort); + group.add(invokerMedium); + group.add(invokerLong); + + const auto testStart = Clock::now(); + + auto awaitFirst = group.getAwaitFirstSettlementInvoker(); + auto [firstSettlement, allSettlementsAfterFirst] = co_await awaitFirst; + + if (&firstSettlement.invokerAs() != &invokerShort) { + throw std::runtime_error("cancel test first settlement mismatch"); + } + + if (group.allInvokersSettled()) { + throw std::runtime_error("cancel test all settled after await-first"); + } + + registry.cancel(timerDelayLongMs); + + auto awaitAll = group.getAwaitAllSettlementsInvoker(); + auto &allSettlements = co_await awaitAll; + + const auto allElapsedMs = + std::chrono::duration_cast(Clock::now() - testStart); + + if (allElapsedMs >= Ms(timerDelayLongMs - awaitAllLongCancelTimingMarginMs)) { + throw std::runtime_error("await-all waited for canceled long timer"); + } + + throwIfElapsedTooShort( + allElapsedMs, + Ms(timerDelayMediumMs - awaitAllTimingSlackMs), + "await-all finished before medium timer"); + + if (allSettlements.size() != 3) { + throw std::runtime_error("cancel test expected three settlements"); + } + + sscl::tests::expectCompletedIntSettlement( + allSettlements[0], + timerDelayShortMs); + sscl::tests::expectCompletedIntSettlement( + allSettlements[1], + timerDelayMediumMs); + sscl::tests::expectCompletedIntSettlement( + allSettlements[2], + timerDelayLongMs); + + if (&allSettlements[2].invokerAs() != &invokerLong) { + throw std::runtime_error("cancel test long invoker mismatch"); + } + + (void)allSettlementsAfterFirst; + co_return; +} + +class GroupTimerTest +: public ::testing::Test +{ +protected: + sscl::tests::PostingThreadSet threads; +}; + +} // namespace + +TEST_F(GroupTimerTest, AwaitFirstReturnsShortestTimerAndAwaitAllWaitsForLongest) +{ + ASSERT_NO_THROW( + sscl::tests::runNonViralPostingTask( + threads.caller(), + []( + std::exception_ptr &exceptionPtr, + std::function completion) + { + return runGroupTimerRace( + exceptionPtr, + std::move(completion)); + })); +} + +TEST_F(GroupTimerTest, CancelLongTimerAfterAwaitFirst) +{ + sscl::tests::CancelableDeadlineTimerRegistry registry; + + ASSERT_NO_THROW( + sscl::tests::runNonViralPostingTask( + threads.caller(), + [®istry]( + std::exception_ptr &exceptionPtr, + std::function completion) + { + return runGroupTimerCancelLongAfterAwaitFirst( + exceptionPtr, + std::move(completion), + registry); + })); +} diff --git a/tests/co/posting_cross_thread_tests.cpp b/tests/co/posting_cross_thread_tests.cpp new file mode 100644 index 0000000..b7e7cf0 --- /dev/null +++ b/tests/co/posting_cross_thread_tests.cpp @@ -0,0 +1,252 @@ +#include +#include +#include +#include + +#include + +#include +#include + +#include +#include + +namespace { + +constexpr int expectedReturnValue = 42; +constexpr int explicitTargetReturnValue = 77; +constexpr const char *expectedThrowMessage = + "posting cross-thread intentional failure"; + +using CallerNonViralInvoker = + sscl::tests::RoleNonViralPostingInvoker< + sscl::tests::PostingThreadRole::CALLER>; +using CalleeNonViralInvoker = + sscl::tests::RoleNonViralPostingInvoker< + sscl::tests::PostingThreadRole::CALLEE>; + +template +using CalleeViralInvoker = + sscl::tests::RoleViralPostingInvoker< + sscl::tests::PostingThreadRole::CALLEE, + T>; + +CalleeViralInvoker returnFromCalleeThread( + sscl::tests::CrossThreadTrace &trace) +{ + trace.recordCalleeExecutionThread(); + trace.recordFinalSuspendThread(); + co_return expectedReturnValue; +} + +CalleeViralInvoker returnFromExplicitTargetThread( + sscl::co::ExplicitPostTarget postTarget, + sscl::tests::CrossThreadTrace &trace) +{ + (void)postTarget; + trace.recordCalleeExecutionThread(); + trace.recordFinalSuspendThread(); + co_return explicitTargetReturnValue; +} + +CalleeViralInvoker throwFromCalleeThread( + sscl::tests::CrossThreadTrace &trace) +{ + constexpr int throwDelayMs = 1; + + const boost::system::error_code waitError = + co_await sscl::tests::DeadlineTimerAwaiter{ + sscl::ComponentThread::getSelf()->getIoContext(), + throwDelayMs}; + sscl::tests::throwIfTimerWaitFailed(waitError); + trace.recordCalleeExecutionThread(); + trace.recordFinalSuspendThread(); + throw std::runtime_error(expectedThrowMessage); +} + +CallerNonViralInvoker awaitCalleeDriver( + std::exception_ptr &exceptionPtr, + std::function completion, + sscl::tests::CrossThreadTrace &trace) +{ + (void)exceptionPtr; + (void)completion; + + const int value = co_await returnFromCalleeThread(trace); + trace.recordAwaitResumeThread(); + + if (value != expectedReturnValue) { + throw std::runtime_error("Unexpected callee return value"); + } + + co_return; +} + +CallerNonViralInvoker awaitExplicitTargetDriver( + std::exception_ptr &exceptionPtr, + std::function completion, + sscl::tests::CrossThreadTrace &trace) +{ + (void)exceptionPtr; + (void)completion; + + sscl::co::ExplicitPostTarget postTarget{ + sscl::tests::ThreadRegistry::ioContext( + sscl::tests::PostingThreadRole::ALTERNATE)}; + const int value = co_await returnFromExplicitTargetThread( + postTarget, + trace); + trace.recordAwaitResumeThread(); + + if (value != explicitTargetReturnValue) { + throw std::runtime_error("Unexpected explicit-target return value"); + } + + co_return; +} + +CallerNonViralInvoker awaitThrowingCalleeDriver( + std::exception_ptr &exceptionPtr, + std::function completion, + sscl::tests::CrossThreadTrace &trace) +{ + (void)exceptionPtr; + (void)completion; + + try { + (void)co_await throwFromCalleeThread(trace); + throw std::runtime_error("Expected callee exception"); + } + catch (const std::runtime_error &runtimeError) { + trace.recordAwaitResumeThread(); + if (std::string(runtimeError.what()) != expectedThrowMessage) { + throw std::runtime_error("Unexpected callee exception message"); + } + } + + co_return; +} + +CalleeNonViralInvoker nonViralCalleeCompletesToCaller( + std::exception_ptr &exceptionPtr, + std::function completion, + sscl::tests::CrossThreadTrace &trace) +{ + (void)exceptionPtr; + (void)completion; + trace.recordCalleeExecutionThread(); + trace.recordFinalSuspendThread(); + co_return; +} + +class PostingCrossThreadTest +: public ::testing::Test +{ +protected: + sscl::tests::PostingThreadSet threads; +}; + +} // namespace + +TEST_F(PostingCrossThreadTest, ViralAwaitPostsCalleeAndResumesCaller) +{ + sscl::tests::CrossThreadTrace trace; + + ASSERT_NO_THROW( + sscl::tests::runNonViralPostingTask( + threads.caller(), + [&trace]( + std::exception_ptr &exceptionPtr, + std::function completion) + { + trace.recordConstructionThread(); + return awaitCalleeDriver( + exceptionPtr, + std::move(completion), + trace); + })); + + EXPECT_EQ(trace.constructionThread(), threads.caller().osThreadId()); + EXPECT_EQ(trace.calleeExecutionThread(), threads.callee().osThreadId()); + EXPECT_EQ(trace.finalSuspendThread(), threads.callee().osThreadId()); + EXPECT_EQ(trace.awaitResumeThread(), threads.caller().osThreadId()); + EXPECT_NE(trace.calleeExecutionThread(), trace.awaitResumeThread()); +} + +TEST_F(PostingCrossThreadTest, NonViralCompletionPostsBackToCaller) +{ + sscl::tests::CrossThreadTrace trace; + + ASSERT_NO_THROW( + sscl::tests::runNonViralPostingTask( + threads.caller(), + [&trace]( + std::exception_ptr &exceptionPtr, + std::function completion) + { + trace.recordConstructionThread(); + return nonViralCalleeCompletesToCaller( + exceptionPtr, + [&trace, completion = std::move(completion)]() mutable + { + trace.recordCompletionCallbackThread(); + completion(); + }, + trace); + })); + + EXPECT_EQ(trace.constructionThread(), threads.caller().osThreadId()); + EXPECT_EQ(trace.calleeExecutionThread(), threads.callee().osThreadId()); + EXPECT_EQ(trace.finalSuspendThread(), threads.callee().osThreadId()); + EXPECT_EQ(trace.completionCallbackThread(), threads.caller().osThreadId()); + EXPECT_NE(trace.calleeExecutionThread(), trace.completionCallbackThread()); +} + +TEST_F(PostingCrossThreadTest, ExplicitPostTargetRoutesCalleeExecution) +{ + sscl::tests::CrossThreadTrace trace; + + ASSERT_NO_THROW( + sscl::tests::runNonViralPostingTask( + threads.caller(), + [&trace]( + std::exception_ptr &exceptionPtr, + std::function completion) + { + trace.recordConstructionThread(); + return awaitExplicitTargetDriver( + exceptionPtr, + std::move(completion), + trace); + })); + + EXPECT_EQ(trace.constructionThread(), threads.caller().osThreadId()); + EXPECT_EQ(trace.calleeExecutionThread(), threads.alternate().osThreadId()); + EXPECT_EQ(trace.awaitResumeThread(), threads.caller().osThreadId()); + EXPECT_NE(trace.calleeExecutionThread(), threads.callee().osThreadId()); + EXPECT_NE(trace.calleeExecutionThread(), trace.awaitResumeThread()); +} + +TEST_F(PostingCrossThreadTest, CalleeExceptionIsObservedOnCallerThread) +{ + sscl::tests::CrossThreadTrace trace; + + ASSERT_NO_THROW( + sscl::tests::runNonViralPostingTask( + threads.caller(), + [&trace]( + std::exception_ptr &exceptionPtr, + std::function completion) + { + trace.recordConstructionThread(); + return awaitThrowingCalleeDriver( + exceptionPtr, + std::move(completion), + trace); + })); + + EXPECT_EQ(trace.constructionThread(), threads.caller().osThreadId()); + EXPECT_EQ(trace.calleeExecutionThread(), threads.callee().osThreadId()); + EXPECT_EQ(trace.awaitResumeThread(), threads.caller().osThreadId()); + EXPECT_NE(trace.calleeExecutionThread(), trace.awaitResumeThread()); +} diff --git a/tests/co/viral_non_posting_tests.cpp b/tests/co/viral_non_posting_tests.cpp new file mode 100644 index 0000000..d40c496 --- /dev/null +++ b/tests/co/viral_non_posting_tests.cpp @@ -0,0 +1,511 @@ +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +#include + +#include +#include + +namespace { + +constexpr int delayShortMs = 50; +constexpr int expectedNonStdThrowValue = 42; +constexpr const char *expectedThrowMessage = + "viral_non_posting_test intentional failure"; + +template +using TestInvoker = sscl::co::ViralNonPostingInvoker; + +using TestDriver = TestInvoker; +using TestVoidDriver = TestInvoker; + +struct ThreadIdPair +{ + std::thread::id callerIdAtCoAwait; + std::thread::id calleeId; +}; + +struct MoveCountedInt +{ + std::shared_ptr moveCount; + int value = 0; + + MoveCountedInt() = default; + + MoveCountedInt( + std::shared_ptr moveCountIn, + int valueIn) + : moveCount(std::move(moveCountIn)), + value(valueIn) + {} + + MoveCountedInt(const MoveCountedInt &) = delete; + MoveCountedInt &operator=(const MoveCountedInt &) = delete; + + MoveCountedInt(MoveCountedInt &&other) noexcept + : moveCount(std::exchange(other.moveCount, {})), + value(other.value) + { + if (moveCount) { + ++(*moveCount); + } + } + + MoveCountedInt &operator=(MoveCountedInt &&other) noexcept + { + moveCount = std::exchange(other.moveCount, {}); + value = other.value; + return *this; + } +}; + +template +struct CountingAwaiter +{ + TestInvoker &invoker; + std::size_t &awaitResumeCallCount; + + bool await_ready() const noexcept + { return invoker.await_ready(); } + + template + bool await_suspend( + std::coroutine_handle callerSchedHandle) noexcept + { return invoker.await_suspend(callerSchedHandle); } + + auto await_resume() + { + ++awaitResumeCallCount; + return invoker.await_resume(); + } +}; + +class ViralNonPostingTest +: public ::testing::Test +{ +protected: + void TearDown() override + { + ioContext.restart(); + } + + int runDriver(TestDriver &driver) + { + sscl::tests::IoContextPump::pumpUntilIdle(ioContext); + return finishDriver(driver); + } + + int finishDriver(TestDriver &driver) + { + if (driver.completedReturnValues().myExceptionPtr) { + std::rethrow_exception( + driver.completedReturnValues().myExceptionPtr); + } + + return driver.completedReturnValues().myReturnValue; + } + + boost::asio::io_context ioContext; +}; + +TestInvoker returnLabelImmediately(int label) +{ + co_return label; +} + +TestInvoker waitAndReturnLabel( + boost::asio::io_context &ioContext, + int delayMilliseconds) +{ + const boost::system::error_code waitError = + co_await sscl::tests::DeadlineTimerAwaiter{ + ioContext, + delayMilliseconds}; + sscl::tests::throwIfTimerWaitFailed(waitError); + co_return delayMilliseconds; +} + +TestVoidDriver voidReturnImmediately() +{ + co_return; +} + +TestInvoker throwRuntimeErrorImmediately() +{ + throw std::runtime_error(expectedThrowMessage); +} + +TestInvoker throwIntImmediately() +{ + throw expectedNonStdThrowValue; +} + +TestInvoker recordThreadIdsAtReturn() +{ + ThreadIdPair pair; + pair.calleeId = std::this_thread::get_id(); + co_return pair; +} + +TestInvoker recordThreadIdsAfterDelay( + boost::asio::io_context &ioContext, + int delayMilliseconds) +{ + const boost::system::error_code waitError = + co_await sscl::tests::DeadlineTimerAwaiter{ + ioContext, + delayMilliseconds}; + sscl::tests::throwIfTimerWaitFailed(waitError); + + ThreadIdPair pair; + pair.calleeId = std::this_thread::get_id(); + co_return pair; +} + +TestInvoker returnMoveCountedInt( + std::shared_ptr moveCount, + int value) +{ + co_return MoveCountedInt{std::move(moveCount), value}; +} + +TestInvoker innerDelayedCoAwait( + boost::asio::io_context &ioContext, + int delayMilliseconds) +{ + const int label = co_await waitAndReturnLabel( + ioContext, + delayMilliseconds); + co_return label; +} + +TestInvoker nestedNonPostingSum(int left, int right) +{ + const int leftSum = co_await returnLabelImmediately(left); + const int rightSum = co_await returnLabelImmediately(right); + co_return leftSum + rightSum; +} + +TestInvoker outerCoAwaitingDelayedInner( + boost::asio::io_context &ioContext, + int delayMilliseconds) +{ + const int innerLabel = co_await innerDelayedCoAwait( + ioContext, + delayMilliseconds); + co_return innerLabel + 1; +} + +TestDriver testImmediateReturnFastPath() +{ + const int value = co_await returnLabelImmediately(42); + if (value != 42) { + throw std::runtime_error("immediateReturnFastPath value mismatch"); + } + co_return 0; +} + +TestDriver testAllCompleteBeforeCoAwait() +{ + TestInvoker invokerTen = returnLabelImmediately(10); + TestInvoker invokerTwenty = returnLabelImmediately(20); + TestInvoker invokerThirty = returnLabelImmediately(30); + + const int valueTen = co_await invokerTen; + const int valueTwenty = co_await invokerTwenty; + const int valueThirty = co_await invokerThirty; + + if (valueTen != 10 || valueTwenty != 20 || valueThirty != 30) { + throw std::runtime_error("allCompleteBeforeCoAwait label mismatch"); + } + + co_return 0; +} + +TestDriver testCallerSuspendsThenResumes(boost::asio::io_context &ioContext) +{ + const int value = co_await waitAndReturnLabel(ioContext, delayShortMs); + if (value != delayShortMs) { + throw std::runtime_error("callerSuspendsThenResumes label mismatch"); + } + co_return 0; +} + +TestDriver testMixedImmediateAndDelayedInSequence( + boost::asio::io_context &ioContext) +{ + const int immediate = co_await returnLabelImmediately(7); + const int delayed = co_await waitAndReturnLabel(ioContext, delayShortMs); + + if (immediate != 7 || delayed != delayShortMs) { + throw std::runtime_error("mixedImmediateAndDelayed label mismatch"); + } + + co_return 0; +} + +TestDriver testAwaitResumeCalledOnceFastPath() +{ + std::size_t awaitResumeCallCount = 0; + TestInvoker invoker = returnLabelImmediately(42); + const int value = co_await CountingAwaiter{ + invoker, + awaitResumeCallCount}; + + if (value != 42 || awaitResumeCallCount != 1) { + throw std::runtime_error("fast path await_resume count mismatch"); + } + + co_return 0; +} + +TestDriver testAwaitResumeCalledOnceSlowPath( + boost::asio::io_context &ioContext) +{ + std::size_t awaitResumeCallCount = 0; + TestInvoker invoker = waitAndReturnLabel(ioContext, delayShortMs); + const int value = co_await CountingAwaiter{ + invoker, + awaitResumeCallCount}; + + if (value != delayShortMs || awaitResumeCallCount != 1) { + throw std::runtime_error("slow path await_resume count mismatch"); + } + + co_return 0; +} + +TestDriver testAwaitResumeCalledOnceNested( + boost::asio::io_context &ioContext) +{ + std::size_t awaitResumeCallCount = 0; + TestInvoker inner = innerDelayedCoAwait(ioContext, delayShortMs); + const int value = co_await CountingAwaiter{ + inner, + awaitResumeCallCount}; + + if (value != delayShortMs || awaitResumeCallCount != 1) { + throw std::runtime_error("nested await_resume count mismatch"); + } + + co_return 0; +} + +TestDriver testMoveCountedReturnNotDoubleMoved() +{ + auto moveCount = std::make_shared(0); + TestInvoker invoker = + returnMoveCountedInt(moveCount, 99); + MoveCountedInt result = co_await invoker; + + if (result.value != 99) { + throw std::runtime_error("move counted value mismatch"); + } + if (*moveCount > 2 || *moveCount < 1) { + throw std::runtime_error("move counted return move-count mismatch"); + } + + co_return 0; +} + +TestDriver testVoidReturnCompletes() +{ + co_await voidReturnImmediately(); + co_return 0; +} + +TestDriver testReturnValuesReadableBeforeDestroy() +{ + TestInvoker invoker = returnLabelImmediately(55); + (void)co_await invoker; + + if (invoker.completedReturnValues().myReturnValue != 55) { + throw std::runtime_error("completed return value not readable"); + } + + co_return 0; +} + +TestDriver testExceptionRethrowsOnCoAwait() +{ + try { + (void)co_await throwRuntimeErrorImmediately(); + throw std::runtime_error("expected runtime_error"); + } + catch (const std::runtime_error &runtimeError) { + if (std::string(runtimeError.what()) != expectedThrowMessage) { + throw std::runtime_error("unexpected runtime_error message"); + } + } + + co_return 0; +} + +TestDriver testNonStdExceptionRethrows() +{ + try { + (void)co_await throwIntImmediately(); + throw std::runtime_error("expected int exception"); + } + catch (int caughtValue) { + if (caughtValue != expectedNonStdThrowValue) { + throw std::runtime_error("unexpected int exception value"); + } + } + + co_return 0; +} + +TestDriver testCalleeRunsOnCallerThread() +{ + const std::thread::id callerThreadId = std::this_thread::get_id(); + const ThreadIdPair pair = co_await recordThreadIdsAtReturn(); + + if (pair.calleeId != callerThreadId) { + throw std::runtime_error("callee thread mismatch"); + } + + co_return 0; +} + +TestDriver testDelayedCalleeStillOnCallerThread( + boost::asio::io_context &ioContext) +{ + const std::thread::id callerThreadId = std::this_thread::get_id(); + const ThreadIdPair pair = + co_await recordThreadIdsAfterDelay(ioContext, delayShortMs); + + if (pair.calleeId != callerThreadId) { + throw std::runtime_error("delayed callee thread mismatch"); + } + + co_return 0; +} + +TestDriver testNestedNonPostingCoAwait() +{ + const int sum = co_await nestedNonPostingSum(10, 32); + if (sum != 42) { + throw std::runtime_error("nested sum mismatch"); + } + co_return 0; +} + +TestDriver testNestedInnerSuspension(boost::asio::io_context &ioContext) +{ + const int value = co_await outerCoAwaitingDelayedInner( + ioContext, + delayShortMs); + if (value != delayShortMs + 1) { + throw std::runtime_error("nested inner suspension value mismatch"); + } + co_return 0; +} + +} // namespace + +TEST_F(ViralNonPostingTest, ImmediateReturnFastPath) +{ + TestDriver driver = testImmediateReturnFastPath(); + EXPECT_NO_THROW({ EXPECT_EQ(finishDriver(driver), 0); }); +} + +TEST_F(ViralNonPostingTest, AllCompleteBeforeCoAwait) +{ + TestDriver driver = testAllCompleteBeforeCoAwait(); + EXPECT_NO_THROW({ EXPECT_EQ(finishDriver(driver), 0); }); +} + +TEST_F(ViralNonPostingTest, CallerSuspendsThenResumes) +{ + TestDriver driver = testCallerSuspendsThenResumes(ioContext); + EXPECT_NO_THROW({ EXPECT_EQ(runDriver(driver), 0); }); +} + +TEST_F(ViralNonPostingTest, MixedImmediateAndDelayedInSequence) +{ + TestDriver driver = testMixedImmediateAndDelayedInSequence(ioContext); + EXPECT_NO_THROW({ EXPECT_EQ(runDriver(driver), 0); }); +} + +TEST_F(ViralNonPostingTest, AwaitResumeCalledOnceFastPath) +{ + TestDriver driver = testAwaitResumeCalledOnceFastPath(); + EXPECT_NO_THROW({ EXPECT_EQ(finishDriver(driver), 0); }); +} + +TEST_F(ViralNonPostingTest, AwaitResumeCalledOnceSlowPath) +{ + TestDriver driver = testAwaitResumeCalledOnceSlowPath(ioContext); + EXPECT_NO_THROW({ EXPECT_EQ(runDriver(driver), 0); }); +} + +TEST_F(ViralNonPostingTest, AwaitResumeCalledOnceNested) +{ + TestDriver driver = testAwaitResumeCalledOnceNested(ioContext); + EXPECT_NO_THROW({ EXPECT_EQ(runDriver(driver), 0); }); +} + +TEST_F(ViralNonPostingTest, MoveCountedReturnNotDoubleMoved) +{ + TestDriver driver = testMoveCountedReturnNotDoubleMoved(); + EXPECT_NO_THROW({ EXPECT_EQ(finishDriver(driver), 0); }); +} + +TEST_F(ViralNonPostingTest, VoidReturnCompletes) +{ + TestDriver driver = testVoidReturnCompletes(); + EXPECT_NO_THROW({ EXPECT_EQ(finishDriver(driver), 0); }); +} + +TEST_F(ViralNonPostingTest, ReturnValuesReadableBeforeDestroy) +{ + TestDriver driver = testReturnValuesReadableBeforeDestroy(); + EXPECT_NO_THROW({ EXPECT_EQ(finishDriver(driver), 0); }); +} + +TEST_F(ViralNonPostingTest, ExceptionRethrowsOnCoAwait) +{ + TestDriver driver = testExceptionRethrowsOnCoAwait(); + EXPECT_NO_THROW({ EXPECT_EQ(finishDriver(driver), 0); }); +} + +TEST_F(ViralNonPostingTest, NonStdExceptionRethrows) +{ + TestDriver driver = testNonStdExceptionRethrows(); + EXPECT_NO_THROW({ EXPECT_EQ(finishDriver(driver), 0); }); +} + +TEST_F(ViralNonPostingTest, CalleeRunsOnCallerThread) +{ + TestDriver driver = testCalleeRunsOnCallerThread(); + EXPECT_NO_THROW({ EXPECT_EQ(finishDriver(driver), 0); }); +} + +TEST_F(ViralNonPostingTest, DelayedCalleeStillOnCallerThread) +{ + TestDriver driver = testDelayedCalleeStillOnCallerThread(ioContext); + EXPECT_NO_THROW({ EXPECT_EQ(runDriver(driver), 0); }); +} + +TEST_F(ViralNonPostingTest, NestedNonPostingCoAwait) +{ + TestDriver driver = testNestedNonPostingCoAwait(); + EXPECT_NO_THROW({ EXPECT_EQ(finishDriver(driver), 0); }); +} + +TEST_F(ViralNonPostingTest, NestedInnerSuspension) +{ + TestDriver driver = testNestedInnerSuspension(ioContext); + EXPECT_NO_THROW({ EXPECT_EQ(runDriver(driver), 0); }); +} diff --git a/tests/cps/qutex_tests.cpp b/tests/cps/qutex_tests.cpp index bf33a01..ef0e2f0 100644 --- a/tests/cps/qutex_tests.cpp +++ b/tests/cps/qutex_tests.cpp @@ -330,8 +330,6 @@ TEST_F(QutexTest, Release) { // Test release without a prior acquire is rejected TEST_F(QutexTest, ReleaseWithoutAcquireThrows) { - qutex.isOwned = true; - EXPECT_THROW(qutex.release(), std::runtime_error); EXPECT_TRUE(qutex.queue.empty()); } diff --git a/tests/support/groupAssertions.h b/tests/support/groupAssertions.h new file mode 100644 index 0000000..cfb84ae --- /dev/null +++ b/tests/support/groupAssertions.h @@ -0,0 +1,104 @@ +#ifndef SPINSCALE_TEST_SUPPORT_GROUP_ASSERTIONS_H +#define SPINSCALE_TEST_SUPPORT_GROUP_ASSERTIONS_H + +#include +#include + +#include + +#include + +namespace sscl::tests { + +template +int completedIntValue(Invoker &invoker) +{ + if (invoker.completedReturnValues().myExceptionPtr) { + std::rethrow_exception( + invoker.completedReturnValues().myExceptionPtr); + } + + return invoker.completedReturnValues().myReturnValue; +} + +inline void expectCompletedSettlement( + const sscl::co::Group::SettlementDescriptor &descriptor) +{ + EXPECT_EQ( + descriptor.type, + sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED); +} + +template +void expectCompletedIntSettlement( + const sscl::co::Group::SettlementDescriptor &descriptor, + int expectedValue) +{ + ASSERT_EQ( + descriptor.type, + sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED); + EXPECT_EQ(completedIntValue(descriptor.invokerAs()), expectedValue); +} + +inline void expectExceptionSettlement( + const sscl::co::Group::SettlementDescriptor &descriptor) +{ + EXPECT_EQ( + descriptor.type, + sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN); + EXPECT_TRUE(descriptor.calleeException != nullptr); +} + +inline void expectRuntimeErrorSettlement( + const sscl::co::Group::SettlementDescriptor &descriptor, + const std::string &expectedMessage) +{ + ASSERT_EQ( + descriptor.type, + sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN); + ASSERT_TRUE(descriptor.calleeException != nullptr); + + try { + std::rethrow_exception(descriptor.calleeException); + } + catch (const std::runtime_error &runtimeError) { + EXPECT_EQ(std::string(runtimeError.what()), expectedMessage); + return; + } + catch (...) { + FAIL() << "Expected std::runtime_error settlement."; + } +} + +inline void expectIntExceptionSettlement( + const sscl::co::Group::SettlementDescriptor &descriptor, + int expectedValue) +{ + ASSERT_EQ( + descriptor.type, + sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN); + ASSERT_TRUE(descriptor.calleeException != nullptr); + + try { + std::rethrow_exception(descriptor.calleeException); + } + catch (int caughtValue) { + EXPECT_EQ(caughtValue, expectedValue); + return; + } + catch (...) { + FAIL() << "Expected int exception settlement."; + } +} + +inline void expectEmptyGroupError( + const std::runtime_error &runtimeError) +{ + constexpr const char *expectedMessage = + "co_await: Group has no member invokers; call add() before awaiting"; + EXPECT_EQ(std::string(runtimeError.what()), expectedMessage); +} + +} // namespace sscl::tests + +#endif // SPINSCALE_TEST_SUPPORT_GROUP_ASSERTIONS_H diff --git a/tests/support/threadHarness.cpp b/tests/support/threadHarness.cpp new file mode 100644 index 0000000..e86cc29 --- /dev/null +++ b/tests/support/threadHarness.cpp @@ -0,0 +1,413 @@ +#include + +#include +#include + +namespace sscl::tests { + +struct DedicatedIoThread::StartupState +{ + std::mutex mutex; + std::condition_variable condition; + std::thread::id osThreadId; + std::exception_ptr startupException; + bool allowInitialization = false; + bool initialized = false; +}; + +namespace { + +constexpr const char *callerThreadName = "test:caller"; +constexpr const char *calleeThreadName = "test:callee"; +constexpr const char *alternateThreadName = "test:alternate"; +constexpr const char *bodyThreadName = "test:body"; +constexpr const char *worldThreadName = "test:world"; +constexpr const char *legThreadName = "test:leg"; + +void runDedicatedThread( + const std::shared_ptr &state, + const sscl::PuppeteerThread::EntryFnArguments &args) +{ + { + std::unique_lock lock(state->mutex); + state->condition.wait( + lock, + [&state]() { return state->allowInitialization; }); + } + + try + { + args.usableBeforeJolt.initializeTls(); + + { + std::lock_guard guard(state->mutex); + state->osThreadId = std::this_thread::get_id(); + state->initialized = true; + } + + state->condition.notify_all(); + + args.usableBeforeJolt.getIoContext().restart(); + args.usableBeforeJolt.getIoContext().run(); + } + catch (...) + { + { + std::lock_guard guard(state->mutex); + state->startupException = std::current_exception(); + state->initialized = true; + } + + state->condition.notify_all(); + } +} + +} // namespace + +std::string threadRoleName(PostingThreadRole role) +{ + switch (role) + { + case PostingThreadRole::CALLER: + return callerThreadName; + case PostingThreadRole::CALLEE: + return calleeThreadName; + case PostingThreadRole::ALTERNATE: + return alternateThreadName; + case PostingThreadRole::BODY: + return bodyThreadName; + case PostingThreadRole::WORLD: + return worldThreadName; + case PostingThreadRole::LEG: + return legThreadName; + } + + throw std::runtime_error("Unknown PostingThreadRole"); +} + +void IoContextPump::pumpUntilIdle( + boost::asio::io_context &ioContext, + std::chrono::milliseconds idleTimeout, + std::chrono::milliseconds totalTimeout) +{ + const auto totalDeadline = + std::chrono::steady_clock::now() + totalTimeout; + auto lastProgress = std::chrono::steady_clock::now(); + + while (std::chrono::steady_clock::now() < totalDeadline) + { + if (ioContext.poll_one() > 0) + { + lastProgress = std::chrono::steady_clock::now(); + continue; + } + + if (std::chrono::steady_clock::now() - lastProgress >= idleTimeout) { + return; + } + + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } +} + +ThreadBoundComponent::ThreadBoundComponent() +: sscl::pptr::PuppeteerComponent(nullptr) +{ +} + +void ThreadBoundComponent::handleLoopExceptionHook() +{ + loopException = std::current_exception(); +} + +DedicatedIoThread::DedicatedIoThread(PostingThreadRole roleIn) +: role(roleIn), + startupState(std::make_shared()), + component(), + thread(std::make_shared( + static_cast(roleIn), + threadRoleName(roleIn), + [state = startupState]( + const sscl::PuppeteerThread::EntryFnArguments &args) + { + runDedicatedThread(state, args); + }, + component, + nullptr)) +{ + component.thread = thread; + releaseStartupBarrier(); + waitUntilInitialized(); +} + +DedicatedIoThread::~DedicatedIoThread() +{ + stopAndJoin(); +} + +boost::asio::io_context &DedicatedIoThread::ioContext() +{ + return thread->getIoContext(); +} + +sscl::ThreadId DedicatedIoThread::threadId() const noexcept +{ + return static_cast(role); +} + +std::thread::id DedicatedIoThread::osThreadId() const +{ + std::lock_guard guard(startupState->mutex); + return startupState->osThreadId; +} + +std::shared_ptr DedicatedIoThread::componentThread() const +{ + return thread; +} + +void DedicatedIoThread::stopAndJoin() +{ + if (!thread) { + return; + } + + releaseStartupBarrier(); + thread->getIoContext().stop(); + + if (thread->thread.joinable()) { + thread->thread.join(); + } + + thread.reset(); +} + +void DedicatedIoThread::releaseStartupBarrier() +{ + { + std::lock_guard guard(startupState->mutex); + startupState->allowInitialization = true; + } + + startupState->condition.notify_all(); +} + +void DedicatedIoThread::waitUntilInitialized() +{ + std::unique_lock lock(startupState->mutex); + const bool initialized = startupState->condition.wait_for( + lock, + defaultPostingTaskTimeout, + [this]() { return startupState->initialized; }); + + if (!initialized) { + throw std::runtime_error("Timed out waiting for test thread startup"); + } + + std::exception_ptr startupException = startupState->startupException; + lock.unlock(); + + if (startupException) { + std::rethrow_exception(startupException); + } +} + +void ThreadRegistry::registerThread( + PostingThreadRole role, + DedicatedIoThread &thread) +{ + std::lock_guard guard(registryMutex()); + threadsByRole()[role] = &thread; +} + +void ThreadRegistry::unregisterThread(PostingThreadRole role) +{ + std::lock_guard guard(registryMutex()); + threadsByRole().erase(role); +} + +boost::asio::io_context &ThreadRegistry::ioContext(PostingThreadRole role) +{ + std::lock_guard guard(registryMutex()); + auto iterator = threadsByRole().find(role); + + if (iterator == threadsByRole().end()) { + throw std::runtime_error( + "No test thread registered for " + threadRoleName(role)); + } + + return iterator->second->ioContext(); +} + +std::thread::id ThreadRegistry::osThreadId(PostingThreadRole role) +{ + std::lock_guard guard(registryMutex()); + auto iterator = threadsByRole().find(role); + + if (iterator == threadsByRole().end()) { + throw std::runtime_error( + "No test thread registered for " + threadRoleName(role)); + } + + return iterator->second->osThreadId(); +} + +std::mutex &ThreadRegistry::registryMutex() +{ + static std::mutex mutex; + return mutex; +} + +std::map & +ThreadRegistry::threadsByRole() +{ + static std::map threads; + return threads; +} + +PostingThreadSet::PostingThreadSet() +: callerThread(PostingThreadRole::CALLER), + calleeThread(PostingThreadRole::CALLEE), + alternateThread(PostingThreadRole::ALTERNATE), + bodyThread(PostingThreadRole::BODY), + worldThread(PostingThreadRole::WORLD), + legThread(PostingThreadRole::LEG) +{ + ThreadRegistry::registerThread(PostingThreadRole::CALLER, callerThread); + ThreadRegistry::registerThread(PostingThreadRole::CALLEE, calleeThread); + ThreadRegistry::registerThread(PostingThreadRole::ALTERNATE, alternateThread); + ThreadRegistry::registerThread(PostingThreadRole::BODY, bodyThread); + ThreadRegistry::registerThread(PostingThreadRole::WORLD, worldThread); + ThreadRegistry::registerThread(PostingThreadRole::LEG, legThread); + + sscl::ComponentThread::setPuppeteerThreadId( + static_cast(PostingThreadRole::CALLER)); + sscl::ComponentThread::setPuppeteerThread(callerThread.componentThread()); +} + +PostingThreadSet::~PostingThreadSet() +{ + ThreadRegistry::unregisterThread(PostingThreadRole::CALLER); + ThreadRegistry::unregisterThread(PostingThreadRole::CALLEE); + ThreadRegistry::unregisterThread(PostingThreadRole::ALTERNATE); + ThreadRegistry::unregisterThread(PostingThreadRole::BODY); + ThreadRegistry::unregisterThread(PostingThreadRole::WORLD); + ThreadRegistry::unregisterThread(PostingThreadRole::LEG); + + sscl::ComponentThread::setPuppeteerThread(nullptr); +} + +DedicatedIoThread &PostingThreadSet::thread(PostingThreadRole role) +{ + switch (role) + { + case PostingThreadRole::CALLER: + return callerThread; + case PostingThreadRole::CALLEE: + return calleeThread; + case PostingThreadRole::ALTERNATE: + return alternateThread; + case PostingThreadRole::BODY: + return bodyThread; + case PostingThreadRole::WORLD: + return worldThread; + case PostingThreadRole::LEG: + return legThread; + } + + throw std::runtime_error("Unknown PostingThreadRole"); +} + +DedicatedIoThread &PostingThreadSet::caller() +{ + return callerThread; +} + +DedicatedIoThread &PostingThreadSet::callee() +{ + return calleeThread; +} + +DedicatedIoThread &PostingThreadSet::alternate() +{ + return alternateThread; +} + +DedicatedIoThread &PostingThreadSet::body() +{ + return bodyThread; +} + +DedicatedIoThread &PostingThreadSet::world() +{ + return worldThread; +} + +DedicatedIoThread &PostingThreadSet::leg() +{ + return legThread; +} + +void CrossThreadTrace::recordConstructionThread() +{ + record(constructionThreadId); +} + +void CrossThreadTrace::recordCalleeExecutionThread() +{ + record(calleeExecutionThreadId); +} + +void CrossThreadTrace::recordFinalSuspendThread() +{ + record(finalSuspendThreadId); +} + +void CrossThreadTrace::recordAwaitResumeThread() +{ + record(awaitResumeThreadId); +} + +void CrossThreadTrace::recordCompletionCallbackThread() +{ + record(completionCallbackThreadId); +} + +std::thread::id CrossThreadTrace::constructionThread() const +{ + return read(constructionThreadId); +} + +std::thread::id CrossThreadTrace::calleeExecutionThread() const +{ + return read(calleeExecutionThreadId); +} + +std::thread::id CrossThreadTrace::finalSuspendThread() const +{ + return read(finalSuspendThreadId); +} + +std::thread::id CrossThreadTrace::awaitResumeThread() const +{ + return read(awaitResumeThreadId); +} + +std::thread::id CrossThreadTrace::completionCallbackThread() const +{ + return read(completionCallbackThreadId); +} + +void CrossThreadTrace::record(std::thread::id &slot) +{ + std::lock_guard guard(mutex); + slot = std::this_thread::get_id(); +} + +std::thread::id CrossThreadTrace::read(const std::thread::id &slot) const +{ + std::lock_guard guard(mutex); + return slot; +} + +} // namespace sscl::tests diff --git a/tests/support/threadHarness.h b/tests/support/threadHarness.h new file mode 100644 index 0000000..6f79422 --- /dev/null +++ b/tests/support/threadHarness.h @@ -0,0 +1,362 @@ +#ifndef SPINSCALE_TEST_SUPPORT_THREAD_HARNESS_H +#define SPINSCALE_TEST_SUPPORT_THREAD_HARNESS_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +namespace sscl::tests { + +constexpr std::chrono::milliseconds defaultIdleTimeout{800}; +constexpr std::chrono::milliseconds defaultTotalTimeout{10000}; +constexpr std::chrono::milliseconds defaultPostingTaskTimeout{10000}; + +enum class PostingThreadRole : sscl::ThreadId +{ + CALLER = 70, + CALLEE = 71, + ALTERNATE = 72, + BODY = 73, + WORLD = 74, + LEG = 75, +}; + +std::string threadRoleName(PostingThreadRole role); + +class IoContextPump +{ +public: + static void pumpUntilIdle( + boost::asio::io_context &ioContext, + std::chrono::milliseconds idleTimeout = defaultIdleTimeout, + std::chrono::milliseconds totalTimeout = defaultTotalTimeout); + + template + static bool pumpUntil( + boost::asio::io_context &ioContext, + Predicate &&predicate, + std::chrono::milliseconds idleTimeout = defaultIdleTimeout, + std::chrono::milliseconds totalTimeout = defaultTotalTimeout) + { + const auto totalDeadline = + std::chrono::steady_clock::now() + totalTimeout; + auto lastProgress = std::chrono::steady_clock::now(); + + while (std::chrono::steady_clock::now() < totalDeadline) + { + if (std::invoke(predicate)) { + return true; + } + + if (ioContext.poll_one() > 0) + { + lastProgress = std::chrono::steady_clock::now(); + continue; + } + + if (std::chrono::steady_clock::now() - lastProgress >= idleTimeout) { + return std::invoke(predicate); + } + + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + + return std::invoke(predicate); + } +}; + +class ThreadBoundComponent final +: public sscl::pptr::PuppeteerComponent +{ +public: + ThreadBoundComponent(); + void handleLoopExceptionHook() override; + + std::exception_ptr loopException; +}; + +class DedicatedIoThread +{ +public: + explicit DedicatedIoThread(PostingThreadRole role); + ~DedicatedIoThread(); + + DedicatedIoThread(const DedicatedIoThread &) = delete; + DedicatedIoThread &operator=(const DedicatedIoThread &) = delete; + DedicatedIoThread(DedicatedIoThread &&) = delete; + DedicatedIoThread &operator=(DedicatedIoThread &&) = delete; + + boost::asio::io_context &ioContext(); + sscl::ThreadId threadId() const noexcept; + std::thread::id osThreadId() const; + std::shared_ptr componentThread() const; + + void stopAndJoin(); + + struct StartupState; + + template + void post(Function &&function) + { + boost::asio::post( + ioContext(), + std::forward(function)); + } + + template + auto runSync(Function &&function) + -> std::invoke_result_t + { + using Result = std::invoke_result_t; + + if (std::this_thread::get_id() == osThreadId()) { + if constexpr (std::is_void_v) { + std::invoke(function); + return; + } else { + return std::invoke(function); + } + } + + auto promise = std::make_shared>(); + auto future = promise->get_future(); + + post( + [promise, function = std::forward(function)]() mutable + { + try + { + if constexpr (std::is_void_v) + { + std::invoke(function); + promise->set_value(); + } + else + { + promise->set_value(std::invoke(function)); + } + } + catch (...) + { + promise->set_exception(std::current_exception()); + } + }); + + return future.get(); + } + +private: + void releaseStartupBarrier(); + void waitUntilInitialized(); + + PostingThreadRole role; + std::shared_ptr startupState; + ThreadBoundComponent component; + std::shared_ptr thread; +}; + +class ThreadRegistry +{ +public: + static void registerThread( + PostingThreadRole role, + DedicatedIoThread &thread); + static void unregisterThread(PostingThreadRole role); + static boost::asio::io_context &ioContext(PostingThreadRole role); + static std::thread::id osThreadId(PostingThreadRole role); + +private: + static std::mutex ®istryMutex(); + static std::map &threadsByRole(); +}; + +template +struct PostingThreadTag +{ + static boost::asio::io_context &io_context() + { + return ThreadRegistry::ioContext(role); + } +}; + +template +using RolePostingPromise = + sscl::co::TaggedPostingPromise>; + +template +struct RolePostingPromiseTemplate +{ + template + using Type = RolePostingPromise; +}; + +template +using RoleViralPostingInvoker = + sscl::co::ViralPostingInvoker< + RolePostingPromiseTemplate::template Type, + T>; + +template +using RoleNonViralPostingInvoker = + sscl::co::NonViralPostingInvoker< + RolePostingPromiseTemplate::template Type>; + +class PostingThreadSet +{ +public: + PostingThreadSet(); + ~PostingThreadSet(); + + PostingThreadSet(const PostingThreadSet &) = delete; + PostingThreadSet &operator=(const PostingThreadSet &) = delete; + PostingThreadSet(PostingThreadSet &&) = delete; + PostingThreadSet &operator=(PostingThreadSet &&) = delete; + + DedicatedIoThread &thread(PostingThreadRole role); + DedicatedIoThread &caller(); + DedicatedIoThread &callee(); + DedicatedIoThread &alternate(); + DedicatedIoThread &body(); + DedicatedIoThread &world(); + DedicatedIoThread &leg(); + +private: + DedicatedIoThread callerThread; + DedicatedIoThread calleeThread; + DedicatedIoThread alternateThread; + DedicatedIoThread bodyThread; + DedicatedIoThread worldThread; + DedicatedIoThread legThread; +}; + +class CrossThreadTrace +{ +public: + void recordConstructionThread(); + void recordCalleeExecutionThread(); + void recordFinalSuspendThread(); + void recordAwaitResumeThread(); + void recordCompletionCallbackThread(); + + std::thread::id constructionThread() const; + std::thread::id calleeExecutionThread() const; + std::thread::id finalSuspendThread() const; + std::thread::id awaitResumeThread() const; + std::thread::id completionCallbackThread() const; + +private: + void record(std::thread::id &slot); + std::thread::id read(const std::thread::id &slot) const; + + mutable std::mutex mutex; + std::thread::id constructionThreadId; + std::thread::id calleeExecutionThreadId; + std::thread::id finalSuspendThreadId; + std::thread::id awaitResumeThreadId; + std::thread::id completionCallbackThreadId; +}; + +template +void runNonViralPostingTask( + DedicatedIoThread &callerThread, + InvokerFactory &&invokerFactory, + std::chrono::milliseconds timeout = defaultPostingTaskTimeout) +{ + using Factory = std::decay_t; + using Invoker = std::invoke_result_t< + Factory &, std::exception_ptr &, std::function>; + + struct TaskState + { + explicit TaskState(Factory factoryIn) + : factory(std::move(factoryIn)) + {} + + Factory factory; + std::exception_ptr coroutineException; + std::exception_ptr taskException; + std::optional invoker; + std::mutex mutex; + std::condition_variable condition; + bool completed = false; + }; + + auto taskState = std::make_shared( + std::forward(invokerFactory)); + + callerThread.post( + [taskState]() + { + auto completeTask = [taskState]() + { + taskState->taskException = taskState->coroutineException; + taskState->invoker.reset(); + + { + std::lock_guard guard(taskState->mutex); + taskState->completed = true; + } + + taskState->condition.notify_one(); + }; + + try + { + taskState->invoker.emplace( + std::invoke( + taskState->factory, + taskState->coroutineException, + std::move(completeTask))); + } + catch (...) + { + { + std::lock_guard guard(taskState->mutex); + taskState->taskException = std::current_exception(); + taskState->completed = true; + } + + taskState->condition.notify_one(); + } + }); + + std::unique_lock lock(taskState->mutex); + const bool completed = taskState->condition.wait_for( + lock, + timeout, + [&taskState]() { return taskState->completed; }); + + if (!completed) { + throw std::runtime_error("Timed out waiting for posting coroutine task"); + } + + std::exception_ptr taskException = taskState->taskException; + lock.unlock(); + + if (taskException) { + std::rethrow_exception(taskException); + } +} + +} // namespace sscl::tests + +#endif // SPINSCALE_TEST_SUPPORT_THREAD_HARNESS_H diff --git a/tests/support/timerAwaiters.h b/tests/support/timerAwaiters.h new file mode 100644 index 0000000..9b57b7f --- /dev/null +++ b/tests/support/timerAwaiters.h @@ -0,0 +1,161 @@ +#ifndef SPINSCALE_TEST_SUPPORT_TIMER_AWAITERS_H +#define SPINSCALE_TEST_SUPPORT_TIMER_AWAITERS_H + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace sscl::tests { + +using SharedDeadlineTimer = std::shared_ptr; + +class CancelableDeadlineTimerRegistry +{ +public: + void clear() + { + std::lock_guard guard(mutex); + timersByLabel.clear(); + } + + void registerTimer( + int labelMilliseconds, + const SharedDeadlineTimer &timer) + { + std::lock_guard guard(mutex); + timersByLabel[labelMilliseconds] = timer; + } + + void cancel(int labelMilliseconds) + { + std::lock_guard guard(mutex); + const auto iterator = timersByLabel.find(labelMilliseconds); + + if (iterator == timersByLabel.end()) { + throw std::runtime_error( + "No cancelable deadline_timer registered for label " + + std::to_string(labelMilliseconds)); + } + + const SharedDeadlineTimer timer = iterator->second.lock(); + + if (!timer) { + throw std::runtime_error( + "Cancelable deadline_timer expired before cancel for label " + + std::to_string(labelMilliseconds)); + } + + timer->cancel(); + } + +private: + std::mutex mutex; + std::unordered_map> + timersByLabel; +}; + +struct DeadlineTimerAwaiter +{ + DeadlineTimerAwaiter( + boost::asio::io_context &ioContext, + int delayMilliseconds) + : timer(std::make_shared(ioContext)) + { + start(delayMilliseconds); + } + + DeadlineTimerAwaiter( + SharedDeadlineTimer sharedTimer, + int delayMilliseconds) + : timer(std::move(sharedTimer)) + { + start(delayMilliseconds); + } + + bool await_ready() const noexcept + { return waitCompleted; } + + bool await_suspend(std::coroutine_handle<> handle) noexcept + { + resumeHandle = handle; + return !waitCompleted; + } + + boost::system::error_code await_resume() const noexcept + { return completionErrorCode; } + +private: + void start(int delayMilliseconds) + { + timer->expires_from_now( + boost::posix_time::milliseconds(delayMilliseconds)); + timer->async_wait( + [this](const boost::system::error_code &errorCode) + { + completionErrorCode = errorCode; + waitCompleted = true; + if (resumeHandle) { + resumeHandle.resume(); + } + }); + } + + SharedDeadlineTimer timer; + boost::system::error_code completionErrorCode; + bool waitCompleted = false; + std::coroutine_handle<> resumeHandle; +}; + +struct RegisteredDeadlineTimerAwaiter +{ + RegisteredDeadlineTimerAwaiter( + boost::asio::io_context &ioContext, + int delayMilliseconds, + int registrationLabelMilliseconds, + CancelableDeadlineTimerRegistry ®istry) + : timer(std::make_shared(ioContext)) + { + registry.registerTimer(registrationLabelMilliseconds, timer); + waiter.emplace(timer, delayMilliseconds); + } + + bool await_ready() const noexcept + { return waiter->await_ready(); } + + bool await_suspend(std::coroutine_handle<> handle) noexcept + { return waiter->await_suspend(handle); } + + boost::system::error_code await_resume() const noexcept + { return waiter->await_resume(); } + + SharedDeadlineTimer timer; + std::optional waiter; +}; + +inline void throwIfTimerWaitFailed( + const boost::system::error_code &waitError) +{ + if (waitError) { + throw std::runtime_error( + "deadline_timer wait failed: " + waitError.message()); + } +} + +inline bool timerWasCanceled(const boost::system::error_code &waitError) +{ + return waitError == boost::asio::error::operation_aborted; +} + +} // namespace sscl::tests + +#endif // SPINSCALE_TEST_SUPPORT_TIMER_AWAITERS_H