From 2f31e9a034cf08c3286359f6aaa083eb6c6397e3 Mon Sep 17 00:00:00 2001 From: Hayodea Hekol Date: Sat, 13 Jun 2026 17:59:06 -0400 Subject: [PATCH] Adversarial review on test porting plan --- tests/co/group_edge_tests.cpp | 221 +++++++++++++++------------ tests/co/group_timer_tests.cpp | 142 ++++++++++++++--- tests/co/viral_non_posting_tests.cpp | 138 ++++++++++++++++- tests/support/coroutineDriver.h | 38 +++++ tests/support/groupAssertions.h | 143 ++++++++++++----- tests/support/threadHarness.cpp | 66 ++++++-- tests/support/threadHarness.h | 18 ++- 7 files changed, 593 insertions(+), 173 deletions(-) create mode 100644 tests/support/coroutineDriver.h diff --git a/tests/co/group_edge_tests.cpp b/tests/co/group_edge_tests.cpp index b827e10..81afa8d 100644 --- a/tests/co/group_edge_tests.cpp +++ b/tests/co/group_edge_tests.cpp @@ -88,75 +88,17 @@ CalleeVoidInvoker voidMemberAfterDelay(int delayMilliseconds) co_return; } -int readCompletedLabel(CalleeIntInvoker &invoker) +CalleeIntInvoker waitRecordThreadAndReturnLabel( + int timerLabelMilliseconds, + sscl::tests::CrossThreadTrace &trace) { - return invoker.completedReturnValues().myReturnValue; -} - -void assertCompleted( - const sscl::co::Group::SettlementDescriptor &descriptor, - int expectedLabel) -{ - if (descriptor.type - != sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED) { - throw std::runtime_error("expected completed settlement"); - } - - if (readCompletedLabel(descriptor.invokerAs()) - != expectedLabel) { - throw std::runtime_error("settlement label mismatch"); - } -} - -void assertRuntimeErrorSettlement( - const sscl::co::Group::SettlementDescriptor &descriptor) -{ - if (descriptor.type - != sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN) { - throw std::runtime_error("expected exception settlement"); - } - - try { - std::rethrow_exception(descriptor.calleeException); - } - catch (const std::runtime_error &runtimeError) { - if (std::string(runtimeError.what()) != expectedThrowMessage) { - throw std::runtime_error("unexpected exception message"); - } - return; - } - - throw std::runtime_error("expected runtime_error settlement"); -} - -void assertIntExceptionSettlement( - const sscl::co::Group::SettlementDescriptor &descriptor) -{ - if (descriptor.type - != sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN) { - throw std::runtime_error("expected int exception settlement"); - } - - try { - std::rethrow_exception(descriptor.calleeException); - } - catch (int caughtValue) { - if (caughtValue != expectedNonStdThrowValue) { - throw std::runtime_error("unexpected int exception value"); - } - return; - } - - throw std::runtime_error("expected int exception settlement"); -} - -void assertEmptyGroupCoAwaitError(const std::runtime_error &runtimeError) -{ - constexpr const char *expectedEmptyGroupCoAwaitMessage = - "co_await: Group has no member invokers; call add() before awaiting"; - if (std::string(runtimeError.what()) != expectedEmptyGroupCoAwaitMessage) { - throw std::runtime_error("unexpected empty-group error message"); - } + const boost::system::error_code waitError = + co_await sscl::tests::DeadlineTimerAwaiter{ + sscl::ComponentThread::getSelf()->getIoContext(), + timerLabelMilliseconds}; + sscl::tests::throwIfTimerWaitFailed(waitError); + trace.recordCalleeExecutionThread(); + co_return timerLabelMilliseconds; } sscl::co::ViralNonPostingInvoker waitOnCallerThread(int delayMilliseconds) @@ -188,11 +130,15 @@ CallerDriver mixedSuccessAndFailureAwaitFirstThenAll( if (firstDescriptor.type == sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED) { - assertCompleted(firstDescriptor, 1); + sscl::tests::requireCompletedIntSettlement( + firstDescriptor, + 1); } else if (firstDescriptor.type == sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN) { - assertRuntimeErrorSettlement(firstDescriptor); + sscl::tests::requireRuntimeErrorSettlement( + firstDescriptor, + expectedThrowMessage); } else { throw std::runtime_error("first settlement has unexpected type"); @@ -212,12 +158,16 @@ CallerDriver mixedSuccessAndFailureAwaitFirstThenAll( if (descriptor.type == sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED) { ++completedCount; - assertCompleted(descriptor, 1); + sscl::tests::requireCompletedIntSettlement( + descriptor, + 1); } else if (descriptor.type == sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN) { ++exceptionCount; - assertRuntimeErrorSettlement(descriptor); + sscl::tests::requireRuntimeErrorSettlement( + descriptor, + expectedThrowMessage); } } @@ -241,7 +191,9 @@ CallerDriver singleMemberAwaitFirstThenAll( auto awaitFirst = group.getAwaitFirstSettlementInvoker(); auto [firstDescriptor, allAfterFirst] = co_await awaitFirst; - assertCompleted(firstDescriptor, delayShortMs); + sscl::tests::requireCompletedIntSettlement( + firstDescriptor, + delayShortMs); if (!group.allInvokersSettled() || allAfterFirst.size() != 1) { throw std::runtime_error("single member state mismatch"); @@ -254,7 +206,9 @@ CallerDriver singleMemberAwaitFirstThenAll( throw std::runtime_error("single member await-all count mismatch"); } - assertCompleted(allDescriptors[0], delayShortMs); + sscl::tests::requireCompletedIntSettlement( + allDescriptors[0], + delayShortMs); co_return; } @@ -282,7 +236,9 @@ CallerDriver allCompleteBeforeCoAwait( auto awaitFirst = group.getAwaitFirstSettlementInvoker(); auto [firstDescriptor, allAfterFirst] = co_await awaitFirst; - assertCompleted(firstDescriptor, 10); + sscl::tests::requireCompletedIntSettlement( + firstDescriptor, + 10); auto awaitAll = group.getAwaitAllSettlementsInvoker(); auto &allDescriptors = co_await awaitAll; @@ -294,13 +250,13 @@ CallerDriver allCompleteBeforeCoAwait( co_return; } -std::thread startAddWhileGroupAwaiterSuspendedProbe( +std::jthread startAddWhileGroupAwaiterSuspendedProbe( sscl::co::Group &group, CalleeIntInvoker &lateInvoker, std::atomic &groupIsAwaitingAll, std::atomic &addWasRejected) { - return std::thread( + return std::jthread( [&]() { while (!groupIsAwaitingAll.load(std::memory_order_acquire)) { @@ -343,7 +299,7 @@ CallerDriver addWhileAwaitAllSuspended( group.add(slowInvokerA); group.add(slowInvokerB); - std::thread addProbeThread = startAddWhileGroupAwaiterSuspendedProbe( + std::jthread addProbeThread = startAddWhileGroupAwaiterSuspendedProbe( group, lateInvoker, groupIsAwaitingAll, @@ -390,12 +346,16 @@ CallerDriver awaitAllOnlyMixedOutcomes( if (descriptor.type == sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED) { ++completedCount; - assertCompleted(descriptor, 7); + sscl::tests::requireCompletedIntSettlement( + descriptor, + 7); } else if (descriptor.type == sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN) { ++exceptionCount; - assertRuntimeErrorSettlement(descriptor); + sscl::tests::requireRuntimeErrorSettlement( + descriptor, + expectedThrowMessage); } } @@ -446,7 +406,7 @@ CallerDriver emptyGroupAwaitAllThrows( (void)co_await group.getAwaitAllSettlementsInvoker(); } catch (const std::runtime_error &runtimeError) { - assertEmptyGroupCoAwaitError(runtimeError); + sscl::tests::requireEmptyGroupError(runtimeError); co_return; } @@ -466,7 +426,7 @@ CallerDriver emptyGroupAwaitFirstThrows( (void)co_await group.getAwaitFirstSettlementInvoker(); } catch (const std::runtime_error &runtimeError) { - assertEmptyGroupCoAwaitError(runtimeError); + sscl::tests::requireEmptyGroupError(runtimeError); co_return; } @@ -496,9 +456,10 @@ CallerDriver wrongAwaitInvokerOrder( } auto [firstDescriptor, allAfterFirst] = co_await awaitFirstHandle; - assertCompleted( + sscl::tests::requireCompletedIntSettlement( firstDescriptor, - readCompletedLabel(firstDescriptor.invokerAs())); + sscl::tests::completedIntValue( + firstDescriptor.invokerAs())); if (!group.firstInvokerSettled() || allAfterFirst.size() != 2) { throw std::runtime_error("wrong-order await-first state mismatch"); @@ -524,8 +485,12 @@ CallerDriver doubleCoAwaitSameAwaitFirst( auto [firstDescriptorA, allAfterFirstA] = co_await awaitFirst; auto [firstDescriptorB, allAfterFirstB] = co_await awaitFirst; - assertCompleted(firstDescriptorA, delayShortMs); - assertCompleted(firstDescriptorB, delayShortMs); + sscl::tests::requireCompletedIntSettlement( + firstDescriptorA, + delayShortMs); + sscl::tests::requireCompletedIntSettlement( + firstDescriptorB, + delayShortMs); if (&firstDescriptorA.invokerAs() != &firstDescriptorB.invokerAs()) { @@ -558,8 +523,12 @@ CallerDriver doubleCoAwaitSameAwaitAll( throw std::runtime_error("double await-all count mismatch"); } - assertCompleted(allDescriptorsA[0], delayShortMs); - assertCompleted(allDescriptorsB[0], delayShortMs); + sscl::tests::requireCompletedIntSettlement( + allDescriptorsA[0], + delayShortMs); + sscl::tests::requireCompletedIntSettlement( + allDescriptorsB[0], + delayShortMs); co_return; } @@ -579,11 +548,15 @@ CallerDriver twoAwaitFirstHandlesSequentially( auto awaitFirstA = group.getAwaitFirstSettlementInvoker(); auto [firstDescriptorA, allAfterFirstA] = co_await awaitFirstA; - assertCompleted(firstDescriptorA, delayShortMs); + sscl::tests::requireCompletedIntSettlement( + firstDescriptorA, + delayShortMs); auto awaitFirstB = group.getAwaitFirstSettlementInvoker(); auto [firstDescriptorB, allAfterFirstB] = co_await awaitFirstB; - assertCompleted(firstDescriptorB, delayShortMs); + sscl::tests::requireCompletedIntSettlement( + firstDescriptorB, + delayShortMs); if (&firstDescriptorA.invokerAs() != &firstDescriptorB.invokerAs()) { @@ -620,7 +593,7 @@ CallerDriver addSecondWaveAfterAwaitAll( co_await waitOnCallerThread(delayShortMs); - if (readCompletedLabel(wave2Immediate) + if (sscl::tests::completedIntValue(wave2Immediate) != wave2ImmediateSettlementLabel) { throw std::runtime_error("wave-2 immediate member did not complete"); } @@ -656,7 +629,9 @@ CallerDriver shortTimerAddedAfterLongStillWinsRace( auto awaitFirst = group.getAwaitFirstSettlementInvoker(); auto [firstDescriptor, allAfterFirst] = co_await awaitFirst; - assertCompleted(firstDescriptor, delayShortMs); + sscl::tests::requireCompletedIntSettlement( + firstDescriptor, + delayShortMs); if (&firstDescriptor.invokerAs() != &shortInvoker) { throw std::runtime_error("short timer should win first settlement"); @@ -684,7 +659,9 @@ CallerDriver nonStdExceptionSettlement( throw std::runtime_error("non-std exception count mismatch"); } - assertIntExceptionSettlement(allDescriptors[0]); + sscl::tests::requireIntExceptionSettlement( + allDescriptors[0], + expectedNonStdThrowValue); try { group.checkForAndReThrowGroupExceptions(); @@ -738,11 +715,14 @@ CallerDriver returnValuesRemainReadableAfterAwaitFirst( auto awaitFirst = group.getAwaitFirstSettlementInvoker(); auto [firstDescriptor, allAfterFirst] = co_await awaitFirst; - assertCompleted(firstDescriptor, delayShortMs); + sscl::tests::requireCompletedIntSettlement( + firstDescriptor, + delayShortMs); - const int fastLabelFromDescriptor = readCompletedLabel( + const int fastLabelFromDescriptor = sscl::tests::completedIntValue( firstDescriptor.invokerAs()); - const int fastLabelFromLocal = readCompletedLabel(fastInvoker); + const int fastLabelFromLocal = + sscl::tests::completedIntValue(fastInvoker); if (fastLabelFromDescriptor != fastLabelFromLocal) { throw std::runtime_error("descriptor/local return value mismatch"); @@ -756,6 +736,35 @@ CallerDriver returnValuesRemainReadableAfterAwaitFirst( co_return; } +CallerDriver groupMemberRunsOnCalleeAndAwaitResumesOnCaller( + std::exception_ptr &exceptionPtr, + std::function completion, + sscl::tests::CrossThreadTrace &trace) +{ + (void)exceptionPtr; + (void)completion; + + sscl::co::Group group; + CalleeIntInvoker memberInvoker = waitRecordThreadAndReturnLabel( + delayShortMs, + trace); + group.add(memberInvoker); + + auto awaitFirst = group.getAwaitFirstSettlementInvoker(); + auto [firstDescriptor, allAfterFirst] = co_await awaitFirst; + trace.recordAwaitResumeThread(); + + sscl::tests::requireCompletedIntSettlement( + firstDescriptor, + delayShortMs); + + if (allAfterFirst.size() != 1) { + throw std::runtime_error("cross-thread group trace count mismatch"); + } + + co_return; +} + class GroupEdgeTest : public ::testing::Test { @@ -816,6 +825,26 @@ RUN_GROUP_EDGE_SCENARIO( ReturnValuesRemainReadableAfterAwaitFirst, returnValuesRemainReadableAfterAwaitFirst) +TEST_F(GroupEdgeTest, SuspendingMemberRunsOnCalleeAndAwaitResumesOnCaller) +{ + sscl::tests::CrossThreadTrace trace; + + runScenario( + [&trace]( + std::exception_ptr &exceptionPtr, + std::function completion) + { + return groupMemberRunsOnCalleeAndAwaitResumesOnCaller( + exceptionPtr, + std::move(completion), + trace); + }); + + EXPECT_EQ(trace.calleeExecutionThread(), threads.callee().osThreadId()); + EXPECT_EQ(trace.awaitResumeThread(), threads.caller().osThreadId()); + EXPECT_NE(trace.calleeExecutionThread(), trace.awaitResumeThread()); +} + TEST_F(GroupEdgeTest, NonViralVoidGroupTemplateInstantiates) { GTEST_SKIP() diff --git a/tests/co/group_timer_tests.cpp b/tests/co/group_timer_tests.cpp index 1cb3acf..702b6f0 100644 --- a/tests/co/group_timer_tests.cpp +++ b/tests/co/group_timer_tests.cpp @@ -1,8 +1,11 @@ #include #include #include +#include +#include #include #include +#include #include @@ -36,19 +39,77 @@ using CalleeIntInvoker = using Clock = std::chrono::steady_clock; using Ms = std::chrono::milliseconds; -CalleeIntInvoker waitDeadlineTimer(int timerLabelMilliseconds) +class GroupTimerThreadTrace +{ +public: + void recordTimerCompletionThread(int timerLabelMilliseconds) + { + std::lock_guard guard(mutex); + timerCompletionThreads[timerLabelMilliseconds] = + std::this_thread::get_id(); + } + + void recordAwaitFirstResumeThread() + { + std::lock_guard guard(mutex); + awaitFirstResumeThread = std::this_thread::get_id(); + } + + void recordAwaitAllResumeThread() + { + std::lock_guard guard(mutex); + awaitAllResumeThread = std::this_thread::get_id(); + } + + std::thread::id timerCompletionThread(int timerLabelMilliseconds) const + { + std::lock_guard guard(mutex); + const auto iterator = + timerCompletionThreads.find(timerLabelMilliseconds); + + if (iterator == timerCompletionThreads.end()) { + throw std::runtime_error("Missing timer completion thread trace"); + } + + return iterator->second; + } + + std::thread::id awaitFirstThread() const + { + std::lock_guard guard(mutex); + return awaitFirstResumeThread; + } + + std::thread::id awaitAllThread() const + { + std::lock_guard guard(mutex); + return awaitAllResumeThread; + } + +private: + mutable std::mutex mutex; + std::map timerCompletionThreads; + std::thread::id awaitFirstResumeThread; + std::thread::id awaitAllResumeThread; +}; + +CalleeIntInvoker waitDeadlineTimer( + int timerLabelMilliseconds, + GroupTimerThreadTrace &trace) { const boost::system::error_code waitError = co_await sscl::tests::DeadlineTimerAwaiter{ sscl::ComponentThread::getSelf()->getIoContext(), timerLabelMilliseconds}; sscl::tests::throwIfTimerWaitFailed(waitError); + trace.recordTimerCompletionThread(timerLabelMilliseconds); co_return timerLabelMilliseconds; } CalleeIntInvoker waitCancelableDeadlineTimer( int timerLabelMilliseconds, - sscl::tests::CancelableDeadlineTimerRegistry ®istry) + sscl::tests::CancelableDeadlineTimerRegistry ®istry, + GroupTimerThreadTrace &trace) { const boost::system::error_code waitError = co_await sscl::tests::RegisteredDeadlineTimerAwaiter{ @@ -58,10 +119,12 @@ CalleeIntInvoker waitCancelableDeadlineTimer( registry}; if (sscl::tests::timerWasCanceled(waitError)) { + trace.recordTimerCompletionThread(timerLabelMilliseconds); co_return timerLabelMilliseconds; } sscl::tests::throwIfTimerWaitFailed(waitError); + trace.recordTimerCompletionThread(timerLabelMilliseconds); co_return timerLabelMilliseconds; } @@ -89,15 +152,19 @@ void throwIfElapsedTooShort( CallerDriver runGroupTimerRace( std::exception_ptr &exceptionPtr, - std::function completion) + std::function completion, + GroupTimerThreadTrace &trace) { (void)exceptionPtr; (void)completion; sscl::co::Group group; - CalleeIntInvoker invokerShort = waitDeadlineTimer(timerDelayShortMs); - CalleeIntInvoker invokerMedium = waitDeadlineTimer(timerDelayMediumMs); - CalleeIntInvoker invokerLong = waitDeadlineTimer(timerDelayLongMs); + CalleeIntInvoker invokerShort = + waitDeadlineTimer(timerDelayShortMs, trace); + CalleeIntInvoker invokerMedium = + waitDeadlineTimer(timerDelayMediumMs, trace); + CalleeIntInvoker invokerLong = + waitDeadlineTimer(timerDelayLongMs, trace); group.add(invokerShort); group.add(invokerMedium); @@ -107,6 +174,7 @@ CallerDriver runGroupTimerRace( auto awaitFirst = group.getAwaitFirstSettlementInvoker(); auto [firstSettlement, allSettlementsAfterFirst] = co_await awaitFirst; + trace.recordAwaitFirstResumeThread(); const auto firstElapsedMs = std::chrono::duration_cast(Clock::now() - testStart); @@ -125,6 +193,7 @@ CallerDriver runGroupTimerRace( auto awaitAll = group.getAwaitAllSettlementsInvoker(); auto &allSettlements = co_await awaitAll; + trace.recordAwaitAllResumeThread(); const auto allElapsedMs = std::chrono::duration_cast(Clock::now() - testStart); @@ -137,16 +206,16 @@ CallerDriver runGroupTimerRace( throw std::runtime_error("expected three settlements"); } - sscl::tests::expectCompletedIntSettlement( + sscl::tests::requireCompletedIntSettlement( firstSettlement, timerDelayShortMs); - sscl::tests::expectCompletedIntSettlement( + sscl::tests::requireCompletedIntSettlement( allSettlementsAfterFirst[0], timerDelayShortMs); - sscl::tests::expectCompletedIntSettlement( + sscl::tests::requireCompletedIntSettlement( allSettlementsAfterFirst[1], timerDelayMediumMs); - sscl::tests::expectCompletedIntSettlement( + sscl::tests::requireCompletedIntSettlement( allSettlementsAfterFirst[2], timerDelayLongMs); @@ -156,18 +225,19 @@ CallerDriver runGroupTimerRace( CallerDriver runGroupTimerCancelLongAfterAwaitFirst( std::exception_ptr &exceptionPtr, std::function completion, - sscl::tests::CancelableDeadlineTimerRegistry ®istry) + sscl::tests::CancelableDeadlineTimerRegistry ®istry, + GroupTimerThreadTrace &trace) { (void)exceptionPtr; (void)completion; sscl::co::Group group; CalleeIntInvoker invokerShort = - waitCancelableDeadlineTimer(timerDelayShortMs, registry); + waitCancelableDeadlineTimer(timerDelayShortMs, registry, trace); CalleeIntInvoker invokerMedium = - waitCancelableDeadlineTimer(timerDelayMediumMs, registry); + waitCancelableDeadlineTimer(timerDelayMediumMs, registry, trace); CalleeIntInvoker invokerLong = - waitCancelableDeadlineTimer(timerDelayLongMs, registry); + waitCancelableDeadlineTimer(timerDelayLongMs, registry, trace); group.add(invokerShort); group.add(invokerMedium); @@ -177,6 +247,7 @@ CallerDriver runGroupTimerCancelLongAfterAwaitFirst( auto awaitFirst = group.getAwaitFirstSettlementInvoker(); auto [firstSettlement, allSettlementsAfterFirst] = co_await awaitFirst; + trace.recordAwaitFirstResumeThread(); if (&firstSettlement.invokerAs() != &invokerShort) { throw std::runtime_error("cancel test first settlement mismatch"); @@ -190,6 +261,7 @@ CallerDriver runGroupTimerCancelLongAfterAwaitFirst( auto awaitAll = group.getAwaitAllSettlementsInvoker(); auto &allSettlements = co_await awaitAll; + trace.recordAwaitAllResumeThread(); const auto allElapsedMs = std::chrono::duration_cast(Clock::now() - testStart); @@ -207,13 +279,13 @@ CallerDriver runGroupTimerCancelLongAfterAwaitFirst( throw std::runtime_error("cancel test expected three settlements"); } - sscl::tests::expectCompletedIntSettlement( + sscl::tests::requireCompletedIntSettlement( allSettlements[0], timerDelayShortMs); - sscl::tests::expectCompletedIntSettlement( + sscl::tests::requireCompletedIntSettlement( allSettlements[1], timerDelayMediumMs); - sscl::tests::expectCompletedIntSettlement( + sscl::tests::requireCompletedIntSettlement( allSettlements[2], timerDelayLongMs); @@ -229,6 +301,25 @@ class GroupTimerTest : public ::testing::Test { protected: + void assertTimerTraceCrossedThreads( + const GroupTimerThreadTrace &trace) + { + EXPECT_EQ( + trace.timerCompletionThread(timerDelayShortMs), + threads.callee().osThreadId()); + EXPECT_EQ( + trace.timerCompletionThread(timerDelayMediumMs), + threads.callee().osThreadId()); + EXPECT_EQ( + trace.timerCompletionThread(timerDelayLongMs), + threads.callee().osThreadId()); + EXPECT_EQ(trace.awaitFirstThread(), threads.caller().osThreadId()); + EXPECT_EQ(trace.awaitAllThread(), threads.caller().osThreadId()); + EXPECT_NE( + trace.timerCompletionThread(timerDelayShortMs), + trace.awaitFirstThread()); + } + sscl::tests::PostingThreadSet threads; }; @@ -236,33 +327,42 @@ protected: TEST_F(GroupTimerTest, AwaitFirstReturnsShortestTimerAndAwaitAllWaitsForLongest) { + GroupTimerThreadTrace trace; + ASSERT_NO_THROW( sscl::tests::runNonViralPostingTask( threads.caller(), - []( + [&trace]( std::exception_ptr &exceptionPtr, std::function completion) { return runGroupTimerRace( exceptionPtr, - std::move(completion)); + std::move(completion), + trace); })); + + assertTimerTraceCrossedThreads(trace); } TEST_F(GroupTimerTest, CancelLongTimerAfterAwaitFirst) { sscl::tests::CancelableDeadlineTimerRegistry registry; + GroupTimerThreadTrace trace; ASSERT_NO_THROW( sscl::tests::runNonViralPostingTask( threads.caller(), - [®istry]( + [®istry, &trace]( std::exception_ptr &exceptionPtr, std::function completion) { return runGroupTimerCancelLongAfterAwaitFirst( exceptionPtr, std::move(completion), - registry); + registry, + trace); })); + + assertTimerTraceCrossedThreads(trace); } diff --git a/tests/co/viral_non_posting_tests.cpp b/tests/co/viral_non_posting_tests.cpp index d40c496..c730f34 100644 --- a/tests/co/viral_non_posting_tests.cpp +++ b/tests/co/viral_non_posting_tests.cpp @@ -12,7 +12,11 @@ #include #include +#include +#include +#include +#include #include #include @@ -28,6 +32,9 @@ using TestInvoker = sscl::co::ViralNonPostingInvoker; using TestDriver = TestInvoker; using TestVoidDriver = TestInvoker; +using CallerPostingDriver = + sscl::tests::RoleNonViralPostingInvoker< + sscl::tests::PostingThreadRole::CALLER>; struct ThreadIdPair { @@ -101,18 +108,14 @@ protected: int runDriver(TestDriver &driver) { - sscl::tests::IoContextPump::pumpUntilIdle(ioContext); - return finishDriver(driver); + return sscl::tests::CoroutineDriver::pumpUntilIdleAndReturnValue( + ioContext, + driver); } int finishDriver(TestDriver &driver) { - if (driver.completedReturnValues().myExceptionPtr) { - std::rethrow_exception( - driver.completedReturnValues().myExceptionPtr); - } - - return driver.completedReturnValues().myReturnValue; + return sscl::tests::CoroutineDriver::completedReturnValue(driver); } boost::asio::io_context ioContext; @@ -140,6 +143,18 @@ TestVoidDriver voidReturnImmediately() co_return; } +TestVoidDriver voidMemberAfterDelay( + boost::asio::io_context &ioContext, + int delayMilliseconds) +{ + const boost::system::error_code waitError = + co_await sscl::tests::DeadlineTimerAwaiter{ + ioContext, + delayMilliseconds}; + sscl::tests::throwIfTimerWaitFailed(waitError); + co_return; +} + TestInvoker throwRuntimeErrorImmediately() { throw std::runtime_error(expectedThrowMessage); @@ -412,6 +427,79 @@ TestDriver testNestedInnerSuspension(boost::asio::io_context &ioContext) co_return 0; } +CallerPostingDriver nonPostingVoidMemberInGroupDriver( + std::exception_ptr &exceptionPtr, + std::function completion) +{ + (void)exceptionPtr; + (void)completion; + + sscl::co::Group group; + TestVoidDriver voidInvoker = voidMemberAfterDelay( + sscl::ComponentThread::getSelf()->getIoContext(), + delayShortMs); + group.add(voidInvoker); + + auto &allDescriptors = co_await group.getAwaitAllSettlementsInvoker(); + + if (allDescriptors.size() != 1) { + throw std::runtime_error("voidMemberInGroup count mismatch"); + } + + sscl::tests::requireCompletedSettlement(allDescriptors[0]); + + co_return; +} + +CallerPostingDriver nonPostingGroupMixedImmediateAndDelayedDriver( + std::exception_ptr &exceptionPtr, + std::function completion) +{ + (void)exceptionPtr; + (void)completion; + + sscl::co::Group group; + TestInvoker immediateInvoker = returnLabelImmediately(11); + TestInvoker delayedInvoker = waitAndReturnLabel( + sscl::ComponentThread::getSelf()->getIoContext(), + delayShortMs); + + group.add(immediateInvoker); + group.add(delayedInvoker); + + auto &allDescriptors = co_await group.getAwaitAllSettlementsInvoker(); + + if (allDescriptors.size() != 2) { + throw std::runtime_error("groupMixedImmediateAndDelayed count mismatch"); + } + + bool sawImmediate = false; + bool sawDelayed = false; + + for (auto &descriptor : allDescriptors) { + sscl::tests::requireCompletedSettlement(descriptor); + const int label = sscl::tests::completedIntValue( + descriptor.invokerAs>()); + if (label == 11) { + sawImmediate = true; + } + else if (label == delayShortMs) { + sawDelayed = true; + } + else { + throw std::runtime_error( + "groupMixedImmediateAndDelayed unexpected label"); + } + } + + if (!sawImmediate || !sawDelayed) { + throw std::runtime_error( + "groupMixedImmediateAndDelayed missing expected label"); + } + + co_return; +} + } // namespace TEST_F(ViralNonPostingTest, ImmediateReturnFastPath) @@ -509,3 +597,37 @@ TEST_F(ViralNonPostingTest, NestedInnerSuspension) TestDriver driver = testNestedInnerSuspension(ioContext); EXPECT_NO_THROW({ EXPECT_EQ(runDriver(driver), 0); }); } + +TEST(ViralNonPostingGroupIntegrationTest, VoidMemberInGroup) +{ + sscl::tests::PostingThreadSet threads; + + ASSERT_NO_THROW( + sscl::tests::runNonViralPostingTask( + threads.caller(), + []( + std::exception_ptr &exceptionPtr, + std::function completion) + { + return nonPostingVoidMemberInGroupDriver( + exceptionPtr, + std::move(completion)); + })); +} + +TEST(ViralNonPostingGroupIntegrationTest, MixedImmediateAndDelayedInGroup) +{ + sscl::tests::PostingThreadSet threads; + + ASSERT_NO_THROW( + sscl::tests::runNonViralPostingTask( + threads.caller(), + []( + std::exception_ptr &exceptionPtr, + std::function completion) + { + return nonPostingGroupMixedImmediateAndDelayedDriver( + exceptionPtr, + std::move(completion)); + })); +} diff --git a/tests/support/coroutineDriver.h b/tests/support/coroutineDriver.h new file mode 100644 index 0000000..e7a8dab --- /dev/null +++ b/tests/support/coroutineDriver.h @@ -0,0 +1,38 @@ +#ifndef SPINSCALE_TEST_SUPPORT_COROUTINE_DRIVER_H +#define SPINSCALE_TEST_SUPPORT_COROUTINE_DRIVER_H + +#include + +#include + +#include + +namespace sscl::tests { + +class CoroutineDriver +{ +public: + template + static auto completedReturnValue(Invoker &invoker) + { + if (invoker.completedReturnValues().myExceptionPtr) { + std::rethrow_exception( + invoker.completedReturnValues().myExceptionPtr); + } + + return invoker.completedReturnValues().myReturnValue; + } + + template + static auto pumpUntilIdleAndReturnValue( + boost::asio::io_context &ioContext, + Invoker &invoker) + { + IoContextPump::pumpUntilIdle(ioContext); + return completedReturnValue(invoker); + } +}; + +} // namespace sscl::tests + +#endif // SPINSCALE_TEST_SUPPORT_COROUTINE_DRIVER_H diff --git a/tests/support/groupAssertions.h b/tests/support/groupAssertions.h index cfb84ae..db66d70 100644 --- a/tests/support/groupAssertions.h +++ b/tests/support/groupAssertions.h @@ -2,6 +2,7 @@ #define SPINSCALE_TEST_SUPPORT_GROUP_ASSERTIONS_H #include +#include #include #include @@ -21,12 +22,31 @@ int completedIntValue(Invoker &invoker) return invoker.completedReturnValues().myReturnValue; } -inline void expectCompletedSettlement( +inline void requireCompletedSettlement( const sscl::co::Group::SettlementDescriptor &descriptor) { - EXPECT_EQ( - descriptor.type, - sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED); + if (descriptor.type != + sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED) + { + throw std::runtime_error("Expected completed settlement"); + } +} + +template +void requireCompletedIntSettlement( + const sscl::co::Group::SettlementDescriptor &descriptor, + int expectedValue) +{ + requireCompletedSettlement(descriptor); + + const int actualValue = completedIntValue(descriptor.invokerAs()); + if (actualValue != expectedValue) { + throw std::runtime_error( + "Expected completed settlement value " + + std::to_string(expectedValue) + + ", got " + + std::to_string(actualValue)); + } } template @@ -34,39 +54,85 @@ void expectCompletedIntSettlement( const sscl::co::Group::SettlementDescriptor &descriptor, int expectedValue) { - ASSERT_EQ( - descriptor.type, - sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED); - EXPECT_EQ(completedIntValue(descriptor.invokerAs()), expectedValue); + EXPECT_NO_THROW( + requireCompletedIntSettlement( + descriptor, + expectedValue)); +} + +inline void expectCompletedSettlement( + const sscl::co::Group::SettlementDescriptor &descriptor) +{ + EXPECT_NO_THROW(requireCompletedSettlement(descriptor)); +} + +inline void requireExceptionSettlement( + const sscl::co::Group::SettlementDescriptor &descriptor) +{ + if (descriptor.type != + sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN) + { + throw std::runtime_error("Expected exception settlement"); + } + + if (!descriptor.calleeException) { + throw std::runtime_error("Expected exception pointer in settlement"); + } } inline void expectExceptionSettlement( const sscl::co::Group::SettlementDescriptor &descriptor) { - EXPECT_EQ( - descriptor.type, - sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN); - EXPECT_TRUE(descriptor.calleeException != nullptr); + EXPECT_NO_THROW(requireExceptionSettlement(descriptor)); } -inline void expectRuntimeErrorSettlement( +inline void requireRuntimeErrorSettlement( const sscl::co::Group::SettlementDescriptor &descriptor, const std::string &expectedMessage) { - ASSERT_EQ( - descriptor.type, - sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN); - ASSERT_TRUE(descriptor.calleeException != nullptr); + requireExceptionSettlement(descriptor); try { std::rethrow_exception(descriptor.calleeException); } catch (const std::runtime_error &runtimeError) { - EXPECT_EQ(std::string(runtimeError.what()), expectedMessage); + const std::string actualMessage = runtimeError.what(); + if (actualMessage != expectedMessage) { + throw std::runtime_error( + "Expected runtime_error settlement message \"" + + expectedMessage + + "\", got \"" + + actualMessage + + "\""); + } return; } catch (...) { - FAIL() << "Expected std::runtime_error settlement."; + throw std::runtime_error("Expected std::runtime_error settlement"); + } +} + +inline void requireIntExceptionSettlement( + const sscl::co::Group::SettlementDescriptor &descriptor, + int expectedValue) +{ + requireExceptionSettlement(descriptor); + + try { + std::rethrow_exception(descriptor.calleeException); + } + catch (int caughtValue) { + if (caughtValue != expectedValue) { + throw std::runtime_error( + "Expected int exception settlement value " + + std::to_string(expectedValue) + + ", got " + + std::to_string(caughtValue)); + } + return; + } + catch (...) { + throw std::runtime_error("Expected int exception settlement"); } } @@ -74,29 +140,36 @@ inline void expectIntExceptionSettlement( const sscl::co::Group::SettlementDescriptor &descriptor, int expectedValue) { - ASSERT_EQ( - descriptor.type, - sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN); - ASSERT_TRUE(descriptor.calleeException != nullptr); + EXPECT_NO_THROW( + requireIntExceptionSettlement( + descriptor, + expectedValue)); +} - try { - std::rethrow_exception(descriptor.calleeException); - } - catch (int caughtValue) { - EXPECT_EQ(caughtValue, expectedValue); - return; - } - catch (...) { - FAIL() << "Expected int exception settlement."; +inline void expectRuntimeErrorSettlement( + const sscl::co::Group::SettlementDescriptor &descriptor, + const std::string &expectedMessage) +{ + EXPECT_NO_THROW( + requireRuntimeErrorSettlement( + descriptor, + expectedMessage)); +} + +inline void requireEmptyGroupError( + const std::runtime_error &runtimeError) +{ + constexpr const char *expectedMessage = + "co_await: Group has no member invokers; call add() before awaiting"; + if (std::string(runtimeError.what()) != expectedMessage) { + throw std::runtime_error("Unexpected empty group error message"); } } inline void expectEmptyGroupError( const std::runtime_error &runtimeError) { - constexpr const char *expectedMessage = - "co_await: Group has no member invokers; call add() before awaiting"; - EXPECT_EQ(std::string(runtimeError.what()), expectedMessage); + EXPECT_NO_THROW(requireEmptyGroupError(runtimeError)); } } // namespace sscl::tests diff --git a/tests/support/threadHarness.cpp b/tests/support/threadHarness.cpp index e86cc29..e473226 100644 --- a/tests/support/threadHarness.cpp +++ b/tests/support/threadHarness.cpp @@ -217,13 +217,32 @@ void ThreadRegistry::registerThread( DedicatedIoThread &thread) { std::lock_guard guard(registryMutex()); - threadsByRole()[role] = &thread; + auto [iterator, inserted] = threadsByRole().emplace(role, &thread); + + if (!inserted) { + throw std::runtime_error( + "Test thread role already registered for " + threadRoleName(role)); + } } -void ThreadRegistry::unregisterThread(PostingThreadRole role) +void ThreadRegistry::unregisterThread( + PostingThreadRole role, + DedicatedIoThread &expectedThread) { std::lock_guard guard(registryMutex()); - threadsByRole().erase(role); + auto iterator = threadsByRole().find(role); + + if (iterator == threadsByRole().end()) { + return; + } + + if (iterator->second != &expectedThread) { + throw std::runtime_error( + "Test thread role registered to a different thread for " + + threadRoleName(role)); + } + + threadsByRole().erase(iterator); } boost::asio::io_context &ThreadRegistry::ioContext(PostingThreadRole role) @@ -272,6 +291,20 @@ PostingThreadSet::PostingThreadSet() bodyThread(PostingThreadRole::BODY), worldThread(PostingThreadRole::WORLD), legThread(PostingThreadRole::LEG) +{ + previousPuppeteerThread = sscl::ComponentThread::getPptr(); + previousPuppeteerThreadId = sscl::pptr::puppeteerThreadId; + registerAllThreads(); + installCallerAsPuppeteer(); +} + +PostingThreadSet::~PostingThreadSet() +{ + restorePreviousPuppeteer(); + unregisterAllThreads(); +} + +void PostingThreadSet::registerAllThreads() { ThreadRegistry::registerThread(PostingThreadRole::CALLER, callerThread); ThreadRegistry::registerThread(PostingThreadRole::CALLEE, calleeThread); @@ -279,22 +312,31 @@ PostingThreadSet::PostingThreadSet() ThreadRegistry::registerThread(PostingThreadRole::BODY, bodyThread); ThreadRegistry::registerThread(PostingThreadRole::WORLD, worldThread); ThreadRegistry::registerThread(PostingThreadRole::LEG, legThread); +} +void PostingThreadSet::unregisterAllThreads() +{ + ThreadRegistry::unregisterThread(PostingThreadRole::CALLER, callerThread); + ThreadRegistry::unregisterThread(PostingThreadRole::CALLEE, calleeThread); + ThreadRegistry::unregisterThread( + PostingThreadRole::ALTERNATE, + alternateThread); + ThreadRegistry::unregisterThread(PostingThreadRole::BODY, bodyThread); + ThreadRegistry::unregisterThread(PostingThreadRole::WORLD, worldThread); + ThreadRegistry::unregisterThread(PostingThreadRole::LEG, legThread); +} + +void PostingThreadSet::installCallerAsPuppeteer() +{ sscl::ComponentThread::setPuppeteerThreadId( static_cast(PostingThreadRole::CALLER)); sscl::ComponentThread::setPuppeteerThread(callerThread.componentThread()); } -PostingThreadSet::~PostingThreadSet() +void PostingThreadSet::restorePreviousPuppeteer() { - ThreadRegistry::unregisterThread(PostingThreadRole::CALLER); - ThreadRegistry::unregisterThread(PostingThreadRole::CALLEE); - ThreadRegistry::unregisterThread(PostingThreadRole::ALTERNATE); - ThreadRegistry::unregisterThread(PostingThreadRole::BODY); - ThreadRegistry::unregisterThread(PostingThreadRole::WORLD); - ThreadRegistry::unregisterThread(PostingThreadRole::LEG); - - sscl::ComponentThread::setPuppeteerThread(nullptr); + sscl::ComponentThread::setPuppeteerThreadId(previousPuppeteerThreadId); + sscl::ComponentThread::setPuppeteerThread(previousPuppeteerThread); } DedicatedIoThread &PostingThreadSet::thread(PostingThreadRole role) diff --git a/tests/support/threadHarness.h b/tests/support/threadHarness.h index 6f79422..45ccd23 100644 --- a/tests/support/threadHarness.h +++ b/tests/support/threadHarness.h @@ -180,7 +180,9 @@ public: static void registerThread( PostingThreadRole role, DedicatedIoThread &thread); - static void unregisterThread(PostingThreadRole role); + static void unregisterThread( + PostingThreadRole role, + DedicatedIoThread &expectedThread); static boost::asio::io_context &ioContext(PostingThreadRole role); static std::thread::id osThreadId(PostingThreadRole role); @@ -240,14 +242,28 @@ public: DedicatedIoThread &leg(); private: + void registerAllThreads(); + void unregisterAllThreads(); + void installCallerAsPuppeteer(); + void restorePreviousPuppeteer(); + DedicatedIoThread callerThread; DedicatedIoThread calleeThread; DedicatedIoThread alternateThread; DedicatedIoThread bodyThread; DedicatedIoThread worldThread; DedicatedIoThread legThread; + std::shared_ptr previousPuppeteerThread; + sscl::ThreadId previousPuppeteerThreadId = 0; }; +template +auto RunOnThread(DedicatedIoThread &thread, Function &&function) + -> std::invoke_result_t +{ + return thread.runSync(std::forward(function)); +} + class CrossThreadTrace { public: