Adversarial review on test porting plan

This commit is contained in:
2026-06-13 17:59:06 -04:00
parent a29c779f6e
commit 2f31e9a034
7 changed files with 593 additions and 173 deletions
+125 -96
View File
@@ -88,75 +88,17 @@ CalleeVoidInvoker voidMemberAfterDelay(int delayMilliseconds)
co_return; co_return;
} }
int readCompletedLabel(CalleeIntInvoker &invoker) CalleeIntInvoker waitRecordThreadAndReturnLabel(
int timerLabelMilliseconds,
sscl::tests::CrossThreadTrace &trace)
{ {
return invoker.completedReturnValues().myReturnValue; const boost::system::error_code waitError =
} co_await sscl::tests::DeadlineTimerAwaiter{
sscl::ComponentThread::getSelf()->getIoContext(),
void assertCompleted( timerLabelMilliseconds};
const sscl::co::Group::SettlementDescriptor &descriptor, sscl::tests::throwIfTimerWaitFailed(waitError);
int expectedLabel) trace.recordCalleeExecutionThread();
{ co_return timerLabelMilliseconds;
if (descriptor.type
!= sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED) {
throw std::runtime_error("expected completed settlement");
}
if (readCompletedLabel(descriptor.invokerAs<CalleeIntInvoker>())
!= expectedLabel) {
throw std::runtime_error("settlement label mismatch");
}
}
void assertRuntimeErrorSettlement(
const sscl::co::Group::SettlementDescriptor &descriptor)
{
if (descriptor.type
!= sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN) {
throw std::runtime_error("expected exception settlement");
}
try {
std::rethrow_exception(descriptor.calleeException);
}
catch (const std::runtime_error &runtimeError) {
if (std::string(runtimeError.what()) != expectedThrowMessage) {
throw std::runtime_error("unexpected exception message");
}
return;
}
throw std::runtime_error("expected runtime_error settlement");
}
void assertIntExceptionSettlement(
const sscl::co::Group::SettlementDescriptor &descriptor)
{
if (descriptor.type
!= sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN) {
throw std::runtime_error("expected int exception settlement");
}
try {
std::rethrow_exception(descriptor.calleeException);
}
catch (int caughtValue) {
if (caughtValue != expectedNonStdThrowValue) {
throw std::runtime_error("unexpected int exception value");
}
return;
}
throw std::runtime_error("expected int exception settlement");
}
void assertEmptyGroupCoAwaitError(const std::runtime_error &runtimeError)
{
constexpr const char *expectedEmptyGroupCoAwaitMessage =
"co_await: Group has no member invokers; call add() before awaiting";
if (std::string(runtimeError.what()) != expectedEmptyGroupCoAwaitMessage) {
throw std::runtime_error("unexpected empty-group error message");
}
} }
sscl::co::ViralNonPostingInvoker<void> waitOnCallerThread(int delayMilliseconds) sscl::co::ViralNonPostingInvoker<void> waitOnCallerThread(int delayMilliseconds)
@@ -188,11 +130,15 @@ CallerDriver mixedSuccessAndFailureAwaitFirstThenAll(
if (firstDescriptor.type if (firstDescriptor.type
== sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED) { == sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED) {
assertCompleted(firstDescriptor, 1); sscl::tests::requireCompletedIntSettlement<CalleeIntInvoker>(
firstDescriptor,
1);
} }
else if (firstDescriptor.type else if (firstDescriptor.type
== sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN) { == sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN) {
assertRuntimeErrorSettlement(firstDescriptor); sscl::tests::requireRuntimeErrorSettlement(
firstDescriptor,
expectedThrowMessage);
} }
else { else {
throw std::runtime_error("first settlement has unexpected type"); throw std::runtime_error("first settlement has unexpected type");
@@ -212,12 +158,16 @@ CallerDriver mixedSuccessAndFailureAwaitFirstThenAll(
if (descriptor.type if (descriptor.type
== sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED) { == sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED) {
++completedCount; ++completedCount;
assertCompleted(descriptor, 1); sscl::tests::requireCompletedIntSettlement<CalleeIntInvoker>(
descriptor,
1);
} }
else if (descriptor.type else if (descriptor.type
== sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN) { == sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN) {
++exceptionCount; ++exceptionCount;
assertRuntimeErrorSettlement(descriptor); sscl::tests::requireRuntimeErrorSettlement(
descriptor,
expectedThrowMessage);
} }
} }
@@ -241,7 +191,9 @@ CallerDriver singleMemberAwaitFirstThenAll(
auto awaitFirst = group.getAwaitFirstSettlementInvoker(); auto awaitFirst = group.getAwaitFirstSettlementInvoker();
auto [firstDescriptor, allAfterFirst] = co_await awaitFirst; auto [firstDescriptor, allAfterFirst] = co_await awaitFirst;
assertCompleted(firstDescriptor, delayShortMs); sscl::tests::requireCompletedIntSettlement<CalleeIntInvoker>(
firstDescriptor,
delayShortMs);
if (!group.allInvokersSettled() || allAfterFirst.size() != 1) { if (!group.allInvokersSettled() || allAfterFirst.size() != 1) {
throw std::runtime_error("single member state mismatch"); throw std::runtime_error("single member state mismatch");
@@ -254,7 +206,9 @@ CallerDriver singleMemberAwaitFirstThenAll(
throw std::runtime_error("single member await-all count mismatch"); throw std::runtime_error("single member await-all count mismatch");
} }
assertCompleted(allDescriptors[0], delayShortMs); sscl::tests::requireCompletedIntSettlement<CalleeIntInvoker>(
allDescriptors[0],
delayShortMs);
co_return; co_return;
} }
@@ -282,7 +236,9 @@ CallerDriver allCompleteBeforeCoAwait(
auto awaitFirst = group.getAwaitFirstSettlementInvoker(); auto awaitFirst = group.getAwaitFirstSettlementInvoker();
auto [firstDescriptor, allAfterFirst] = co_await awaitFirst; auto [firstDescriptor, allAfterFirst] = co_await awaitFirst;
assertCompleted(firstDescriptor, 10); sscl::tests::requireCompletedIntSettlement<CalleeIntInvoker>(
firstDescriptor,
10);
auto awaitAll = group.getAwaitAllSettlementsInvoker(); auto awaitAll = group.getAwaitAllSettlementsInvoker();
auto &allDescriptors = co_await awaitAll; auto &allDescriptors = co_await awaitAll;
@@ -294,13 +250,13 @@ CallerDriver allCompleteBeforeCoAwait(
co_return; co_return;
} }
std::thread startAddWhileGroupAwaiterSuspendedProbe( std::jthread startAddWhileGroupAwaiterSuspendedProbe(
sscl::co::Group &group, sscl::co::Group &group,
CalleeIntInvoker &lateInvoker, CalleeIntInvoker &lateInvoker,
std::atomic<bool> &groupIsAwaitingAll, std::atomic<bool> &groupIsAwaitingAll,
std::atomic<bool> &addWasRejected) std::atomic<bool> &addWasRejected)
{ {
return std::thread( return std::jthread(
[&]() [&]()
{ {
while (!groupIsAwaitingAll.load(std::memory_order_acquire)) { while (!groupIsAwaitingAll.load(std::memory_order_acquire)) {
@@ -343,7 +299,7 @@ CallerDriver addWhileAwaitAllSuspended(
group.add(slowInvokerA); group.add(slowInvokerA);
group.add(slowInvokerB); group.add(slowInvokerB);
std::thread addProbeThread = startAddWhileGroupAwaiterSuspendedProbe( std::jthread addProbeThread = startAddWhileGroupAwaiterSuspendedProbe(
group, group,
lateInvoker, lateInvoker,
groupIsAwaitingAll, groupIsAwaitingAll,
@@ -390,12 +346,16 @@ CallerDriver awaitAllOnlyMixedOutcomes(
if (descriptor.type if (descriptor.type
== sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED) { == sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED) {
++completedCount; ++completedCount;
assertCompleted(descriptor, 7); sscl::tests::requireCompletedIntSettlement<CalleeIntInvoker>(
descriptor,
7);
} }
else if (descriptor.type else if (descriptor.type
== sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN) { == sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN) {
++exceptionCount; ++exceptionCount;
assertRuntimeErrorSettlement(descriptor); sscl::tests::requireRuntimeErrorSettlement(
descriptor,
expectedThrowMessage);
} }
} }
@@ -446,7 +406,7 @@ CallerDriver emptyGroupAwaitAllThrows(
(void)co_await group.getAwaitAllSettlementsInvoker(); (void)co_await group.getAwaitAllSettlementsInvoker();
} }
catch (const std::runtime_error &runtimeError) { catch (const std::runtime_error &runtimeError) {
assertEmptyGroupCoAwaitError(runtimeError); sscl::tests::requireEmptyGroupError(runtimeError);
co_return; co_return;
} }
@@ -466,7 +426,7 @@ CallerDriver emptyGroupAwaitFirstThrows(
(void)co_await group.getAwaitFirstSettlementInvoker(); (void)co_await group.getAwaitFirstSettlementInvoker();
} }
catch (const std::runtime_error &runtimeError) { catch (const std::runtime_error &runtimeError) {
assertEmptyGroupCoAwaitError(runtimeError); sscl::tests::requireEmptyGroupError(runtimeError);
co_return; co_return;
} }
@@ -496,9 +456,10 @@ CallerDriver wrongAwaitInvokerOrder(
} }
auto [firstDescriptor, allAfterFirst] = co_await awaitFirstHandle; auto [firstDescriptor, allAfterFirst] = co_await awaitFirstHandle;
assertCompleted( sscl::tests::requireCompletedIntSettlement<CalleeIntInvoker>(
firstDescriptor, firstDescriptor,
readCompletedLabel(firstDescriptor.invokerAs<CalleeIntInvoker>())); sscl::tests::completedIntValue(
firstDescriptor.invokerAs<CalleeIntInvoker>()));
if (!group.firstInvokerSettled() || allAfterFirst.size() != 2) { if (!group.firstInvokerSettled() || allAfterFirst.size() != 2) {
throw std::runtime_error("wrong-order await-first state mismatch"); throw std::runtime_error("wrong-order await-first state mismatch");
@@ -524,8 +485,12 @@ CallerDriver doubleCoAwaitSameAwaitFirst(
auto [firstDescriptorA, allAfterFirstA] = co_await awaitFirst; auto [firstDescriptorA, allAfterFirstA] = co_await awaitFirst;
auto [firstDescriptorB, allAfterFirstB] = co_await awaitFirst; auto [firstDescriptorB, allAfterFirstB] = co_await awaitFirst;
assertCompleted(firstDescriptorA, delayShortMs); sscl::tests::requireCompletedIntSettlement<CalleeIntInvoker>(
assertCompleted(firstDescriptorB, delayShortMs); firstDescriptorA,
delayShortMs);
sscl::tests::requireCompletedIntSettlement<CalleeIntInvoker>(
firstDescriptorB,
delayShortMs);
if (&firstDescriptorA.invokerAs<CalleeIntInvoker>() if (&firstDescriptorA.invokerAs<CalleeIntInvoker>()
!= &firstDescriptorB.invokerAs<CalleeIntInvoker>()) { != &firstDescriptorB.invokerAs<CalleeIntInvoker>()) {
@@ -558,8 +523,12 @@ CallerDriver doubleCoAwaitSameAwaitAll(
throw std::runtime_error("double await-all count mismatch"); throw std::runtime_error("double await-all count mismatch");
} }
assertCompleted(allDescriptorsA[0], delayShortMs); sscl::tests::requireCompletedIntSettlement<CalleeIntInvoker>(
assertCompleted(allDescriptorsB[0], delayShortMs); allDescriptorsA[0],
delayShortMs);
sscl::tests::requireCompletedIntSettlement<CalleeIntInvoker>(
allDescriptorsB[0],
delayShortMs);
co_return; co_return;
} }
@@ -579,11 +548,15 @@ CallerDriver twoAwaitFirstHandlesSequentially(
auto awaitFirstA = group.getAwaitFirstSettlementInvoker(); auto awaitFirstA = group.getAwaitFirstSettlementInvoker();
auto [firstDescriptorA, allAfterFirstA] = co_await awaitFirstA; auto [firstDescriptorA, allAfterFirstA] = co_await awaitFirstA;
assertCompleted(firstDescriptorA, delayShortMs); sscl::tests::requireCompletedIntSettlement<CalleeIntInvoker>(
firstDescriptorA,
delayShortMs);
auto awaitFirstB = group.getAwaitFirstSettlementInvoker(); auto awaitFirstB = group.getAwaitFirstSettlementInvoker();
auto [firstDescriptorB, allAfterFirstB] = co_await awaitFirstB; auto [firstDescriptorB, allAfterFirstB] = co_await awaitFirstB;
assertCompleted(firstDescriptorB, delayShortMs); sscl::tests::requireCompletedIntSettlement<CalleeIntInvoker>(
firstDescriptorB,
delayShortMs);
if (&firstDescriptorA.invokerAs<CalleeIntInvoker>() if (&firstDescriptorA.invokerAs<CalleeIntInvoker>()
!= &firstDescriptorB.invokerAs<CalleeIntInvoker>()) { != &firstDescriptorB.invokerAs<CalleeIntInvoker>()) {
@@ -620,7 +593,7 @@ CallerDriver addSecondWaveAfterAwaitAll(
co_await waitOnCallerThread(delayShortMs); co_await waitOnCallerThread(delayShortMs);
if (readCompletedLabel(wave2Immediate) if (sscl::tests::completedIntValue(wave2Immediate)
!= wave2ImmediateSettlementLabel) { != wave2ImmediateSettlementLabel) {
throw std::runtime_error("wave-2 immediate member did not complete"); throw std::runtime_error("wave-2 immediate member did not complete");
} }
@@ -656,7 +629,9 @@ CallerDriver shortTimerAddedAfterLongStillWinsRace(
auto awaitFirst = group.getAwaitFirstSettlementInvoker(); auto awaitFirst = group.getAwaitFirstSettlementInvoker();
auto [firstDescriptor, allAfterFirst] = co_await awaitFirst; auto [firstDescriptor, allAfterFirst] = co_await awaitFirst;
assertCompleted(firstDescriptor, delayShortMs); sscl::tests::requireCompletedIntSettlement<CalleeIntInvoker>(
firstDescriptor,
delayShortMs);
if (&firstDescriptor.invokerAs<CalleeIntInvoker>() != &shortInvoker) { if (&firstDescriptor.invokerAs<CalleeIntInvoker>() != &shortInvoker) {
throw std::runtime_error("short timer should win first settlement"); throw std::runtime_error("short timer should win first settlement");
@@ -684,7 +659,9 @@ CallerDriver nonStdExceptionSettlement(
throw std::runtime_error("non-std exception count mismatch"); throw std::runtime_error("non-std exception count mismatch");
} }
assertIntExceptionSettlement(allDescriptors[0]); sscl::tests::requireIntExceptionSettlement(
allDescriptors[0],
expectedNonStdThrowValue);
try { try {
group.checkForAndReThrowGroupExceptions(); group.checkForAndReThrowGroupExceptions();
@@ -738,11 +715,14 @@ CallerDriver returnValuesRemainReadableAfterAwaitFirst(
auto awaitFirst = group.getAwaitFirstSettlementInvoker(); auto awaitFirst = group.getAwaitFirstSettlementInvoker();
auto [firstDescriptor, allAfterFirst] = co_await awaitFirst; auto [firstDescriptor, allAfterFirst] = co_await awaitFirst;
assertCompleted(firstDescriptor, delayShortMs); sscl::tests::requireCompletedIntSettlement<CalleeIntInvoker>(
firstDescriptor,
delayShortMs);
const int fastLabelFromDescriptor = readCompletedLabel( const int fastLabelFromDescriptor = sscl::tests::completedIntValue(
firstDescriptor.invokerAs<CalleeIntInvoker>()); firstDescriptor.invokerAs<CalleeIntInvoker>());
const int fastLabelFromLocal = readCompletedLabel(fastInvoker); const int fastLabelFromLocal =
sscl::tests::completedIntValue(fastInvoker);
if (fastLabelFromDescriptor != fastLabelFromLocal) { if (fastLabelFromDescriptor != fastLabelFromLocal) {
throw std::runtime_error("descriptor/local return value mismatch"); throw std::runtime_error("descriptor/local return value mismatch");
@@ -756,6 +736,35 @@ CallerDriver returnValuesRemainReadableAfterAwaitFirst(
co_return; co_return;
} }
CallerDriver groupMemberRunsOnCalleeAndAwaitResumesOnCaller(
std::exception_ptr &exceptionPtr,
std::function<void()> completion,
sscl::tests::CrossThreadTrace &trace)
{
(void)exceptionPtr;
(void)completion;
sscl::co::Group group;
CalleeIntInvoker memberInvoker = waitRecordThreadAndReturnLabel(
delayShortMs,
trace);
group.add(memberInvoker);
auto awaitFirst = group.getAwaitFirstSettlementInvoker();
auto [firstDescriptor, allAfterFirst] = co_await awaitFirst;
trace.recordAwaitResumeThread();
sscl::tests::requireCompletedIntSettlement<CalleeIntInvoker>(
firstDescriptor,
delayShortMs);
if (allAfterFirst.size() != 1) {
throw std::runtime_error("cross-thread group trace count mismatch");
}
co_return;
}
class GroupEdgeTest class GroupEdgeTest
: public ::testing::Test : public ::testing::Test
{ {
@@ -816,6 +825,26 @@ RUN_GROUP_EDGE_SCENARIO(
ReturnValuesRemainReadableAfterAwaitFirst, ReturnValuesRemainReadableAfterAwaitFirst,
returnValuesRemainReadableAfterAwaitFirst) returnValuesRemainReadableAfterAwaitFirst)
TEST_F(GroupEdgeTest, SuspendingMemberRunsOnCalleeAndAwaitResumesOnCaller)
{
sscl::tests::CrossThreadTrace trace;
runScenario(
[&trace](
std::exception_ptr &exceptionPtr,
std::function<void()> completion)
{
return groupMemberRunsOnCalleeAndAwaitResumesOnCaller(
exceptionPtr,
std::move(completion),
trace);
});
EXPECT_EQ(trace.calleeExecutionThread(), threads.callee().osThreadId());
EXPECT_EQ(trace.awaitResumeThread(), threads.caller().osThreadId());
EXPECT_NE(trace.calleeExecutionThread(), trace.awaitResumeThread());
}
TEST_F(GroupEdgeTest, NonViralVoidGroupTemplateInstantiates) TEST_F(GroupEdgeTest, NonViralVoidGroupTemplateInstantiates)
{ {
GTEST_SKIP() GTEST_SKIP()
+121 -21
View File
@@ -1,8 +1,11 @@
#include <chrono> #include <chrono>
#include <exception> #include <exception>
#include <functional> #include <functional>
#include <map>
#include <mutex>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <thread>
#include <gtest/gtest.h> #include <gtest/gtest.h>
@@ -36,19 +39,77 @@ using CalleeIntInvoker =
using Clock = std::chrono::steady_clock; using Clock = std::chrono::steady_clock;
using Ms = std::chrono::milliseconds; using Ms = std::chrono::milliseconds;
CalleeIntInvoker waitDeadlineTimer(int timerLabelMilliseconds) class GroupTimerThreadTrace
{
public:
void recordTimerCompletionThread(int timerLabelMilliseconds)
{
std::lock_guard<std::mutex> guard(mutex);
timerCompletionThreads[timerLabelMilliseconds] =
std::this_thread::get_id();
}
void recordAwaitFirstResumeThread()
{
std::lock_guard<std::mutex> guard(mutex);
awaitFirstResumeThread = std::this_thread::get_id();
}
void recordAwaitAllResumeThread()
{
std::lock_guard<std::mutex> guard(mutex);
awaitAllResumeThread = std::this_thread::get_id();
}
std::thread::id timerCompletionThread(int timerLabelMilliseconds) const
{
std::lock_guard<std::mutex> guard(mutex);
const auto iterator =
timerCompletionThreads.find(timerLabelMilliseconds);
if (iterator == timerCompletionThreads.end()) {
throw std::runtime_error("Missing timer completion thread trace");
}
return iterator->second;
}
std::thread::id awaitFirstThread() const
{
std::lock_guard<std::mutex> guard(mutex);
return awaitFirstResumeThread;
}
std::thread::id awaitAllThread() const
{
std::lock_guard<std::mutex> guard(mutex);
return awaitAllResumeThread;
}
private:
mutable std::mutex mutex;
std::map<int, std::thread::id> timerCompletionThreads;
std::thread::id awaitFirstResumeThread;
std::thread::id awaitAllResumeThread;
};
CalleeIntInvoker waitDeadlineTimer(
int timerLabelMilliseconds,
GroupTimerThreadTrace &trace)
{ {
const boost::system::error_code waitError = const boost::system::error_code waitError =
co_await sscl::tests::DeadlineTimerAwaiter{ co_await sscl::tests::DeadlineTimerAwaiter{
sscl::ComponentThread::getSelf()->getIoContext(), sscl::ComponentThread::getSelf()->getIoContext(),
timerLabelMilliseconds}; timerLabelMilliseconds};
sscl::tests::throwIfTimerWaitFailed(waitError); sscl::tests::throwIfTimerWaitFailed(waitError);
trace.recordTimerCompletionThread(timerLabelMilliseconds);
co_return timerLabelMilliseconds; co_return timerLabelMilliseconds;
} }
CalleeIntInvoker waitCancelableDeadlineTimer( CalleeIntInvoker waitCancelableDeadlineTimer(
int timerLabelMilliseconds, int timerLabelMilliseconds,
sscl::tests::CancelableDeadlineTimerRegistry &registry) sscl::tests::CancelableDeadlineTimerRegistry &registry,
GroupTimerThreadTrace &trace)
{ {
const boost::system::error_code waitError = const boost::system::error_code waitError =
co_await sscl::tests::RegisteredDeadlineTimerAwaiter{ co_await sscl::tests::RegisteredDeadlineTimerAwaiter{
@@ -58,10 +119,12 @@ CalleeIntInvoker waitCancelableDeadlineTimer(
registry}; registry};
if (sscl::tests::timerWasCanceled(waitError)) { if (sscl::tests::timerWasCanceled(waitError)) {
trace.recordTimerCompletionThread(timerLabelMilliseconds);
co_return timerLabelMilliseconds; co_return timerLabelMilliseconds;
} }
sscl::tests::throwIfTimerWaitFailed(waitError); sscl::tests::throwIfTimerWaitFailed(waitError);
trace.recordTimerCompletionThread(timerLabelMilliseconds);
co_return timerLabelMilliseconds; co_return timerLabelMilliseconds;
} }
@@ -89,15 +152,19 @@ void throwIfElapsedTooShort(
CallerDriver runGroupTimerRace( CallerDriver runGroupTimerRace(
std::exception_ptr &exceptionPtr, std::exception_ptr &exceptionPtr,
std::function<void()> completion) std::function<void()> completion,
GroupTimerThreadTrace &trace)
{ {
(void)exceptionPtr; (void)exceptionPtr;
(void)completion; (void)completion;
sscl::co::Group group; sscl::co::Group group;
CalleeIntInvoker invokerShort = waitDeadlineTimer(timerDelayShortMs); CalleeIntInvoker invokerShort =
CalleeIntInvoker invokerMedium = waitDeadlineTimer(timerDelayMediumMs); waitDeadlineTimer(timerDelayShortMs, trace);
CalleeIntInvoker invokerLong = waitDeadlineTimer(timerDelayLongMs); CalleeIntInvoker invokerMedium =
waitDeadlineTimer(timerDelayMediumMs, trace);
CalleeIntInvoker invokerLong =
waitDeadlineTimer(timerDelayLongMs, trace);
group.add(invokerShort); group.add(invokerShort);
group.add(invokerMedium); group.add(invokerMedium);
@@ -107,6 +174,7 @@ CallerDriver runGroupTimerRace(
auto awaitFirst = group.getAwaitFirstSettlementInvoker(); auto awaitFirst = group.getAwaitFirstSettlementInvoker();
auto [firstSettlement, allSettlementsAfterFirst] = co_await awaitFirst; auto [firstSettlement, allSettlementsAfterFirst] = co_await awaitFirst;
trace.recordAwaitFirstResumeThread();
const auto firstElapsedMs = const auto firstElapsedMs =
std::chrono::duration_cast<Ms>(Clock::now() - testStart); std::chrono::duration_cast<Ms>(Clock::now() - testStart);
@@ -125,6 +193,7 @@ CallerDriver runGroupTimerRace(
auto awaitAll = group.getAwaitAllSettlementsInvoker(); auto awaitAll = group.getAwaitAllSettlementsInvoker();
auto &allSettlements = co_await awaitAll; auto &allSettlements = co_await awaitAll;
trace.recordAwaitAllResumeThread();
const auto allElapsedMs = const auto allElapsedMs =
std::chrono::duration_cast<Ms>(Clock::now() - testStart); std::chrono::duration_cast<Ms>(Clock::now() - testStart);
@@ -137,16 +206,16 @@ CallerDriver runGroupTimerRace(
throw std::runtime_error("expected three settlements"); throw std::runtime_error("expected three settlements");
} }
sscl::tests::expectCompletedIntSettlement<CalleeIntInvoker>( sscl::tests::requireCompletedIntSettlement<CalleeIntInvoker>(
firstSettlement, firstSettlement,
timerDelayShortMs); timerDelayShortMs);
sscl::tests::expectCompletedIntSettlement<CalleeIntInvoker>( sscl::tests::requireCompletedIntSettlement<CalleeIntInvoker>(
allSettlementsAfterFirst[0], allSettlementsAfterFirst[0],
timerDelayShortMs); timerDelayShortMs);
sscl::tests::expectCompletedIntSettlement<CalleeIntInvoker>( sscl::tests::requireCompletedIntSettlement<CalleeIntInvoker>(
allSettlementsAfterFirst[1], allSettlementsAfterFirst[1],
timerDelayMediumMs); timerDelayMediumMs);
sscl::tests::expectCompletedIntSettlement<CalleeIntInvoker>( sscl::tests::requireCompletedIntSettlement<CalleeIntInvoker>(
allSettlementsAfterFirst[2], allSettlementsAfterFirst[2],
timerDelayLongMs); timerDelayLongMs);
@@ -156,18 +225,19 @@ CallerDriver runGroupTimerRace(
CallerDriver runGroupTimerCancelLongAfterAwaitFirst( CallerDriver runGroupTimerCancelLongAfterAwaitFirst(
std::exception_ptr &exceptionPtr, std::exception_ptr &exceptionPtr,
std::function<void()> completion, std::function<void()> completion,
sscl::tests::CancelableDeadlineTimerRegistry &registry) sscl::tests::CancelableDeadlineTimerRegistry &registry,
GroupTimerThreadTrace &trace)
{ {
(void)exceptionPtr; (void)exceptionPtr;
(void)completion; (void)completion;
sscl::co::Group group; sscl::co::Group group;
CalleeIntInvoker invokerShort = CalleeIntInvoker invokerShort =
waitCancelableDeadlineTimer(timerDelayShortMs, registry); waitCancelableDeadlineTimer(timerDelayShortMs, registry, trace);
CalleeIntInvoker invokerMedium = CalleeIntInvoker invokerMedium =
waitCancelableDeadlineTimer(timerDelayMediumMs, registry); waitCancelableDeadlineTimer(timerDelayMediumMs, registry, trace);
CalleeIntInvoker invokerLong = CalleeIntInvoker invokerLong =
waitCancelableDeadlineTimer(timerDelayLongMs, registry); waitCancelableDeadlineTimer(timerDelayLongMs, registry, trace);
group.add(invokerShort); group.add(invokerShort);
group.add(invokerMedium); group.add(invokerMedium);
@@ -177,6 +247,7 @@ CallerDriver runGroupTimerCancelLongAfterAwaitFirst(
auto awaitFirst = group.getAwaitFirstSettlementInvoker(); auto awaitFirst = group.getAwaitFirstSettlementInvoker();
auto [firstSettlement, allSettlementsAfterFirst] = co_await awaitFirst; auto [firstSettlement, allSettlementsAfterFirst] = co_await awaitFirst;
trace.recordAwaitFirstResumeThread();
if (&firstSettlement.invokerAs<CalleeIntInvoker>() != &invokerShort) { if (&firstSettlement.invokerAs<CalleeIntInvoker>() != &invokerShort) {
throw std::runtime_error("cancel test first settlement mismatch"); throw std::runtime_error("cancel test first settlement mismatch");
@@ -190,6 +261,7 @@ CallerDriver runGroupTimerCancelLongAfterAwaitFirst(
auto awaitAll = group.getAwaitAllSettlementsInvoker(); auto awaitAll = group.getAwaitAllSettlementsInvoker();
auto &allSettlements = co_await awaitAll; auto &allSettlements = co_await awaitAll;
trace.recordAwaitAllResumeThread();
const auto allElapsedMs = const auto allElapsedMs =
std::chrono::duration_cast<Ms>(Clock::now() - testStart); std::chrono::duration_cast<Ms>(Clock::now() - testStart);
@@ -207,13 +279,13 @@ CallerDriver runGroupTimerCancelLongAfterAwaitFirst(
throw std::runtime_error("cancel test expected three settlements"); throw std::runtime_error("cancel test expected three settlements");
} }
sscl::tests::expectCompletedIntSettlement<CalleeIntInvoker>( sscl::tests::requireCompletedIntSettlement<CalleeIntInvoker>(
allSettlements[0], allSettlements[0],
timerDelayShortMs); timerDelayShortMs);
sscl::tests::expectCompletedIntSettlement<CalleeIntInvoker>( sscl::tests::requireCompletedIntSettlement<CalleeIntInvoker>(
allSettlements[1], allSettlements[1],
timerDelayMediumMs); timerDelayMediumMs);
sscl::tests::expectCompletedIntSettlement<CalleeIntInvoker>( sscl::tests::requireCompletedIntSettlement<CalleeIntInvoker>(
allSettlements[2], allSettlements[2],
timerDelayLongMs); timerDelayLongMs);
@@ -229,6 +301,25 @@ class GroupTimerTest
: public ::testing::Test : public ::testing::Test
{ {
protected: protected:
void assertTimerTraceCrossedThreads(
const GroupTimerThreadTrace &trace)
{
EXPECT_EQ(
trace.timerCompletionThread(timerDelayShortMs),
threads.callee().osThreadId());
EXPECT_EQ(
trace.timerCompletionThread(timerDelayMediumMs),
threads.callee().osThreadId());
EXPECT_EQ(
trace.timerCompletionThread(timerDelayLongMs),
threads.callee().osThreadId());
EXPECT_EQ(trace.awaitFirstThread(), threads.caller().osThreadId());
EXPECT_EQ(trace.awaitAllThread(), threads.caller().osThreadId());
EXPECT_NE(
trace.timerCompletionThread(timerDelayShortMs),
trace.awaitFirstThread());
}
sscl::tests::PostingThreadSet threads; sscl::tests::PostingThreadSet threads;
}; };
@@ -236,33 +327,42 @@ protected:
TEST_F(GroupTimerTest, AwaitFirstReturnsShortestTimerAndAwaitAllWaitsForLongest) TEST_F(GroupTimerTest, AwaitFirstReturnsShortestTimerAndAwaitAllWaitsForLongest)
{ {
GroupTimerThreadTrace trace;
ASSERT_NO_THROW( ASSERT_NO_THROW(
sscl::tests::runNonViralPostingTask( sscl::tests::runNonViralPostingTask(
threads.caller(), threads.caller(),
[]( [&trace](
std::exception_ptr &exceptionPtr, std::exception_ptr &exceptionPtr,
std::function<void()> completion) std::function<void()> completion)
{ {
return runGroupTimerRace( return runGroupTimerRace(
exceptionPtr, exceptionPtr,
std::move(completion)); std::move(completion),
trace);
})); }));
assertTimerTraceCrossedThreads(trace);
} }
TEST_F(GroupTimerTest, CancelLongTimerAfterAwaitFirst) TEST_F(GroupTimerTest, CancelLongTimerAfterAwaitFirst)
{ {
sscl::tests::CancelableDeadlineTimerRegistry registry; sscl::tests::CancelableDeadlineTimerRegistry registry;
GroupTimerThreadTrace trace;
ASSERT_NO_THROW( ASSERT_NO_THROW(
sscl::tests::runNonViralPostingTask( sscl::tests::runNonViralPostingTask(
threads.caller(), threads.caller(),
[&registry]( [&registry, &trace](
std::exception_ptr &exceptionPtr, std::exception_ptr &exceptionPtr,
std::function<void()> completion) std::function<void()> completion)
{ {
return runGroupTimerCancelLongAfterAwaitFirst( return runGroupTimerCancelLongAfterAwaitFirst(
exceptionPtr, exceptionPtr,
std::move(completion), std::move(completion),
registry); registry,
trace);
})); }));
assertTimerTraceCrossedThreads(trace);
} }
+130 -8
View File
@@ -12,7 +12,11 @@
#include <boost/system/error_code.hpp> #include <boost/system/error_code.hpp>
#include <spinscale/co/invokers.h> #include <spinscale/co/invokers.h>
#include <spinscale/co/group.h>
#include <spinscale/componentThread.h>
#include <support/coroutineDriver.h>
#include <support/groupAssertions.h>
#include <support/threadHarness.h> #include <support/threadHarness.h>
#include <support/timerAwaiters.h> #include <support/timerAwaiters.h>
@@ -28,6 +32,9 @@ using TestInvoker = sscl::co::ViralNonPostingInvoker<T>;
using TestDriver = TestInvoker<int>; using TestDriver = TestInvoker<int>;
using TestVoidDriver = TestInvoker<void>; using TestVoidDriver = TestInvoker<void>;
using CallerPostingDriver =
sscl::tests::RoleNonViralPostingInvoker<
sscl::tests::PostingThreadRole::CALLER>;
struct ThreadIdPair struct ThreadIdPair
{ {
@@ -101,18 +108,14 @@ protected:
int runDriver(TestDriver &driver) int runDriver(TestDriver &driver)
{ {
sscl::tests::IoContextPump::pumpUntilIdle(ioContext); return sscl::tests::CoroutineDriver::pumpUntilIdleAndReturnValue(
return finishDriver(driver); ioContext,
driver);
} }
int finishDriver(TestDriver &driver) int finishDriver(TestDriver &driver)
{ {
if (driver.completedReturnValues().myExceptionPtr) { return sscl::tests::CoroutineDriver::completedReturnValue(driver);
std::rethrow_exception(
driver.completedReturnValues().myExceptionPtr);
}
return driver.completedReturnValues().myReturnValue;
} }
boost::asio::io_context ioContext; boost::asio::io_context ioContext;
@@ -140,6 +143,18 @@ TestVoidDriver voidReturnImmediately()
co_return; co_return;
} }
TestVoidDriver voidMemberAfterDelay(
boost::asio::io_context &ioContext,
int delayMilliseconds)
{
const boost::system::error_code waitError =
co_await sscl::tests::DeadlineTimerAwaiter{
ioContext,
delayMilliseconds};
sscl::tests::throwIfTimerWaitFailed(waitError);
co_return;
}
TestInvoker<int> throwRuntimeErrorImmediately() TestInvoker<int> throwRuntimeErrorImmediately()
{ {
throw std::runtime_error(expectedThrowMessage); throw std::runtime_error(expectedThrowMessage);
@@ -412,6 +427,79 @@ TestDriver testNestedInnerSuspension(boost::asio::io_context &ioContext)
co_return 0; co_return 0;
} }
CallerPostingDriver nonPostingVoidMemberInGroupDriver(
std::exception_ptr &exceptionPtr,
std::function<void()> completion)
{
(void)exceptionPtr;
(void)completion;
sscl::co::Group group;
TestVoidDriver voidInvoker = voidMemberAfterDelay(
sscl::ComponentThread::getSelf()->getIoContext(),
delayShortMs);
group.add(voidInvoker);
auto &allDescriptors = co_await group.getAwaitAllSettlementsInvoker();
if (allDescriptors.size() != 1) {
throw std::runtime_error("voidMemberInGroup count mismatch");
}
sscl::tests::requireCompletedSettlement(allDescriptors[0]);
co_return;
}
CallerPostingDriver nonPostingGroupMixedImmediateAndDelayedDriver(
std::exception_ptr &exceptionPtr,
std::function<void()> completion)
{
(void)exceptionPtr;
(void)completion;
sscl::co::Group group;
TestInvoker<int> immediateInvoker = returnLabelImmediately(11);
TestInvoker<int> delayedInvoker = waitAndReturnLabel(
sscl::ComponentThread::getSelf()->getIoContext(),
delayShortMs);
group.add(immediateInvoker);
group.add(delayedInvoker);
auto &allDescriptors = co_await group.getAwaitAllSettlementsInvoker();
if (allDescriptors.size() != 2) {
throw std::runtime_error("groupMixedImmediateAndDelayed count mismatch");
}
bool sawImmediate = false;
bool sawDelayed = false;
for (auto &descriptor : allDescriptors) {
sscl::tests::requireCompletedSettlement(descriptor);
const int label = sscl::tests::completedIntValue(
descriptor.invokerAs<TestInvoker<int>>());
if (label == 11) {
sawImmediate = true;
}
else if (label == delayShortMs) {
sawDelayed = true;
}
else {
throw std::runtime_error(
"groupMixedImmediateAndDelayed unexpected label");
}
}
if (!sawImmediate || !sawDelayed) {
throw std::runtime_error(
"groupMixedImmediateAndDelayed missing expected label");
}
co_return;
}
} // namespace } // namespace
TEST_F(ViralNonPostingTest, ImmediateReturnFastPath) TEST_F(ViralNonPostingTest, ImmediateReturnFastPath)
@@ -509,3 +597,37 @@ TEST_F(ViralNonPostingTest, NestedInnerSuspension)
TestDriver driver = testNestedInnerSuspension(ioContext); TestDriver driver = testNestedInnerSuspension(ioContext);
EXPECT_NO_THROW({ EXPECT_EQ(runDriver(driver), 0); }); EXPECT_NO_THROW({ EXPECT_EQ(runDriver(driver), 0); });
} }
TEST(ViralNonPostingGroupIntegrationTest, VoidMemberInGroup)
{
sscl::tests::PostingThreadSet threads;
ASSERT_NO_THROW(
sscl::tests::runNonViralPostingTask(
threads.caller(),
[](
std::exception_ptr &exceptionPtr,
std::function<void()> completion)
{
return nonPostingVoidMemberInGroupDriver(
exceptionPtr,
std::move(completion));
}));
}
TEST(ViralNonPostingGroupIntegrationTest, MixedImmediateAndDelayedInGroup)
{
sscl::tests::PostingThreadSet threads;
ASSERT_NO_THROW(
sscl::tests::runNonViralPostingTask(
threads.caller(),
[](
std::exception_ptr &exceptionPtr,
std::function<void()> completion)
{
return nonPostingGroupMixedImmediateAndDelayedDriver(
exceptionPtr,
std::move(completion));
}));
}
+38
View File
@@ -0,0 +1,38 @@
#ifndef SPINSCALE_TEST_SUPPORT_COROUTINE_DRIVER_H
#define SPINSCALE_TEST_SUPPORT_COROUTINE_DRIVER_H
#include <exception>
#include <boost/asio/io_context.hpp>
#include <support/threadHarness.h>
namespace sscl::tests {
class CoroutineDriver
{
public:
template <typename Invoker>
static auto completedReturnValue(Invoker &invoker)
{
if (invoker.completedReturnValues().myExceptionPtr) {
std::rethrow_exception(
invoker.completedReturnValues().myExceptionPtr);
}
return invoker.completedReturnValues().myReturnValue;
}
template <typename Invoker>
static auto pumpUntilIdleAndReturnValue(
boost::asio::io_context &ioContext,
Invoker &invoker)
{
IoContextPump::pumpUntilIdle(ioContext);
return completedReturnValue(invoker);
}
};
} // namespace sscl::tests
#endif // SPINSCALE_TEST_SUPPORT_COROUTINE_DRIVER_H
+108 -35
View File
@@ -2,6 +2,7 @@
#define SPINSCALE_TEST_SUPPORT_GROUP_ASSERTIONS_H #define SPINSCALE_TEST_SUPPORT_GROUP_ASSERTIONS_H
#include <exception> #include <exception>
#include <stdexcept>
#include <string> #include <string>
#include <gtest/gtest.h> #include <gtest/gtest.h>
@@ -21,12 +22,31 @@ int completedIntValue(Invoker &invoker)
return invoker.completedReturnValues().myReturnValue; return invoker.completedReturnValues().myReturnValue;
} }
inline void expectCompletedSettlement( inline void requireCompletedSettlement(
const sscl::co::Group::SettlementDescriptor &descriptor) const sscl::co::Group::SettlementDescriptor &descriptor)
{ {
EXPECT_EQ( if (descriptor.type !=
descriptor.type, sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED)
sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED); {
throw std::runtime_error("Expected completed settlement");
}
}
template <typename Invoker>
void requireCompletedIntSettlement(
const sscl::co::Group::SettlementDescriptor &descriptor,
int expectedValue)
{
requireCompletedSettlement(descriptor);
const int actualValue = completedIntValue(descriptor.invokerAs<Invoker>());
if (actualValue != expectedValue) {
throw std::runtime_error(
"Expected completed settlement value "
+ std::to_string(expectedValue)
+ ", got "
+ std::to_string(actualValue));
}
} }
template <typename Invoker> template <typename Invoker>
@@ -34,39 +54,85 @@ void expectCompletedIntSettlement(
const sscl::co::Group::SettlementDescriptor &descriptor, const sscl::co::Group::SettlementDescriptor &descriptor,
int expectedValue) int expectedValue)
{ {
ASSERT_EQ( EXPECT_NO_THROW(
descriptor.type, requireCompletedIntSettlement<Invoker>(
sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED); descriptor,
EXPECT_EQ(completedIntValue(descriptor.invokerAs<Invoker>()), expectedValue); expectedValue));
}
inline void expectCompletedSettlement(
const sscl::co::Group::SettlementDescriptor &descriptor)
{
EXPECT_NO_THROW(requireCompletedSettlement(descriptor));
}
inline void requireExceptionSettlement(
const sscl::co::Group::SettlementDescriptor &descriptor)
{
if (descriptor.type !=
sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN)
{
throw std::runtime_error("Expected exception settlement");
}
if (!descriptor.calleeException) {
throw std::runtime_error("Expected exception pointer in settlement");
}
} }
inline void expectExceptionSettlement( inline void expectExceptionSettlement(
const sscl::co::Group::SettlementDescriptor &descriptor) const sscl::co::Group::SettlementDescriptor &descriptor)
{ {
EXPECT_EQ( EXPECT_NO_THROW(requireExceptionSettlement(descriptor));
descriptor.type,
sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN);
EXPECT_TRUE(descriptor.calleeException != nullptr);
} }
inline void expectRuntimeErrorSettlement( inline void requireRuntimeErrorSettlement(
const sscl::co::Group::SettlementDescriptor &descriptor, const sscl::co::Group::SettlementDescriptor &descriptor,
const std::string &expectedMessage) const std::string &expectedMessage)
{ {
ASSERT_EQ( requireExceptionSettlement(descriptor);
descriptor.type,
sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN);
ASSERT_TRUE(descriptor.calleeException != nullptr);
try { try {
std::rethrow_exception(descriptor.calleeException); std::rethrow_exception(descriptor.calleeException);
} }
catch (const std::runtime_error &runtimeError) { catch (const std::runtime_error &runtimeError) {
EXPECT_EQ(std::string(runtimeError.what()), expectedMessage); const std::string actualMessage = runtimeError.what();
if (actualMessage != expectedMessage) {
throw std::runtime_error(
"Expected runtime_error settlement message \""
+ expectedMessage
+ "\", got \""
+ actualMessage
+ "\"");
}
return; return;
} }
catch (...) { catch (...) {
FAIL() << "Expected std::runtime_error settlement."; throw std::runtime_error("Expected std::runtime_error settlement");
}
}
inline void requireIntExceptionSettlement(
const sscl::co::Group::SettlementDescriptor &descriptor,
int expectedValue)
{
requireExceptionSettlement(descriptor);
try {
std::rethrow_exception(descriptor.calleeException);
}
catch (int caughtValue) {
if (caughtValue != expectedValue) {
throw std::runtime_error(
"Expected int exception settlement value "
+ std::to_string(expectedValue)
+ ", got "
+ std::to_string(caughtValue));
}
return;
}
catch (...) {
throw std::runtime_error("Expected int exception settlement");
} }
} }
@@ -74,29 +140,36 @@ inline void expectIntExceptionSettlement(
const sscl::co::Group::SettlementDescriptor &descriptor, const sscl::co::Group::SettlementDescriptor &descriptor,
int expectedValue) int expectedValue)
{ {
ASSERT_EQ( EXPECT_NO_THROW(
descriptor.type, requireIntExceptionSettlement(
sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN); descriptor,
ASSERT_TRUE(descriptor.calleeException != nullptr); expectedValue));
}
try { inline void expectRuntimeErrorSettlement(
std::rethrow_exception(descriptor.calleeException); const sscl::co::Group::SettlementDescriptor &descriptor,
} const std::string &expectedMessage)
catch (int caughtValue) { {
EXPECT_EQ(caughtValue, expectedValue); EXPECT_NO_THROW(
return; requireRuntimeErrorSettlement(
} descriptor,
catch (...) { expectedMessage));
FAIL() << "Expected int exception settlement."; }
inline void requireEmptyGroupError(
const std::runtime_error &runtimeError)
{
constexpr const char *expectedMessage =
"co_await: Group has no member invokers; call add() before awaiting";
if (std::string(runtimeError.what()) != expectedMessage) {
throw std::runtime_error("Unexpected empty group error message");
} }
} }
inline void expectEmptyGroupError( inline void expectEmptyGroupError(
const std::runtime_error &runtimeError) const std::runtime_error &runtimeError)
{ {
constexpr const char *expectedMessage = EXPECT_NO_THROW(requireEmptyGroupError(runtimeError));
"co_await: Group has no member invokers; call add() before awaiting";
EXPECT_EQ(std::string(runtimeError.what()), expectedMessage);
} }
} // namespace sscl::tests } // namespace sscl::tests
+54 -12
View File
@@ -217,13 +217,32 @@ void ThreadRegistry::registerThread(
DedicatedIoThread &thread) DedicatedIoThread &thread)
{ {
std::lock_guard<std::mutex> guard(registryMutex()); std::lock_guard<std::mutex> guard(registryMutex());
threadsByRole()[role] = &thread; auto [iterator, inserted] = threadsByRole().emplace(role, &thread);
if (!inserted) {
throw std::runtime_error(
"Test thread role already registered for " + threadRoleName(role));
}
} }
void ThreadRegistry::unregisterThread(PostingThreadRole role) void ThreadRegistry::unregisterThread(
PostingThreadRole role,
DedicatedIoThread &expectedThread)
{ {
std::lock_guard<std::mutex> guard(registryMutex()); std::lock_guard<std::mutex> guard(registryMutex());
threadsByRole().erase(role); auto iterator = threadsByRole().find(role);
if (iterator == threadsByRole().end()) {
return;
}
if (iterator->second != &expectedThread) {
throw std::runtime_error(
"Test thread role registered to a different thread for "
+ threadRoleName(role));
}
threadsByRole().erase(iterator);
} }
boost::asio::io_context &ThreadRegistry::ioContext(PostingThreadRole role) boost::asio::io_context &ThreadRegistry::ioContext(PostingThreadRole role)
@@ -272,6 +291,20 @@ PostingThreadSet::PostingThreadSet()
bodyThread(PostingThreadRole::BODY), bodyThread(PostingThreadRole::BODY),
worldThread(PostingThreadRole::WORLD), worldThread(PostingThreadRole::WORLD),
legThread(PostingThreadRole::LEG) legThread(PostingThreadRole::LEG)
{
previousPuppeteerThread = sscl::ComponentThread::getPptr();
previousPuppeteerThreadId = sscl::pptr::puppeteerThreadId;
registerAllThreads();
installCallerAsPuppeteer();
}
PostingThreadSet::~PostingThreadSet()
{
restorePreviousPuppeteer();
unregisterAllThreads();
}
void PostingThreadSet::registerAllThreads()
{ {
ThreadRegistry::registerThread(PostingThreadRole::CALLER, callerThread); ThreadRegistry::registerThread(PostingThreadRole::CALLER, callerThread);
ThreadRegistry::registerThread(PostingThreadRole::CALLEE, calleeThread); ThreadRegistry::registerThread(PostingThreadRole::CALLEE, calleeThread);
@@ -279,22 +312,31 @@ PostingThreadSet::PostingThreadSet()
ThreadRegistry::registerThread(PostingThreadRole::BODY, bodyThread); ThreadRegistry::registerThread(PostingThreadRole::BODY, bodyThread);
ThreadRegistry::registerThread(PostingThreadRole::WORLD, worldThread); ThreadRegistry::registerThread(PostingThreadRole::WORLD, worldThread);
ThreadRegistry::registerThread(PostingThreadRole::LEG, legThread); ThreadRegistry::registerThread(PostingThreadRole::LEG, legThread);
}
void PostingThreadSet::unregisterAllThreads()
{
ThreadRegistry::unregisterThread(PostingThreadRole::CALLER, callerThread);
ThreadRegistry::unregisterThread(PostingThreadRole::CALLEE, calleeThread);
ThreadRegistry::unregisterThread(
PostingThreadRole::ALTERNATE,
alternateThread);
ThreadRegistry::unregisterThread(PostingThreadRole::BODY, bodyThread);
ThreadRegistry::unregisterThread(PostingThreadRole::WORLD, worldThread);
ThreadRegistry::unregisterThread(PostingThreadRole::LEG, legThread);
}
void PostingThreadSet::installCallerAsPuppeteer()
{
sscl::ComponentThread::setPuppeteerThreadId( sscl::ComponentThread::setPuppeteerThreadId(
static_cast<sscl::ThreadId>(PostingThreadRole::CALLER)); static_cast<sscl::ThreadId>(PostingThreadRole::CALLER));
sscl::ComponentThread::setPuppeteerThread(callerThread.componentThread()); sscl::ComponentThread::setPuppeteerThread(callerThread.componentThread());
} }
PostingThreadSet::~PostingThreadSet() void PostingThreadSet::restorePreviousPuppeteer()
{ {
ThreadRegistry::unregisterThread(PostingThreadRole::CALLER); sscl::ComponentThread::setPuppeteerThreadId(previousPuppeteerThreadId);
ThreadRegistry::unregisterThread(PostingThreadRole::CALLEE); sscl::ComponentThread::setPuppeteerThread(previousPuppeteerThread);
ThreadRegistry::unregisterThread(PostingThreadRole::ALTERNATE);
ThreadRegistry::unregisterThread(PostingThreadRole::BODY);
ThreadRegistry::unregisterThread(PostingThreadRole::WORLD);
ThreadRegistry::unregisterThread(PostingThreadRole::LEG);
sscl::ComponentThread::setPuppeteerThread(nullptr);
} }
DedicatedIoThread &PostingThreadSet::thread(PostingThreadRole role) DedicatedIoThread &PostingThreadSet::thread(PostingThreadRole role)
+17 -1
View File
@@ -180,7 +180,9 @@ public:
static void registerThread( static void registerThread(
PostingThreadRole role, PostingThreadRole role,
DedicatedIoThread &thread); DedicatedIoThread &thread);
static void unregisterThread(PostingThreadRole role); static void unregisterThread(
PostingThreadRole role,
DedicatedIoThread &expectedThread);
static boost::asio::io_context &ioContext(PostingThreadRole role); static boost::asio::io_context &ioContext(PostingThreadRole role);
static std::thread::id osThreadId(PostingThreadRole role); static std::thread::id osThreadId(PostingThreadRole role);
@@ -240,14 +242,28 @@ public:
DedicatedIoThread &leg(); DedicatedIoThread &leg();
private: private:
void registerAllThreads();
void unregisterAllThreads();
void installCallerAsPuppeteer();
void restorePreviousPuppeteer();
DedicatedIoThread callerThread; DedicatedIoThread callerThread;
DedicatedIoThread calleeThread; DedicatedIoThread calleeThread;
DedicatedIoThread alternateThread; DedicatedIoThread alternateThread;
DedicatedIoThread bodyThread; DedicatedIoThread bodyThread;
DedicatedIoThread worldThread; DedicatedIoThread worldThread;
DedicatedIoThread legThread; DedicatedIoThread legThread;
std::shared_ptr<sscl::PuppeteerThread> previousPuppeteerThread;
sscl::ThreadId previousPuppeteerThreadId = 0;
}; };
template <typename Function>
auto RunOnThread(DedicatedIoThread &thread, Function &&function)
-> std::invoke_result_t<Function &>
{
return thread.runSync(std::forward<Function>(function));
}
class CrossThreadTrace class CrossThreadTrace
{ {
public: public: