Tests: Add all tests from the coro creation repo

We went back and brought along all the tests we implemented while
we were building the new coro framework.
This commit is contained in:
2026-06-13 17:17:57 -04:00
parent 1763685c0e
commit a29c779f6e
11 changed files with 3199 additions and 28 deletions
+43 -26
View File
@@ -1,37 +1,54 @@
add_executable(spinscale_env_kv_store_tests
add_library(spinscale_test_support STATIC
support/threadHarness.cpp
)
target_include_directories(spinscale_test_support PUBLIC
${CMAKE_CURRENT_SOURCE_DIR}
)
target_link_libraries(spinscale_test_support PUBLIC
spinscale
gtest
)
function(add_spinscale_gtest target)
add_executable(${target} ${ARGN})
target_link_libraries(${target} PRIVATE
spinscale_test_support
gtest_main
)
add_dependencies(${target} gtest_main)
add_test(NAME ${target} COMMAND ${target})
endfunction()
add_spinscale_gtest(spinscale_env_kv_store_tests
env_kv_store_test.cpp
)
target_link_libraries(spinscale_env_kv_store_tests PRIVATE
spinscale
gtest_main
)
add_dependencies(spinscale_env_kv_store_tests gtest_main)
add_test(NAME spinscale_env_kv_store_tests
COMMAND spinscale_env_kv_store_tests)
add_executable(qutex_tests
add_spinscale_gtest(qutex_tests
cps/qutex_tests.cpp
)
target_link_libraries(qutex_tests PRIVATE
spinscale
gtest_main
)
add_dependencies(qutex_tests gtest_main)
add_test(NAME qutex_tests COMMAND qutex_tests)
add_executable(nonViralTaskNursery_tests
add_spinscale_gtest(nonViralTaskNursery_tests
co/nonViralTaskNursery_tests.cpp
)
target_link_libraries(nonViralTaskNursery_tests PRIVATE
spinscale
gtest_main
add_spinscale_gtest(co_viral_non_posting_tests
co/viral_non_posting_tests.cpp
)
add_dependencies(nonViralTaskNursery_tests gtest_main)
add_test(NAME nonViralTaskNursery_tests
COMMAND nonViralTaskNursery_tests)
add_spinscale_gtest(co_posting_cross_thread_tests
co/posting_cross_thread_tests.cpp
)
add_spinscale_gtest(co_group_edge_tests
co/group_edge_tests.cpp
)
add_spinscale_gtest(co_group_timer_tests
co/group_timer_tests.cpp
)
add_spinscale_gtest(co_component_continuation_tests
co/component_continuation_tests.cpp
)
+250
View File
@@ -0,0 +1,250 @@
#include <exception>
#include <functional>
#include <mutex>
#include <stdexcept>
#include <string>
#include <thread>
#include <gtest/gtest.h>
#include <spinscale/co/coQutex.h>
#include <support/threadHarness.h>
namespace {
constexpr int leftValue = 1;
constexpr int rightValue = 2;
constexpr int expectedIntSum = 3;
constexpr int bodyArgument = 4;
constexpr const char *bodyStringArgument = "KEKW";
constexpr const char *leftString = "Hello";
constexpr const char *rightString = "World";
constexpr const char *expectedString = "Hello World";
using BodyNonViralInvoker =
sscl::tests::RoleNonViralPostingInvoker<
sscl::tests::PostingThreadRole::BODY>;
template <typename T>
using BodyViralInvoker =
sscl::tests::RoleViralPostingInvoker<
sscl::tests::PostingThreadRole::BODY,
T>;
template <typename T>
using WorldViralInvoker =
sscl::tests::RoleViralPostingInvoker<
sscl::tests::PostingThreadRole::WORLD,
T>;
template <typename T>
using LegViralInvoker =
sscl::tests::RoleViralPostingInvoker<
sscl::tests::PostingThreadRole::LEG,
T>;
class ComponentContinuationTrace
{
public:
void recordBodyThread()
{
std::lock_guard<std::mutex> guard(mutex);
bodyThreadId = std::this_thread::get_id();
}
void recordWorldThread()
{
std::lock_guard<std::mutex> guard(mutex);
worldThreadId = std::this_thread::get_id();
}
void recordLegThread()
{
std::lock_guard<std::mutex> guard(mutex);
legThreadId = std::this_thread::get_id();
}
void recordCompletionThread()
{
std::lock_guard<std::mutex> guard(mutex);
completionThreadId = std::this_thread::get_id();
}
void recordLegSum(int value)
{
std::lock_guard<std::mutex> guard(mutex);
legSum = value;
}
void recordWorldString(std::string value)
{
std::lock_guard<std::mutex> guard(mutex);
worldString = std::move(value);
}
void recordBodyString(std::string value)
{
std::lock_guard<std::mutex> guard(mutex);
bodyString = std::move(value);
}
std::thread::id bodyThread() const
{
std::lock_guard<std::mutex> guard(mutex);
return bodyThreadId;
}
std::thread::id worldThread() const
{
std::lock_guard<std::mutex> guard(mutex);
return worldThreadId;
}
std::thread::id legThread() const
{
std::lock_guard<std::mutex> guard(mutex);
return legThreadId;
}
std::thread::id completionThread() const
{
std::lock_guard<std::mutex> guard(mutex);
return completionThreadId;
}
int recordedLegSum() const
{
std::lock_guard<std::mutex> guard(mutex);
return legSum;
}
std::string recordedWorldString() const
{
std::lock_guard<std::mutex> guard(mutex);
return worldString;
}
std::string recordedBodyString() const
{
std::lock_guard<std::mutex> guard(mutex);
return bodyString;
}
private:
mutable std::mutex mutex;
std::thread::id bodyThreadId;
std::thread::id worldThreadId;
std::thread::id legThreadId;
std::thread::id completionThreadId;
int legSum = 0;
std::string worldString;
std::string bodyString;
};
LegViralInvoker<int> print2Ints(
int arg1,
int arg2,
ComponentContinuationTrace &trace)
{
sscl::co::CoQutex print2IntsLock;
trace.recordLegThread();
auto releaseHandle =
co_await print2IntsLock.getAcquireInvocationAndSuspensionPolicy();
const int sum = arg1 + arg2;
trace.recordLegSum(sum);
releaseHandle.release();
co_return sum;
}
WorldViralInvoker<std::string> print2Strings(
std::string arg1,
std::string arg2,
ComponentContinuationTrace &trace)
{
sscl::co::CoQutex print2StringsLock;
trace.recordWorldThread();
auto releaseHandle =
co_await print2StringsLock.getAcquireInvocationAndSuspensionPolicy();
const int returnedInt =
co_await print2Ints(leftValue, rightValue, trace);
releaseHandle.release();
if (returnedInt != expectedIntSum) {
throw std::runtime_error("LEG int return mismatch");
}
std::string returnedString = arg1 + " " + arg2;
trace.recordWorldString(returnedString);
co_return returnedString;
}
BodyNonViralInvoker initializeDemoCReq(
std::exception_ptr &exceptionPtr,
std::function<void()> completion,
int arg3,
std::string arg4,
ComponentContinuationTrace &trace)
{
(void)exceptionPtr;
(void)completion;
(void)arg3;
(void)arg4;
sscl::co::CoQutex initializeLock;
trace.recordBodyThread();
auto releaseHandle =
co_await initializeLock.getAcquireInvocationAndSuspensionPolicy();
std::string returnedString =
co_await print2Strings(leftString, rightString, trace);
releaseHandle.release();
trace.recordBodyString(returnedString);
co_return;
}
class ComponentContinuationTest
: public ::testing::Test
{
protected:
sscl::tests::PostingThreadSet threads;
};
} // namespace
TEST_F(ComponentContinuationTest, SyncMainStyleContinuationCrossesComponentThreads)
{
ComponentContinuationTrace trace;
ASSERT_NO_THROW(
sscl::tests::runNonViralPostingTask(
threads.caller(),
[&trace](
std::exception_ptr &exceptionPtr,
std::function<void()> completion)
{
return initializeDemoCReq(
exceptionPtr,
[&trace, completion = std::move(completion)]() mutable
{
trace.recordCompletionThread();
completion();
},
bodyArgument,
bodyStringArgument,
trace);
}));
EXPECT_EQ(trace.bodyThread(), threads.body().osThreadId());
EXPECT_EQ(trace.worldThread(), threads.world().osThreadId());
EXPECT_EQ(trace.legThread(), threads.leg().osThreadId());
EXPECT_EQ(trace.completionThread(), threads.caller().osThreadId());
EXPECT_NE(trace.bodyThread(), trace.worldThread());
EXPECT_NE(trace.worldThread(), trace.legThread());
EXPECT_NE(trace.legThread(), trace.completionThread());
EXPECT_EQ(trace.recordedLegSum(), expectedIntSum);
EXPECT_EQ(trace.recordedWorldString(), expectedString);
EXPECT_EQ(trace.recordedBodyString(), expectedString);
}
+835
View File
@@ -0,0 +1,835 @@
#include <atomic>
#include <chrono>
#include <exception>
#include <functional>
#include <stdexcept>
#include <string>
#include <thread>
#include <gtest/gtest.h>
#include <boost/asio/post.hpp>
#include <boost/system/error_code.hpp>
#include <spinscale/co/group.h>
#include <spinscale/componentThread.h>
#include <support/groupAssertions.h>
#include <support/threadHarness.h>
#include <support/timerAwaiters.h>
namespace {
constexpr int delayShortMs = 50;
constexpr int delayMediumMs = 200;
constexpr int delayLongMs = 500;
constexpr int delayAddWhileSuspendedProbeMs = 80;
constexpr int expectedNonStdThrowValue = 42;
constexpr int wave2ImmediateSettlementLabel = 1000;
constexpr const char *expectedThrowMessage =
"group_edge_test intentional failure";
using CallerDriver =
sscl::tests::RoleNonViralPostingInvoker<
sscl::tests::PostingThreadRole::CALLER>;
using CalleeIntInvoker =
sscl::tests::RoleViralPostingInvoker<
sscl::tests::PostingThreadRole::CALLEE,
int>;
using CalleeVoidInvoker =
sscl::tests::RoleViralPostingInvoker<
sscl::tests::PostingThreadRole::CALLEE,
void>;
CalleeIntInvoker waitAndReturnLabel(int timerLabelMilliseconds)
{
const boost::system::error_code waitError =
co_await sscl::tests::DeadlineTimerAwaiter{
sscl::ComponentThread::getSelf()->getIoContext(),
timerLabelMilliseconds};
sscl::tests::throwIfTimerWaitFailed(waitError);
co_return timerLabelMilliseconds;
}
CalleeIntInvoker waitThenThrowAfterDelay(int delayMilliseconds)
{
const boost::system::error_code waitError =
co_await sscl::tests::DeadlineTimerAwaiter{
sscl::ComponentThread::getSelf()->getIoContext(),
delayMilliseconds};
sscl::tests::throwIfTimerWaitFailed(waitError);
throw std::runtime_error(expectedThrowMessage);
}
CalleeIntInvoker waitThenThrowIntAfterDelay(int delayMilliseconds)
{
const boost::system::error_code waitError =
co_await sscl::tests::DeadlineTimerAwaiter{
sscl::ComponentThread::getSelf()->getIoContext(),
delayMilliseconds};
sscl::tests::throwIfTimerWaitFailed(waitError);
throw expectedNonStdThrowValue;
}
CalleeIntInvoker returnLabelImmediately(int label)
{
co_return label;
}
CalleeVoidInvoker voidMemberAfterDelay(int delayMilliseconds)
{
const boost::system::error_code waitError =
co_await sscl::tests::DeadlineTimerAwaiter{
sscl::ComponentThread::getSelf()->getIoContext(),
delayMilliseconds};
sscl::tests::throwIfTimerWaitFailed(waitError);
co_return;
}
int readCompletedLabel(CalleeIntInvoker &invoker)
{
return invoker.completedReturnValues().myReturnValue;
}
void assertCompleted(
const sscl::co::Group::SettlementDescriptor &descriptor,
int expectedLabel)
{
if (descriptor.type
!= sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED) {
throw std::runtime_error("expected completed settlement");
}
if (readCompletedLabel(descriptor.invokerAs<CalleeIntInvoker>())
!= expectedLabel) {
throw std::runtime_error("settlement label mismatch");
}
}
void assertRuntimeErrorSettlement(
const sscl::co::Group::SettlementDescriptor &descriptor)
{
if (descriptor.type
!= sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN) {
throw std::runtime_error("expected exception settlement");
}
try {
std::rethrow_exception(descriptor.calleeException);
}
catch (const std::runtime_error &runtimeError) {
if (std::string(runtimeError.what()) != expectedThrowMessage) {
throw std::runtime_error("unexpected exception message");
}
return;
}
throw std::runtime_error("expected runtime_error settlement");
}
void assertIntExceptionSettlement(
const sscl::co::Group::SettlementDescriptor &descriptor)
{
if (descriptor.type
!= sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN) {
throw std::runtime_error("expected int exception settlement");
}
try {
std::rethrow_exception(descriptor.calleeException);
}
catch (int caughtValue) {
if (caughtValue != expectedNonStdThrowValue) {
throw std::runtime_error("unexpected int exception value");
}
return;
}
throw std::runtime_error("expected int exception settlement");
}
void assertEmptyGroupCoAwaitError(const std::runtime_error &runtimeError)
{
constexpr const char *expectedEmptyGroupCoAwaitMessage =
"co_await: Group has no member invokers; call add() before awaiting";
if (std::string(runtimeError.what()) != expectedEmptyGroupCoAwaitMessage) {
throw std::runtime_error("unexpected empty-group error message");
}
}
sscl::co::ViralNonPostingInvoker<void> waitOnCallerThread(int delayMilliseconds)
{
const boost::system::error_code waitError =
co_await sscl::tests::DeadlineTimerAwaiter{
sscl::ComponentThread::getSelf()->getIoContext(),
delayMilliseconds};
sscl::tests::throwIfTimerWaitFailed(waitError);
co_return;
}
CallerDriver mixedSuccessAndFailureAwaitFirstThenAll(
std::exception_ptr &exceptionPtr,
std::function<void()> completion)
{
(void)exceptionPtr;
(void)completion;
sscl::co::Group group;
CalleeIntInvoker successInvoker = waitAndReturnLabel(1);
CalleeIntInvoker failureInvoker = waitThenThrowAfterDelay(delayShortMs);
group.add(successInvoker);
group.add(failureInvoker);
auto awaitFirst = group.getAwaitFirstSettlementInvoker();
auto [firstDescriptor, allAfterFirst] = co_await awaitFirst;
if (firstDescriptor.type
== sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED) {
assertCompleted(firstDescriptor, 1);
}
else if (firstDescriptor.type
== sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN) {
assertRuntimeErrorSettlement(firstDescriptor);
}
else {
throw std::runtime_error("first settlement has unexpected type");
}
auto awaitAll = group.getAwaitAllSettlementsInvoker();
auto &allDescriptors = co_await awaitAll;
if (allDescriptors.size() != 2 || allAfterFirst.size() != 2) {
throw std::runtime_error("mixed settlement count mismatch");
}
std::size_t completedCount = 0;
std::size_t exceptionCount = 0;
for (auto &descriptor : allDescriptors) {
if (descriptor.type
== sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED) {
++completedCount;
assertCompleted(descriptor, 1);
}
else if (descriptor.type
== sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN) {
++exceptionCount;
assertRuntimeErrorSettlement(descriptor);
}
}
if (completedCount != 1 || exceptionCount != 1) {
throw std::runtime_error("mixed settlement type counts mismatch");
}
co_return;
}
CallerDriver singleMemberAwaitFirstThenAll(
std::exception_ptr &exceptionPtr,
std::function<void()> completion)
{
(void)exceptionPtr;
(void)completion;
sscl::co::Group group;
CalleeIntInvoker onlyInvoker = waitAndReturnLabel(delayShortMs);
group.add(onlyInvoker);
auto awaitFirst = group.getAwaitFirstSettlementInvoker();
auto [firstDescriptor, allAfterFirst] = co_await awaitFirst;
assertCompleted(firstDescriptor, delayShortMs);
if (!group.allInvokersSettled() || allAfterFirst.size() != 1) {
throw std::runtime_error("single member state mismatch");
}
auto awaitAll = group.getAwaitAllSettlementsInvoker();
auto &allDescriptors = co_await awaitAll;
if (allDescriptors.size() != 1) {
throw std::runtime_error("single member await-all count mismatch");
}
assertCompleted(allDescriptors[0], delayShortMs);
co_return;
}
CallerDriver allCompleteBeforeCoAwait(
std::exception_ptr &exceptionPtr,
std::function<void()> completion)
{
(void)exceptionPtr;
(void)completion;
sscl::co::Group group;
CalleeIntInvoker invokerTen = returnLabelImmediately(10);
CalleeIntInvoker invokerTwenty = returnLabelImmediately(20);
CalleeIntInvoker invokerThirty = returnLabelImmediately(30);
group.add(invokerTen);
group.add(invokerTwenty);
group.add(invokerThirty);
co_await waitOnCallerThread(delayShortMs);
if (!group.allInvokersSettled() || !group.firstInvokerSettled()) {
throw std::runtime_error("immediate group did not settle before await");
}
auto awaitFirst = group.getAwaitFirstSettlementInvoker();
auto [firstDescriptor, allAfterFirst] = co_await awaitFirst;
assertCompleted(firstDescriptor, 10);
auto awaitAll = group.getAwaitAllSettlementsInvoker();
auto &allDescriptors = co_await awaitAll;
if (allDescriptors.size() != 3 || allAfterFirst.size() != 3) {
throw std::runtime_error("immediate settlement count mismatch");
}
co_return;
}
std::thread startAddWhileGroupAwaiterSuspendedProbe(
sscl::co::Group &group,
CalleeIntInvoker &lateInvoker,
std::atomic<bool> &groupIsAwaitingAll,
std::atomic<bool> &addWasRejected)
{
return std::thread(
[&]()
{
while (!groupIsAwaitingAll.load(std::memory_order_acquire)) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
std::this_thread::sleep_for(
std::chrono::milliseconds(delayAddWhileSuspendedProbeMs));
boost::asio::post(
sscl::tests::ThreadRegistry::ioContext(
sscl::tests::PostingThreadRole::CALLER),
[&]()
{
try {
group.add(lateInvoker);
}
catch (const std::runtime_error &) {
addWasRejected.store(true, std::memory_order_release);
}
});
});
}
CallerDriver addWhileAwaitAllSuspended(
std::exception_ptr &exceptionPtr,
std::function<void()> completion)
{
(void)exceptionPtr;
(void)completion;
sscl::co::Group group;
std::atomic<bool> groupIsAwaitingAll{false};
std::atomic<bool> addWasRejected{false};
CalleeIntInvoker slowInvokerA = waitAndReturnLabel(delayLongMs);
CalleeIntInvoker slowInvokerB = waitAndReturnLabel(delayLongMs);
CalleeIntInvoker lateInvoker = waitAndReturnLabel(99);
group.add(slowInvokerA);
group.add(slowInvokerB);
std::thread addProbeThread = startAddWhileGroupAwaiterSuspendedProbe(
group,
lateInvoker,
groupIsAwaitingAll,
addWasRejected);
auto awaitAll = group.getAwaitAllSettlementsInvoker();
groupIsAwaitingAll.store(true, std::memory_order_release);
co_await awaitAll;
addProbeThread.join();
if (!addWasRejected.load(std::memory_order_acquire)) {
throw std::runtime_error("expected add while suspended to throw");
}
co_return;
}
CallerDriver awaitAllOnlyMixedOutcomes(
std::exception_ptr &exceptionPtr,
std::function<void()> completion)
{
(void)exceptionPtr;
(void)completion;
sscl::co::Group group;
CalleeIntInvoker successInvoker = returnLabelImmediately(7);
CalleeIntInvoker failureInvoker = waitThenThrowAfterDelay(delayShortMs);
group.add(successInvoker);
group.add(failureInvoker);
auto awaitAll = group.getAwaitAllSettlementsInvoker();
auto &allDescriptors = co_await awaitAll;
if (allDescriptors.size() != 2) {
throw std::runtime_error("await-all-only count mismatch");
}
std::size_t completedCount = 0;
std::size_t exceptionCount = 0;
for (auto &descriptor : allDescriptors) {
if (descriptor.type
== sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED) {
++completedCount;
assertCompleted(descriptor, 7);
}
else if (descriptor.type
== sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN) {
++exceptionCount;
assertRuntimeErrorSettlement(descriptor);
}
}
if (completedCount != 1 || exceptionCount != 1) {
throw std::runtime_error("await-all-only mixed counts mismatch");
}
co_return;
}
CallerDriver checkForAndReThrowGroupExceptions(
std::exception_ptr &exceptionPtr,
std::function<void()> completion)
{
(void)exceptionPtr;
(void)completion;
sscl::co::Group group;
CalleeIntInvoker failureInvoker = waitThenThrowAfterDelay(delayShortMs);
group.add(failureInvoker);
(void)co_await group.getAwaitAllSettlementsInvoker();
try {
group.checkForAndReThrowGroupExceptions();
}
catch (const std::runtime_error &aggregateError) {
if (std::string(aggregateError.what()).find(expectedThrowMessage)
== std::string::npos) {
throw std::runtime_error("aggregate message missing callee text");
}
co_return;
}
throw std::runtime_error("expected aggregate group exception");
}
CallerDriver emptyGroupAwaitAllThrows(
std::exception_ptr &exceptionPtr,
std::function<void()> completion)
{
(void)exceptionPtr;
(void)completion;
sscl::co::Group group;
try {
(void)co_await group.getAwaitAllSettlementsInvoker();
}
catch (const std::runtime_error &runtimeError) {
assertEmptyGroupCoAwaitError(runtimeError);
co_return;
}
throw std::runtime_error("expected empty group await-all to throw");
}
CallerDriver emptyGroupAwaitFirstThrows(
std::exception_ptr &exceptionPtr,
std::function<void()> completion)
{
(void)exceptionPtr;
(void)completion;
sscl::co::Group group;
try {
(void)co_await group.getAwaitFirstSettlementInvoker();
}
catch (const std::runtime_error &runtimeError) {
assertEmptyGroupCoAwaitError(runtimeError);
co_return;
}
throw std::runtime_error("expected empty group await-first to throw");
}
CallerDriver wrongAwaitInvokerOrder(
std::exception_ptr &exceptionPtr,
std::function<void()> completion)
{
(void)exceptionPtr;
(void)completion;
sscl::co::Group group;
CalleeIntInvoker shortInvoker = waitAndReturnLabel(delayShortMs);
CalleeIntInvoker mediumInvoker = waitAndReturnLabel(delayMediumMs);
group.add(shortInvoker);
group.add(mediumInvoker);
auto awaitFirstHandle = group.getAwaitFirstSettlementInvoker();
auto awaitAllHandle = group.getAwaitAllSettlementsInvoker();
auto &allDescriptors = co_await awaitAllHandle;
if (allDescriptors.size() != 2) {
throw std::runtime_error("wrong-order await-all count mismatch");
}
auto [firstDescriptor, allAfterFirst] = co_await awaitFirstHandle;
assertCompleted(
firstDescriptor,
readCompletedLabel(firstDescriptor.invokerAs<CalleeIntInvoker>()));
if (!group.firstInvokerSettled() || allAfterFirst.size() != 2) {
throw std::runtime_error("wrong-order await-first state mismatch");
}
co_return;
}
CallerDriver doubleCoAwaitSameAwaitFirst(
std::exception_ptr &exceptionPtr,
std::function<void()> completion)
{
(void)exceptionPtr;
(void)completion;
sscl::co::Group group;
CalleeIntInvoker memberInvoker = returnLabelImmediately(delayShortMs);
group.add(memberInvoker);
co_await waitOnCallerThread(delayShortMs);
auto awaitFirst = group.getAwaitFirstSettlementInvoker();
auto [firstDescriptorA, allAfterFirstA] = co_await awaitFirst;
auto [firstDescriptorB, allAfterFirstB] = co_await awaitFirst;
assertCompleted(firstDescriptorA, delayShortMs);
assertCompleted(firstDescriptorB, delayShortMs);
if (&firstDescriptorA.invokerAs<CalleeIntInvoker>()
!= &firstDescriptorB.invokerAs<CalleeIntInvoker>()) {
throw std::runtime_error("double await-first descriptor mismatch");
}
if (allAfterFirstA.size() != allAfterFirstB.size()) {
throw std::runtime_error("double await-first snapshot mismatch");
}
co_return;
}
CallerDriver doubleCoAwaitSameAwaitAll(
std::exception_ptr &exceptionPtr,
std::function<void()> completion)
{
(void)exceptionPtr;
(void)completion;
sscl::co::Group group;
CalleeIntInvoker memberInvoker = waitAndReturnLabel(delayShortMs);
group.add(memberInvoker);
auto awaitAll = group.getAwaitAllSettlementsInvoker();
auto &allDescriptorsA = co_await awaitAll;
auto &allDescriptorsB = co_await awaitAll;
if (allDescriptorsA.size() != 1 || allDescriptorsB.size() != 1) {
throw std::runtime_error("double await-all count mismatch");
}
assertCompleted(allDescriptorsA[0], delayShortMs);
assertCompleted(allDescriptorsB[0], delayShortMs);
co_return;
}
CallerDriver twoAwaitFirstHandlesSequentially(
std::exception_ptr &exceptionPtr,
std::function<void()> completion)
{
(void)exceptionPtr;
(void)completion;
sscl::co::Group group;
CalleeIntInvoker shortInvoker = waitAndReturnLabel(delayShortMs);
CalleeIntInvoker mediumInvoker = waitAndReturnLabel(delayMediumMs);
group.add(shortInvoker);
group.add(mediumInvoker);
auto awaitFirstA = group.getAwaitFirstSettlementInvoker();
auto [firstDescriptorA, allAfterFirstA] = co_await awaitFirstA;
assertCompleted(firstDescriptorA, delayShortMs);
auto awaitFirstB = group.getAwaitFirstSettlementInvoker();
auto [firstDescriptorB, allAfterFirstB] = co_await awaitFirstB;
assertCompleted(firstDescriptorB, delayShortMs);
if (&firstDescriptorA.invokerAs<CalleeIntInvoker>()
!= &firstDescriptorB.invokerAs<CalleeIntInvoker>()) {
throw std::runtime_error("sticky first settlement mismatch");
}
(void)co_await group.getAwaitAllSettlementsInvoker();
(void)allAfterFirstA;
(void)allAfterFirstB;
co_return;
}
CallerDriver addSecondWaveAfterAwaitAll(
std::exception_ptr &exceptionPtr,
std::function<void()> completion)
{
(void)exceptionPtr;
(void)completion;
sscl::co::Group group;
CalleeIntInvoker wave1MemberA = waitAndReturnLabel(delayLongMs);
CalleeIntInvoker wave1MemberB = waitAndReturnLabel(delayLongMs);
group.add(wave1MemberA);
group.add(wave1MemberB);
(void)co_await group.getAwaitAllSettlementsInvoker();
CalleeIntInvoker wave2Immediate =
returnLabelImmediately(wave2ImmediateSettlementLabel);
CalleeIntInvoker wave2Slow = waitAndReturnLabel(delayMediumMs);
group.add(wave2Immediate);
group.add(wave2Slow);
co_await waitOnCallerThread(delayShortMs);
if (readCompletedLabel(wave2Immediate)
!= wave2ImmediateSettlementLabel) {
throw std::runtime_error("wave-2 immediate member did not complete");
}
if (group.allInvokersSettled()) {
throw std::runtime_error("wave-2 slow member should still be in flight");
}
auto &allDescriptors =
co_await group.getAwaitAllSettlementsInvoker();
if (allDescriptors.size() != 4) {
throw std::runtime_error("expected four settlements after second wave");
}
co_return;
}
CallerDriver shortTimerAddedAfterLongStillWinsRace(
std::exception_ptr &exceptionPtr,
std::function<void()> completion)
{
(void)exceptionPtr;
(void)completion;
sscl::co::Group group;
CalleeIntInvoker longInvoker = waitAndReturnLabel(delayLongMs);
CalleeIntInvoker shortInvoker = waitAndReturnLabel(delayShortMs);
group.add(longInvoker);
group.add(shortInvoker);
auto awaitFirst = group.getAwaitFirstSettlementInvoker();
auto [firstDescriptor, allAfterFirst] = co_await awaitFirst;
assertCompleted(firstDescriptor, delayShortMs);
if (&firstDescriptor.invokerAs<CalleeIntInvoker>() != &shortInvoker) {
throw std::runtime_error("short timer should win first settlement");
}
(void)co_await group.getAwaitAllSettlementsInvoker();
(void)allAfterFirst;
co_return;
}
CallerDriver nonStdExceptionSettlement(
std::exception_ptr &exceptionPtr,
std::function<void()> completion)
{
(void)exceptionPtr;
(void)completion;
sscl::co::Group group;
CalleeIntInvoker failureInvoker = waitThenThrowIntAfterDelay(delayShortMs);
group.add(failureInvoker);
auto &allDescriptors = co_await group.getAwaitAllSettlementsInvoker();
if (allDescriptors.size() != 1) {
throw std::runtime_error("non-std exception count mismatch");
}
assertIntExceptionSettlement(allDescriptors[0]);
try {
group.checkForAndReThrowGroupExceptions();
}
catch (const std::runtime_error &) {
co_return;
}
throw std::runtime_error("expected aggregate for non-std exception");
}
CallerDriver voidViralMemberInGroup(
std::exception_ptr &exceptionPtr,
std::function<void()> completion)
{
(void)exceptionPtr;
(void)completion;
sscl::co::Group group;
CalleeVoidInvoker voidInvoker = voidMemberAfterDelay(delayShortMs);
group.add(voidInvoker);
auto &allDescriptors = co_await group.getAwaitAllSettlementsInvoker();
if (allDescriptors.size() != 1) {
throw std::runtime_error("void group count mismatch");
}
if (allDescriptors[0].type
!= sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED) {
throw std::runtime_error("void member did not complete");
}
co_return;
}
CallerDriver returnValuesRemainReadableAfterAwaitFirst(
std::exception_ptr &exceptionPtr,
std::function<void()> completion)
{
(void)exceptionPtr;
(void)completion;
sscl::co::Group group;
CalleeIntInvoker slowInvoker = waitAndReturnLabel(delayLongMs);
CalleeIntInvoker fastInvoker = waitAndReturnLabel(delayShortMs);
group.add(slowInvoker);
group.add(fastInvoker);
auto awaitFirst = group.getAwaitFirstSettlementInvoker();
auto [firstDescriptor, allAfterFirst] = co_await awaitFirst;
assertCompleted(firstDescriptor, delayShortMs);
const int fastLabelFromDescriptor = readCompletedLabel(
firstDescriptor.invokerAs<CalleeIntInvoker>());
const int fastLabelFromLocal = readCompletedLabel(fastInvoker);
if (fastLabelFromDescriptor != fastLabelFromLocal) {
throw std::runtime_error("descriptor/local return value mismatch");
}
if (allAfterFirst.size() != 2) {
throw std::runtime_error("expected two settlement slots");
}
(void)co_await group.getAwaitAllSettlementsInvoker();
co_return;
}
class GroupEdgeTest
: public ::testing::Test
{
protected:
template <typename Factory>
void runScenario(Factory &&factory)
{
ASSERT_NO_THROW(
sscl::tests::runNonViralPostingTask(
threads.caller(),
std::forward<Factory>(factory)));
}
sscl::tests::PostingThreadSet threads;
};
} // namespace
#define RUN_GROUP_EDGE_SCENARIO(testName, functionName) \
TEST_F(GroupEdgeTest, testName) \
{ \
runScenario( \
[]( \
std::exception_ptr &exceptionPtr, \
std::function<void()> completion) \
{ \
return functionName(exceptionPtr, std::move(completion)); \
}); \
}
RUN_GROUP_EDGE_SCENARIO(
MixedSuccessAndFailureAwaitFirstThenAll,
mixedSuccessAndFailureAwaitFirstThenAll)
RUN_GROUP_EDGE_SCENARIO(
SingleMemberAwaitFirstThenAll,
singleMemberAwaitFirstThenAll)
RUN_GROUP_EDGE_SCENARIO(AllCompleteBeforeCoAwait, allCompleteBeforeCoAwait)
RUN_GROUP_EDGE_SCENARIO(AddWhileAwaitAllSuspended, addWhileAwaitAllSuspended)
RUN_GROUP_EDGE_SCENARIO(AwaitAllOnlyMixedOutcomes, awaitAllOnlyMixedOutcomes)
RUN_GROUP_EDGE_SCENARIO(
CheckForAndReThrowGroupExceptions,
checkForAndReThrowGroupExceptions)
RUN_GROUP_EDGE_SCENARIO(EmptyGroupAwaitAllThrows, emptyGroupAwaitAllThrows)
RUN_GROUP_EDGE_SCENARIO(EmptyGroupAwaitFirstThrows, emptyGroupAwaitFirstThrows)
RUN_GROUP_EDGE_SCENARIO(WrongAwaitInvokerOrder, wrongAwaitInvokerOrder)
RUN_GROUP_EDGE_SCENARIO(DoubleCoAwaitSameAwaitFirst, doubleCoAwaitSameAwaitFirst)
RUN_GROUP_EDGE_SCENARIO(DoubleCoAwaitSameAwaitAll, doubleCoAwaitSameAwaitAll)
RUN_GROUP_EDGE_SCENARIO(
TwoAwaitFirstHandlesSequentially,
twoAwaitFirstHandlesSequentially)
RUN_GROUP_EDGE_SCENARIO(AddSecondWaveAfterAwaitAll, addSecondWaveAfterAwaitAll)
RUN_GROUP_EDGE_SCENARIO(
ShortTimerAddedAfterLongStillWinsRace,
shortTimerAddedAfterLongStillWinsRace)
RUN_GROUP_EDGE_SCENARIO(NonStdExceptionSettlement, nonStdExceptionSettlement)
RUN_GROUP_EDGE_SCENARIO(VoidViralMemberInGroup, voidViralMemberInGroup)
RUN_GROUP_EDGE_SCENARIO(
ReturnValuesRemainReadableAfterAwaitFirst,
returnValuesRemainReadableAfterAwaitFirst)
TEST_F(GroupEdgeTest, NonViralVoidGroupTemplateInstantiates)
{
GTEST_SKIP()
<< "NonViralPostingInvoker does not satisfy Group's awaitable concept.";
}
TEST_F(GroupEdgeTest, EarlyInvokerDestructionIsUnsupported)
{
GTEST_SKIP()
<< "Destroying a member invoker before group settlement completes is undefined.";
}
TEST_F(GroupEdgeTest, OverlappingGroupWaitsAssertInDebug)
{
GTEST_SKIP()
<< "Overlapping group co_await is debug-assert behavior.";
}
+268
View File
@@ -0,0 +1,268 @@
#include <chrono>
#include <exception>
#include <functional>
#include <stdexcept>
#include <string>
#include <gtest/gtest.h>
#include <boost/asio/error.hpp>
#include <boost/system/error_code.hpp>
#include <spinscale/co/group.h>
#include <spinscale/componentThread.h>
#include <support/groupAssertions.h>
#include <support/threadHarness.h>
#include <support/timerAwaiters.h>
namespace {
constexpr int timerDelayShortMs = 50;
constexpr int timerDelayMediumMs = 200;
constexpr int timerDelayLongMs = 500;
constexpr int awaitAllTimingSlackMs = 25;
constexpr int awaitAllLongCancelTimingMarginMs = 50;
using CallerDriver =
sscl::tests::RoleNonViralPostingInvoker<
sscl::tests::PostingThreadRole::CALLER>;
using CalleeIntInvoker =
sscl::tests::RoleViralPostingInvoker<
sscl::tests::PostingThreadRole::CALLEE,
int>;
using Clock = std::chrono::steady_clock;
using Ms = std::chrono::milliseconds;
CalleeIntInvoker waitDeadlineTimer(int timerLabelMilliseconds)
{
const boost::system::error_code waitError =
co_await sscl::tests::DeadlineTimerAwaiter{
sscl::ComponentThread::getSelf()->getIoContext(),
timerLabelMilliseconds};
sscl::tests::throwIfTimerWaitFailed(waitError);
co_return timerLabelMilliseconds;
}
CalleeIntInvoker waitCancelableDeadlineTimer(
int timerLabelMilliseconds,
sscl::tests::CancelableDeadlineTimerRegistry &registry)
{
const boost::system::error_code waitError =
co_await sscl::tests::RegisteredDeadlineTimerAwaiter{
sscl::ComponentThread::getSelf()->getIoContext(),
timerLabelMilliseconds,
timerLabelMilliseconds,
registry};
if (sscl::tests::timerWasCanceled(waitError)) {
co_return timerLabelMilliseconds;
}
sscl::tests::throwIfTimerWaitFailed(waitError);
co_return timerLabelMilliseconds;
}
void throwIfElapsedTooLong(
const Ms &elapsed,
const Ms &limit,
const char *message)
{
if (elapsed > limit) {
throw std::runtime_error(
std::string(message) + ": " + std::to_string(elapsed.count()));
}
}
void throwIfElapsedTooShort(
const Ms &elapsed,
const Ms &limit,
const char *message)
{
if (elapsed < limit) {
throw std::runtime_error(
std::string(message) + ": " + std::to_string(elapsed.count()));
}
}
CallerDriver runGroupTimerRace(
std::exception_ptr &exceptionPtr,
std::function<void()> completion)
{
(void)exceptionPtr;
(void)completion;
sscl::co::Group group;
CalleeIntInvoker invokerShort = waitDeadlineTimer(timerDelayShortMs);
CalleeIntInvoker invokerMedium = waitDeadlineTimer(timerDelayMediumMs);
CalleeIntInvoker invokerLong = waitDeadlineTimer(timerDelayLongMs);
group.add(invokerShort);
group.add(invokerMedium);
group.add(invokerLong);
const auto testStart = Clock::now();
auto awaitFirst = group.getAwaitFirstSettlementInvoker();
auto [firstSettlement, allSettlementsAfterFirst] = co_await awaitFirst;
const auto firstElapsedMs =
std::chrono::duration_cast<Ms>(Clock::now() - testStart);
throwIfElapsedTooLong(
firstElapsedMs,
Ms(timerDelayMediumMs - awaitAllTimingSlackMs),
"await-first took too long");
if (&firstSettlement.invokerAs<CalleeIntInvoker>() != &invokerShort) {
throw std::runtime_error("first settlement was not shortest timer");
}
if (group.allInvokersSettled()) {
throw std::runtime_error("await-first returned after all settled");
}
auto awaitAll = group.getAwaitAllSettlementsInvoker();
auto &allSettlements = co_await awaitAll;
const auto allElapsedMs =
std::chrono::duration_cast<Ms>(Clock::now() - testStart);
throwIfElapsedTooShort(
allElapsedMs,
Ms(timerDelayLongMs - awaitAllLongCancelTimingMarginMs),
"await-all finished too soon");
if (allSettlements.size() != 3) {
throw std::runtime_error("expected three settlements");
}
sscl::tests::expectCompletedIntSettlement<CalleeIntInvoker>(
firstSettlement,
timerDelayShortMs);
sscl::tests::expectCompletedIntSettlement<CalleeIntInvoker>(
allSettlementsAfterFirst[0],
timerDelayShortMs);
sscl::tests::expectCompletedIntSettlement<CalleeIntInvoker>(
allSettlementsAfterFirst[1],
timerDelayMediumMs);
sscl::tests::expectCompletedIntSettlement<CalleeIntInvoker>(
allSettlementsAfterFirst[2],
timerDelayLongMs);
co_return;
}
CallerDriver runGroupTimerCancelLongAfterAwaitFirst(
std::exception_ptr &exceptionPtr,
std::function<void()> completion,
sscl::tests::CancelableDeadlineTimerRegistry &registry)
{
(void)exceptionPtr;
(void)completion;
sscl::co::Group group;
CalleeIntInvoker invokerShort =
waitCancelableDeadlineTimer(timerDelayShortMs, registry);
CalleeIntInvoker invokerMedium =
waitCancelableDeadlineTimer(timerDelayMediumMs, registry);
CalleeIntInvoker invokerLong =
waitCancelableDeadlineTimer(timerDelayLongMs, registry);
group.add(invokerShort);
group.add(invokerMedium);
group.add(invokerLong);
const auto testStart = Clock::now();
auto awaitFirst = group.getAwaitFirstSettlementInvoker();
auto [firstSettlement, allSettlementsAfterFirst] = co_await awaitFirst;
if (&firstSettlement.invokerAs<CalleeIntInvoker>() != &invokerShort) {
throw std::runtime_error("cancel test first settlement mismatch");
}
if (group.allInvokersSettled()) {
throw std::runtime_error("cancel test all settled after await-first");
}
registry.cancel(timerDelayLongMs);
auto awaitAll = group.getAwaitAllSettlementsInvoker();
auto &allSettlements = co_await awaitAll;
const auto allElapsedMs =
std::chrono::duration_cast<Ms>(Clock::now() - testStart);
if (allElapsedMs >= Ms(timerDelayLongMs - awaitAllLongCancelTimingMarginMs)) {
throw std::runtime_error("await-all waited for canceled long timer");
}
throwIfElapsedTooShort(
allElapsedMs,
Ms(timerDelayMediumMs - awaitAllTimingSlackMs),
"await-all finished before medium timer");
if (allSettlements.size() != 3) {
throw std::runtime_error("cancel test expected three settlements");
}
sscl::tests::expectCompletedIntSettlement<CalleeIntInvoker>(
allSettlements[0],
timerDelayShortMs);
sscl::tests::expectCompletedIntSettlement<CalleeIntInvoker>(
allSettlements[1],
timerDelayMediumMs);
sscl::tests::expectCompletedIntSettlement<CalleeIntInvoker>(
allSettlements[2],
timerDelayLongMs);
if (&allSettlements[2].invokerAs<CalleeIntInvoker>() != &invokerLong) {
throw std::runtime_error("cancel test long invoker mismatch");
}
(void)allSettlementsAfterFirst;
co_return;
}
class GroupTimerTest
: public ::testing::Test
{
protected:
sscl::tests::PostingThreadSet threads;
};
} // namespace
TEST_F(GroupTimerTest, AwaitFirstReturnsShortestTimerAndAwaitAllWaitsForLongest)
{
ASSERT_NO_THROW(
sscl::tests::runNonViralPostingTask(
threads.caller(),
[](
std::exception_ptr &exceptionPtr,
std::function<void()> completion)
{
return runGroupTimerRace(
exceptionPtr,
std::move(completion));
}));
}
TEST_F(GroupTimerTest, CancelLongTimerAfterAwaitFirst)
{
sscl::tests::CancelableDeadlineTimerRegistry registry;
ASSERT_NO_THROW(
sscl::tests::runNonViralPostingTask(
threads.caller(),
[&registry](
std::exception_ptr &exceptionPtr,
std::function<void()> completion)
{
return runGroupTimerCancelLongAfterAwaitFirst(
exceptionPtr,
std::move(completion),
registry);
}));
}
+252
View File
@@ -0,0 +1,252 @@
#include <exception>
#include <functional>
#include <stdexcept>
#include <string>
#include <gtest/gtest.h>
#include <spinscale/co/postTarget.h>
#include <spinscale/componentThread.h>
#include <support/threadHarness.h>
#include <support/timerAwaiters.h>
namespace {
constexpr int expectedReturnValue = 42;
constexpr int explicitTargetReturnValue = 77;
constexpr const char *expectedThrowMessage =
"posting cross-thread intentional failure";
using CallerNonViralInvoker =
sscl::tests::RoleNonViralPostingInvoker<
sscl::tests::PostingThreadRole::CALLER>;
using CalleeNonViralInvoker =
sscl::tests::RoleNonViralPostingInvoker<
sscl::tests::PostingThreadRole::CALLEE>;
template <typename T>
using CalleeViralInvoker =
sscl::tests::RoleViralPostingInvoker<
sscl::tests::PostingThreadRole::CALLEE,
T>;
CalleeViralInvoker<int> returnFromCalleeThread(
sscl::tests::CrossThreadTrace &trace)
{
trace.recordCalleeExecutionThread();
trace.recordFinalSuspendThread();
co_return expectedReturnValue;
}
CalleeViralInvoker<int> returnFromExplicitTargetThread(
sscl::co::ExplicitPostTarget postTarget,
sscl::tests::CrossThreadTrace &trace)
{
(void)postTarget;
trace.recordCalleeExecutionThread();
trace.recordFinalSuspendThread();
co_return explicitTargetReturnValue;
}
CalleeViralInvoker<int> throwFromCalleeThread(
sscl::tests::CrossThreadTrace &trace)
{
constexpr int throwDelayMs = 1;
const boost::system::error_code waitError =
co_await sscl::tests::DeadlineTimerAwaiter{
sscl::ComponentThread::getSelf()->getIoContext(),
throwDelayMs};
sscl::tests::throwIfTimerWaitFailed(waitError);
trace.recordCalleeExecutionThread();
trace.recordFinalSuspendThread();
throw std::runtime_error(expectedThrowMessage);
}
CallerNonViralInvoker awaitCalleeDriver(
std::exception_ptr &exceptionPtr,
std::function<void()> completion,
sscl::tests::CrossThreadTrace &trace)
{
(void)exceptionPtr;
(void)completion;
const int value = co_await returnFromCalleeThread(trace);
trace.recordAwaitResumeThread();
if (value != expectedReturnValue) {
throw std::runtime_error("Unexpected callee return value");
}
co_return;
}
CallerNonViralInvoker awaitExplicitTargetDriver(
std::exception_ptr &exceptionPtr,
std::function<void()> completion,
sscl::tests::CrossThreadTrace &trace)
{
(void)exceptionPtr;
(void)completion;
sscl::co::ExplicitPostTarget postTarget{
sscl::tests::ThreadRegistry::ioContext(
sscl::tests::PostingThreadRole::ALTERNATE)};
const int value = co_await returnFromExplicitTargetThread(
postTarget,
trace);
trace.recordAwaitResumeThread();
if (value != explicitTargetReturnValue) {
throw std::runtime_error("Unexpected explicit-target return value");
}
co_return;
}
CallerNonViralInvoker awaitThrowingCalleeDriver(
std::exception_ptr &exceptionPtr,
std::function<void()> completion,
sscl::tests::CrossThreadTrace &trace)
{
(void)exceptionPtr;
(void)completion;
try {
(void)co_await throwFromCalleeThread(trace);
throw std::runtime_error("Expected callee exception");
}
catch (const std::runtime_error &runtimeError) {
trace.recordAwaitResumeThread();
if (std::string(runtimeError.what()) != expectedThrowMessage) {
throw std::runtime_error("Unexpected callee exception message");
}
}
co_return;
}
CalleeNonViralInvoker nonViralCalleeCompletesToCaller(
std::exception_ptr &exceptionPtr,
std::function<void()> completion,
sscl::tests::CrossThreadTrace &trace)
{
(void)exceptionPtr;
(void)completion;
trace.recordCalleeExecutionThread();
trace.recordFinalSuspendThread();
co_return;
}
class PostingCrossThreadTest
: public ::testing::Test
{
protected:
sscl::tests::PostingThreadSet threads;
};
} // namespace
TEST_F(PostingCrossThreadTest, ViralAwaitPostsCalleeAndResumesCaller)
{
sscl::tests::CrossThreadTrace trace;
ASSERT_NO_THROW(
sscl::tests::runNonViralPostingTask(
threads.caller(),
[&trace](
std::exception_ptr &exceptionPtr,
std::function<void()> completion)
{
trace.recordConstructionThread();
return awaitCalleeDriver(
exceptionPtr,
std::move(completion),
trace);
}));
EXPECT_EQ(trace.constructionThread(), threads.caller().osThreadId());
EXPECT_EQ(trace.calleeExecutionThread(), threads.callee().osThreadId());
EXPECT_EQ(trace.finalSuspendThread(), threads.callee().osThreadId());
EXPECT_EQ(trace.awaitResumeThread(), threads.caller().osThreadId());
EXPECT_NE(trace.calleeExecutionThread(), trace.awaitResumeThread());
}
TEST_F(PostingCrossThreadTest, NonViralCompletionPostsBackToCaller)
{
sscl::tests::CrossThreadTrace trace;
ASSERT_NO_THROW(
sscl::tests::runNonViralPostingTask(
threads.caller(),
[&trace](
std::exception_ptr &exceptionPtr,
std::function<void()> completion)
{
trace.recordConstructionThread();
return nonViralCalleeCompletesToCaller(
exceptionPtr,
[&trace, completion = std::move(completion)]() mutable
{
trace.recordCompletionCallbackThread();
completion();
},
trace);
}));
EXPECT_EQ(trace.constructionThread(), threads.caller().osThreadId());
EXPECT_EQ(trace.calleeExecutionThread(), threads.callee().osThreadId());
EXPECT_EQ(trace.finalSuspendThread(), threads.callee().osThreadId());
EXPECT_EQ(trace.completionCallbackThread(), threads.caller().osThreadId());
EXPECT_NE(trace.calleeExecutionThread(), trace.completionCallbackThread());
}
TEST_F(PostingCrossThreadTest, ExplicitPostTargetRoutesCalleeExecution)
{
sscl::tests::CrossThreadTrace trace;
ASSERT_NO_THROW(
sscl::tests::runNonViralPostingTask(
threads.caller(),
[&trace](
std::exception_ptr &exceptionPtr,
std::function<void()> completion)
{
trace.recordConstructionThread();
return awaitExplicitTargetDriver(
exceptionPtr,
std::move(completion),
trace);
}));
EXPECT_EQ(trace.constructionThread(), threads.caller().osThreadId());
EXPECT_EQ(trace.calleeExecutionThread(), threads.alternate().osThreadId());
EXPECT_EQ(trace.awaitResumeThread(), threads.caller().osThreadId());
EXPECT_NE(trace.calleeExecutionThread(), threads.callee().osThreadId());
EXPECT_NE(trace.calleeExecutionThread(), trace.awaitResumeThread());
}
TEST_F(PostingCrossThreadTest, CalleeExceptionIsObservedOnCallerThread)
{
sscl::tests::CrossThreadTrace trace;
ASSERT_NO_THROW(
sscl::tests::runNonViralPostingTask(
threads.caller(),
[&trace](
std::exception_ptr &exceptionPtr,
std::function<void()> completion)
{
trace.recordConstructionThread();
return awaitThrowingCalleeDriver(
exceptionPtr,
std::move(completion),
trace);
}));
EXPECT_EQ(trace.constructionThread(), threads.caller().osThreadId());
EXPECT_EQ(trace.calleeExecutionThread(), threads.callee().osThreadId());
EXPECT_EQ(trace.awaitResumeThread(), threads.caller().osThreadId());
EXPECT_NE(trace.calleeExecutionThread(), trace.awaitResumeThread());
}
+511
View File
@@ -0,0 +1,511 @@
#include <chrono>
#include <exception>
#include <memory>
#include <stdexcept>
#include <string>
#include <thread>
#include <utility>
#include <gtest/gtest.h>
#include <boost/asio/io_context.hpp>
#include <boost/system/error_code.hpp>
#include <spinscale/co/invokers.h>
#include <support/threadHarness.h>
#include <support/timerAwaiters.h>
namespace {
constexpr int delayShortMs = 50;
constexpr int expectedNonStdThrowValue = 42;
constexpr const char *expectedThrowMessage =
"viral_non_posting_test intentional failure";
template <typename T>
using TestInvoker = sscl::co::ViralNonPostingInvoker<T>;
using TestDriver = TestInvoker<int>;
using TestVoidDriver = TestInvoker<void>;
struct ThreadIdPair
{
std::thread::id callerIdAtCoAwait;
std::thread::id calleeId;
};
struct MoveCountedInt
{
std::shared_ptr<std::size_t> moveCount;
int value = 0;
MoveCountedInt() = default;
MoveCountedInt(
std::shared_ptr<std::size_t> moveCountIn,
int valueIn)
: moveCount(std::move(moveCountIn)),
value(valueIn)
{}
MoveCountedInt(const MoveCountedInt &) = delete;
MoveCountedInt &operator=(const MoveCountedInt &) = delete;
MoveCountedInt(MoveCountedInt &&other) noexcept
: moveCount(std::exchange(other.moveCount, {})),
value(other.value)
{
if (moveCount) {
++(*moveCount);
}
}
MoveCountedInt &operator=(MoveCountedInt &&other) noexcept
{
moveCount = std::exchange(other.moveCount, {});
value = other.value;
return *this;
}
};
template <typename T>
struct CountingAwaiter
{
TestInvoker<T> &invoker;
std::size_t &awaitResumeCallCount;
bool await_ready() const noexcept
{ return invoker.await_ready(); }
template <typename CallerPromise>
bool await_suspend(
std::coroutine_handle<CallerPromise> callerSchedHandle) noexcept
{ return invoker.await_suspend(callerSchedHandle); }
auto await_resume()
{
++awaitResumeCallCount;
return invoker.await_resume();
}
};
class ViralNonPostingTest
: public ::testing::Test
{
protected:
void TearDown() override
{
ioContext.restart();
}
int runDriver(TestDriver &driver)
{
sscl::tests::IoContextPump::pumpUntilIdle(ioContext);
return finishDriver(driver);
}
int finishDriver(TestDriver &driver)
{
if (driver.completedReturnValues().myExceptionPtr) {
std::rethrow_exception(
driver.completedReturnValues().myExceptionPtr);
}
return driver.completedReturnValues().myReturnValue;
}
boost::asio::io_context ioContext;
};
TestInvoker<int> returnLabelImmediately(int label)
{
co_return label;
}
TestInvoker<int> waitAndReturnLabel(
boost::asio::io_context &ioContext,
int delayMilliseconds)
{
const boost::system::error_code waitError =
co_await sscl::tests::DeadlineTimerAwaiter{
ioContext,
delayMilliseconds};
sscl::tests::throwIfTimerWaitFailed(waitError);
co_return delayMilliseconds;
}
TestVoidDriver voidReturnImmediately()
{
co_return;
}
TestInvoker<int> throwRuntimeErrorImmediately()
{
throw std::runtime_error(expectedThrowMessage);
}
TestInvoker<int> throwIntImmediately()
{
throw expectedNonStdThrowValue;
}
TestInvoker<ThreadIdPair> recordThreadIdsAtReturn()
{
ThreadIdPair pair;
pair.calleeId = std::this_thread::get_id();
co_return pair;
}
TestInvoker<ThreadIdPair> recordThreadIdsAfterDelay(
boost::asio::io_context &ioContext,
int delayMilliseconds)
{
const boost::system::error_code waitError =
co_await sscl::tests::DeadlineTimerAwaiter{
ioContext,
delayMilliseconds};
sscl::tests::throwIfTimerWaitFailed(waitError);
ThreadIdPair pair;
pair.calleeId = std::this_thread::get_id();
co_return pair;
}
TestInvoker<MoveCountedInt> returnMoveCountedInt(
std::shared_ptr<std::size_t> moveCount,
int value)
{
co_return MoveCountedInt{std::move(moveCount), value};
}
TestInvoker<int> innerDelayedCoAwait(
boost::asio::io_context &ioContext,
int delayMilliseconds)
{
const int label = co_await waitAndReturnLabel(
ioContext,
delayMilliseconds);
co_return label;
}
TestInvoker<int> nestedNonPostingSum(int left, int right)
{
const int leftSum = co_await returnLabelImmediately(left);
const int rightSum = co_await returnLabelImmediately(right);
co_return leftSum + rightSum;
}
TestInvoker<int> outerCoAwaitingDelayedInner(
boost::asio::io_context &ioContext,
int delayMilliseconds)
{
const int innerLabel = co_await innerDelayedCoAwait(
ioContext,
delayMilliseconds);
co_return innerLabel + 1;
}
TestDriver testImmediateReturnFastPath()
{
const int value = co_await returnLabelImmediately(42);
if (value != 42) {
throw std::runtime_error("immediateReturnFastPath value mismatch");
}
co_return 0;
}
TestDriver testAllCompleteBeforeCoAwait()
{
TestInvoker<int> invokerTen = returnLabelImmediately(10);
TestInvoker<int> invokerTwenty = returnLabelImmediately(20);
TestInvoker<int> invokerThirty = returnLabelImmediately(30);
const int valueTen = co_await invokerTen;
const int valueTwenty = co_await invokerTwenty;
const int valueThirty = co_await invokerThirty;
if (valueTen != 10 || valueTwenty != 20 || valueThirty != 30) {
throw std::runtime_error("allCompleteBeforeCoAwait label mismatch");
}
co_return 0;
}
TestDriver testCallerSuspendsThenResumes(boost::asio::io_context &ioContext)
{
const int value = co_await waitAndReturnLabel(ioContext, delayShortMs);
if (value != delayShortMs) {
throw std::runtime_error("callerSuspendsThenResumes label mismatch");
}
co_return 0;
}
TestDriver testMixedImmediateAndDelayedInSequence(
boost::asio::io_context &ioContext)
{
const int immediate = co_await returnLabelImmediately(7);
const int delayed = co_await waitAndReturnLabel(ioContext, delayShortMs);
if (immediate != 7 || delayed != delayShortMs) {
throw std::runtime_error("mixedImmediateAndDelayed label mismatch");
}
co_return 0;
}
TestDriver testAwaitResumeCalledOnceFastPath()
{
std::size_t awaitResumeCallCount = 0;
TestInvoker<int> invoker = returnLabelImmediately(42);
const int value = co_await CountingAwaiter<int>{
invoker,
awaitResumeCallCount};
if (value != 42 || awaitResumeCallCount != 1) {
throw std::runtime_error("fast path await_resume count mismatch");
}
co_return 0;
}
TestDriver testAwaitResumeCalledOnceSlowPath(
boost::asio::io_context &ioContext)
{
std::size_t awaitResumeCallCount = 0;
TestInvoker<int> invoker = waitAndReturnLabel(ioContext, delayShortMs);
const int value = co_await CountingAwaiter<int>{
invoker,
awaitResumeCallCount};
if (value != delayShortMs || awaitResumeCallCount != 1) {
throw std::runtime_error("slow path await_resume count mismatch");
}
co_return 0;
}
TestDriver testAwaitResumeCalledOnceNested(
boost::asio::io_context &ioContext)
{
std::size_t awaitResumeCallCount = 0;
TestInvoker<int> inner = innerDelayedCoAwait(ioContext, delayShortMs);
const int value = co_await CountingAwaiter<int>{
inner,
awaitResumeCallCount};
if (value != delayShortMs || awaitResumeCallCount != 1) {
throw std::runtime_error("nested await_resume count mismatch");
}
co_return 0;
}
TestDriver testMoveCountedReturnNotDoubleMoved()
{
auto moveCount = std::make_shared<std::size_t>(0);
TestInvoker<MoveCountedInt> invoker =
returnMoveCountedInt(moveCount, 99);
MoveCountedInt result = co_await invoker;
if (result.value != 99) {
throw std::runtime_error("move counted value mismatch");
}
if (*moveCount > 2 || *moveCount < 1) {
throw std::runtime_error("move counted return move-count mismatch");
}
co_return 0;
}
TestDriver testVoidReturnCompletes()
{
co_await voidReturnImmediately();
co_return 0;
}
TestDriver testReturnValuesReadableBeforeDestroy()
{
TestInvoker<int> invoker = returnLabelImmediately(55);
(void)co_await invoker;
if (invoker.completedReturnValues().myReturnValue != 55) {
throw std::runtime_error("completed return value not readable");
}
co_return 0;
}
TestDriver testExceptionRethrowsOnCoAwait()
{
try {
(void)co_await throwRuntimeErrorImmediately();
throw std::runtime_error("expected runtime_error");
}
catch (const std::runtime_error &runtimeError) {
if (std::string(runtimeError.what()) != expectedThrowMessage) {
throw std::runtime_error("unexpected runtime_error message");
}
}
co_return 0;
}
TestDriver testNonStdExceptionRethrows()
{
try {
(void)co_await throwIntImmediately();
throw std::runtime_error("expected int exception");
}
catch (int caughtValue) {
if (caughtValue != expectedNonStdThrowValue) {
throw std::runtime_error("unexpected int exception value");
}
}
co_return 0;
}
TestDriver testCalleeRunsOnCallerThread()
{
const std::thread::id callerThreadId = std::this_thread::get_id();
const ThreadIdPair pair = co_await recordThreadIdsAtReturn();
if (pair.calleeId != callerThreadId) {
throw std::runtime_error("callee thread mismatch");
}
co_return 0;
}
TestDriver testDelayedCalleeStillOnCallerThread(
boost::asio::io_context &ioContext)
{
const std::thread::id callerThreadId = std::this_thread::get_id();
const ThreadIdPair pair =
co_await recordThreadIdsAfterDelay(ioContext, delayShortMs);
if (pair.calleeId != callerThreadId) {
throw std::runtime_error("delayed callee thread mismatch");
}
co_return 0;
}
TestDriver testNestedNonPostingCoAwait()
{
const int sum = co_await nestedNonPostingSum(10, 32);
if (sum != 42) {
throw std::runtime_error("nested sum mismatch");
}
co_return 0;
}
TestDriver testNestedInnerSuspension(boost::asio::io_context &ioContext)
{
const int value = co_await outerCoAwaitingDelayedInner(
ioContext,
delayShortMs);
if (value != delayShortMs + 1) {
throw std::runtime_error("nested inner suspension value mismatch");
}
co_return 0;
}
} // namespace
TEST_F(ViralNonPostingTest, ImmediateReturnFastPath)
{
TestDriver driver = testImmediateReturnFastPath();
EXPECT_NO_THROW({ EXPECT_EQ(finishDriver(driver), 0); });
}
TEST_F(ViralNonPostingTest, AllCompleteBeforeCoAwait)
{
TestDriver driver = testAllCompleteBeforeCoAwait();
EXPECT_NO_THROW({ EXPECT_EQ(finishDriver(driver), 0); });
}
TEST_F(ViralNonPostingTest, CallerSuspendsThenResumes)
{
TestDriver driver = testCallerSuspendsThenResumes(ioContext);
EXPECT_NO_THROW({ EXPECT_EQ(runDriver(driver), 0); });
}
TEST_F(ViralNonPostingTest, MixedImmediateAndDelayedInSequence)
{
TestDriver driver = testMixedImmediateAndDelayedInSequence(ioContext);
EXPECT_NO_THROW({ EXPECT_EQ(runDriver(driver), 0); });
}
TEST_F(ViralNonPostingTest, AwaitResumeCalledOnceFastPath)
{
TestDriver driver = testAwaitResumeCalledOnceFastPath();
EXPECT_NO_THROW({ EXPECT_EQ(finishDriver(driver), 0); });
}
TEST_F(ViralNonPostingTest, AwaitResumeCalledOnceSlowPath)
{
TestDriver driver = testAwaitResumeCalledOnceSlowPath(ioContext);
EXPECT_NO_THROW({ EXPECT_EQ(runDriver(driver), 0); });
}
TEST_F(ViralNonPostingTest, AwaitResumeCalledOnceNested)
{
TestDriver driver = testAwaitResumeCalledOnceNested(ioContext);
EXPECT_NO_THROW({ EXPECT_EQ(runDriver(driver), 0); });
}
TEST_F(ViralNonPostingTest, MoveCountedReturnNotDoubleMoved)
{
TestDriver driver = testMoveCountedReturnNotDoubleMoved();
EXPECT_NO_THROW({ EXPECT_EQ(finishDriver(driver), 0); });
}
TEST_F(ViralNonPostingTest, VoidReturnCompletes)
{
TestDriver driver = testVoidReturnCompletes();
EXPECT_NO_THROW({ EXPECT_EQ(finishDriver(driver), 0); });
}
TEST_F(ViralNonPostingTest, ReturnValuesReadableBeforeDestroy)
{
TestDriver driver = testReturnValuesReadableBeforeDestroy();
EXPECT_NO_THROW({ EXPECT_EQ(finishDriver(driver), 0); });
}
TEST_F(ViralNonPostingTest, ExceptionRethrowsOnCoAwait)
{
TestDriver driver = testExceptionRethrowsOnCoAwait();
EXPECT_NO_THROW({ EXPECT_EQ(finishDriver(driver), 0); });
}
TEST_F(ViralNonPostingTest, NonStdExceptionRethrows)
{
TestDriver driver = testNonStdExceptionRethrows();
EXPECT_NO_THROW({ EXPECT_EQ(finishDriver(driver), 0); });
}
TEST_F(ViralNonPostingTest, CalleeRunsOnCallerThread)
{
TestDriver driver = testCalleeRunsOnCallerThread();
EXPECT_NO_THROW({ EXPECT_EQ(finishDriver(driver), 0); });
}
TEST_F(ViralNonPostingTest, DelayedCalleeStillOnCallerThread)
{
TestDriver driver = testDelayedCalleeStillOnCallerThread(ioContext);
EXPECT_NO_THROW({ EXPECT_EQ(runDriver(driver), 0); });
}
TEST_F(ViralNonPostingTest, NestedNonPostingCoAwait)
{
TestDriver driver = testNestedNonPostingCoAwait();
EXPECT_NO_THROW({ EXPECT_EQ(finishDriver(driver), 0); });
}
TEST_F(ViralNonPostingTest, NestedInnerSuspension)
{
TestDriver driver = testNestedInnerSuspension(ioContext);
EXPECT_NO_THROW({ EXPECT_EQ(runDriver(driver), 0); });
}
-2
View File
@@ -330,8 +330,6 @@ TEST_F(QutexTest, Release) {
// Test release without a prior acquire is rejected
TEST_F(QutexTest, ReleaseWithoutAcquireThrows) {
qutex.isOwned = true;
EXPECT_THROW(qutex.release(), std::runtime_error);
EXPECT_TRUE(qutex.queue.empty());
}
+104
View File
@@ -0,0 +1,104 @@
#ifndef SPINSCALE_TEST_SUPPORT_GROUP_ASSERTIONS_H
#define SPINSCALE_TEST_SUPPORT_GROUP_ASSERTIONS_H
#include <exception>
#include <string>
#include <gtest/gtest.h>
#include <spinscale/co/group.h>
namespace sscl::tests {
template <typename Invoker>
int completedIntValue(Invoker &invoker)
{
if (invoker.completedReturnValues().myExceptionPtr) {
std::rethrow_exception(
invoker.completedReturnValues().myExceptionPtr);
}
return invoker.completedReturnValues().myReturnValue;
}
inline void expectCompletedSettlement(
const sscl::co::Group::SettlementDescriptor &descriptor)
{
EXPECT_EQ(
descriptor.type,
sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED);
}
template <typename Invoker>
void expectCompletedIntSettlement(
const sscl::co::Group::SettlementDescriptor &descriptor,
int expectedValue)
{
ASSERT_EQ(
descriptor.type,
sscl::co::Group::SettlementDescriptor::TypeE::COMPLETED);
EXPECT_EQ(completedIntValue(descriptor.invokerAs<Invoker>()), expectedValue);
}
inline void expectExceptionSettlement(
const sscl::co::Group::SettlementDescriptor &descriptor)
{
EXPECT_EQ(
descriptor.type,
sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN);
EXPECT_TRUE(descriptor.calleeException != nullptr);
}
inline void expectRuntimeErrorSettlement(
const sscl::co::Group::SettlementDescriptor &descriptor,
const std::string &expectedMessage)
{
ASSERT_EQ(
descriptor.type,
sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN);
ASSERT_TRUE(descriptor.calleeException != nullptr);
try {
std::rethrow_exception(descriptor.calleeException);
}
catch (const std::runtime_error &runtimeError) {
EXPECT_EQ(std::string(runtimeError.what()), expectedMessage);
return;
}
catch (...) {
FAIL() << "Expected std::runtime_error settlement.";
}
}
inline void expectIntExceptionSettlement(
const sscl::co::Group::SettlementDescriptor &descriptor,
int expectedValue)
{
ASSERT_EQ(
descriptor.type,
sscl::co::Group::SettlementDescriptor::TypeE::EXCEPTION_THROWN);
ASSERT_TRUE(descriptor.calleeException != nullptr);
try {
std::rethrow_exception(descriptor.calleeException);
}
catch (int caughtValue) {
EXPECT_EQ(caughtValue, expectedValue);
return;
}
catch (...) {
FAIL() << "Expected int exception settlement.";
}
}
inline void expectEmptyGroupError(
const std::runtime_error &runtimeError)
{
constexpr const char *expectedMessage =
"co_await: Group has no member invokers; call add() before awaiting";
EXPECT_EQ(std::string(runtimeError.what()), expectedMessage);
}
} // namespace sscl::tests
#endif // SPINSCALE_TEST_SUPPORT_GROUP_ASSERTIONS_H
+413
View File
@@ -0,0 +1,413 @@
#include <support/threadHarness.h>
#include <cstdlib>
#include <iostream>
namespace sscl::tests {
struct DedicatedIoThread::StartupState
{
std::mutex mutex;
std::condition_variable condition;
std::thread::id osThreadId;
std::exception_ptr startupException;
bool allowInitialization = false;
bool initialized = false;
};
namespace {
constexpr const char *callerThreadName = "test:caller";
constexpr const char *calleeThreadName = "test:callee";
constexpr const char *alternateThreadName = "test:alternate";
constexpr const char *bodyThreadName = "test:body";
constexpr const char *worldThreadName = "test:world";
constexpr const char *legThreadName = "test:leg";
void runDedicatedThread(
const std::shared_ptr<DedicatedIoThread::StartupState> &state,
const sscl::PuppeteerThread::EntryFnArguments &args)
{
{
std::unique_lock<std::mutex> lock(state->mutex);
state->condition.wait(
lock,
[&state]() { return state->allowInitialization; });
}
try
{
args.usableBeforeJolt.initializeTls();
{
std::lock_guard<std::mutex> guard(state->mutex);
state->osThreadId = std::this_thread::get_id();
state->initialized = true;
}
state->condition.notify_all();
args.usableBeforeJolt.getIoContext().restart();
args.usableBeforeJolt.getIoContext().run();
}
catch (...)
{
{
std::lock_guard<std::mutex> guard(state->mutex);
state->startupException = std::current_exception();
state->initialized = true;
}
state->condition.notify_all();
}
}
} // namespace
std::string threadRoleName(PostingThreadRole role)
{
switch (role)
{
case PostingThreadRole::CALLER:
return callerThreadName;
case PostingThreadRole::CALLEE:
return calleeThreadName;
case PostingThreadRole::ALTERNATE:
return alternateThreadName;
case PostingThreadRole::BODY:
return bodyThreadName;
case PostingThreadRole::WORLD:
return worldThreadName;
case PostingThreadRole::LEG:
return legThreadName;
}
throw std::runtime_error("Unknown PostingThreadRole");
}
void IoContextPump::pumpUntilIdle(
boost::asio::io_context &ioContext,
std::chrono::milliseconds idleTimeout,
std::chrono::milliseconds totalTimeout)
{
const auto totalDeadline =
std::chrono::steady_clock::now() + totalTimeout;
auto lastProgress = std::chrono::steady_clock::now();
while (std::chrono::steady_clock::now() < totalDeadline)
{
if (ioContext.poll_one() > 0)
{
lastProgress = std::chrono::steady_clock::now();
continue;
}
if (std::chrono::steady_clock::now() - lastProgress >= idleTimeout) {
return;
}
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
}
ThreadBoundComponent::ThreadBoundComponent()
: sscl::pptr::PuppeteerComponent(nullptr)
{
}
void ThreadBoundComponent::handleLoopExceptionHook()
{
loopException = std::current_exception();
}
DedicatedIoThread::DedicatedIoThread(PostingThreadRole roleIn)
: role(roleIn),
startupState(std::make_shared<StartupState>()),
component(),
thread(std::make_shared<sscl::PuppeteerThread>(
static_cast<sscl::ThreadId>(roleIn),
threadRoleName(roleIn),
[state = startupState](
const sscl::PuppeteerThread::EntryFnArguments &args)
{
runDedicatedThread(state, args);
},
component,
nullptr))
{
component.thread = thread;
releaseStartupBarrier();
waitUntilInitialized();
}
DedicatedIoThread::~DedicatedIoThread()
{
stopAndJoin();
}
boost::asio::io_context &DedicatedIoThread::ioContext()
{
return thread->getIoContext();
}
sscl::ThreadId DedicatedIoThread::threadId() const noexcept
{
return static_cast<sscl::ThreadId>(role);
}
std::thread::id DedicatedIoThread::osThreadId() const
{
std::lock_guard<std::mutex> guard(startupState->mutex);
return startupState->osThreadId;
}
std::shared_ptr<sscl::PuppeteerThread> DedicatedIoThread::componentThread() const
{
return thread;
}
void DedicatedIoThread::stopAndJoin()
{
if (!thread) {
return;
}
releaseStartupBarrier();
thread->getIoContext().stop();
if (thread->thread.joinable()) {
thread->thread.join();
}
thread.reset();
}
void DedicatedIoThread::releaseStartupBarrier()
{
{
std::lock_guard<std::mutex> guard(startupState->mutex);
startupState->allowInitialization = true;
}
startupState->condition.notify_all();
}
void DedicatedIoThread::waitUntilInitialized()
{
std::unique_lock<std::mutex> lock(startupState->mutex);
const bool initialized = startupState->condition.wait_for(
lock,
defaultPostingTaskTimeout,
[this]() { return startupState->initialized; });
if (!initialized) {
throw std::runtime_error("Timed out waiting for test thread startup");
}
std::exception_ptr startupException = startupState->startupException;
lock.unlock();
if (startupException) {
std::rethrow_exception(startupException);
}
}
void ThreadRegistry::registerThread(
PostingThreadRole role,
DedicatedIoThread &thread)
{
std::lock_guard<std::mutex> guard(registryMutex());
threadsByRole()[role] = &thread;
}
void ThreadRegistry::unregisterThread(PostingThreadRole role)
{
std::lock_guard<std::mutex> guard(registryMutex());
threadsByRole().erase(role);
}
boost::asio::io_context &ThreadRegistry::ioContext(PostingThreadRole role)
{
std::lock_guard<std::mutex> guard(registryMutex());
auto iterator = threadsByRole().find(role);
if (iterator == threadsByRole().end()) {
throw std::runtime_error(
"No test thread registered for " + threadRoleName(role));
}
return iterator->second->ioContext();
}
std::thread::id ThreadRegistry::osThreadId(PostingThreadRole role)
{
std::lock_guard<std::mutex> guard(registryMutex());
auto iterator = threadsByRole().find(role);
if (iterator == threadsByRole().end()) {
throw std::runtime_error(
"No test thread registered for " + threadRoleName(role));
}
return iterator->second->osThreadId();
}
std::mutex &ThreadRegistry::registryMutex()
{
static std::mutex mutex;
return mutex;
}
std::map<PostingThreadRole, DedicatedIoThread *> &
ThreadRegistry::threadsByRole()
{
static std::map<PostingThreadRole, DedicatedIoThread *> threads;
return threads;
}
PostingThreadSet::PostingThreadSet()
: callerThread(PostingThreadRole::CALLER),
calleeThread(PostingThreadRole::CALLEE),
alternateThread(PostingThreadRole::ALTERNATE),
bodyThread(PostingThreadRole::BODY),
worldThread(PostingThreadRole::WORLD),
legThread(PostingThreadRole::LEG)
{
ThreadRegistry::registerThread(PostingThreadRole::CALLER, callerThread);
ThreadRegistry::registerThread(PostingThreadRole::CALLEE, calleeThread);
ThreadRegistry::registerThread(PostingThreadRole::ALTERNATE, alternateThread);
ThreadRegistry::registerThread(PostingThreadRole::BODY, bodyThread);
ThreadRegistry::registerThread(PostingThreadRole::WORLD, worldThread);
ThreadRegistry::registerThread(PostingThreadRole::LEG, legThread);
sscl::ComponentThread::setPuppeteerThreadId(
static_cast<sscl::ThreadId>(PostingThreadRole::CALLER));
sscl::ComponentThread::setPuppeteerThread(callerThread.componentThread());
}
PostingThreadSet::~PostingThreadSet()
{
ThreadRegistry::unregisterThread(PostingThreadRole::CALLER);
ThreadRegistry::unregisterThread(PostingThreadRole::CALLEE);
ThreadRegistry::unregisterThread(PostingThreadRole::ALTERNATE);
ThreadRegistry::unregisterThread(PostingThreadRole::BODY);
ThreadRegistry::unregisterThread(PostingThreadRole::WORLD);
ThreadRegistry::unregisterThread(PostingThreadRole::LEG);
sscl::ComponentThread::setPuppeteerThread(nullptr);
}
DedicatedIoThread &PostingThreadSet::thread(PostingThreadRole role)
{
switch (role)
{
case PostingThreadRole::CALLER:
return callerThread;
case PostingThreadRole::CALLEE:
return calleeThread;
case PostingThreadRole::ALTERNATE:
return alternateThread;
case PostingThreadRole::BODY:
return bodyThread;
case PostingThreadRole::WORLD:
return worldThread;
case PostingThreadRole::LEG:
return legThread;
}
throw std::runtime_error("Unknown PostingThreadRole");
}
DedicatedIoThread &PostingThreadSet::caller()
{
return callerThread;
}
DedicatedIoThread &PostingThreadSet::callee()
{
return calleeThread;
}
DedicatedIoThread &PostingThreadSet::alternate()
{
return alternateThread;
}
DedicatedIoThread &PostingThreadSet::body()
{
return bodyThread;
}
DedicatedIoThread &PostingThreadSet::world()
{
return worldThread;
}
DedicatedIoThread &PostingThreadSet::leg()
{
return legThread;
}
void CrossThreadTrace::recordConstructionThread()
{
record(constructionThreadId);
}
void CrossThreadTrace::recordCalleeExecutionThread()
{
record(calleeExecutionThreadId);
}
void CrossThreadTrace::recordFinalSuspendThread()
{
record(finalSuspendThreadId);
}
void CrossThreadTrace::recordAwaitResumeThread()
{
record(awaitResumeThreadId);
}
void CrossThreadTrace::recordCompletionCallbackThread()
{
record(completionCallbackThreadId);
}
std::thread::id CrossThreadTrace::constructionThread() const
{
return read(constructionThreadId);
}
std::thread::id CrossThreadTrace::calleeExecutionThread() const
{
return read(calleeExecutionThreadId);
}
std::thread::id CrossThreadTrace::finalSuspendThread() const
{
return read(finalSuspendThreadId);
}
std::thread::id CrossThreadTrace::awaitResumeThread() const
{
return read(awaitResumeThreadId);
}
std::thread::id CrossThreadTrace::completionCallbackThread() const
{
return read(completionCallbackThreadId);
}
void CrossThreadTrace::record(std::thread::id &slot)
{
std::lock_guard<std::mutex> guard(mutex);
slot = std::this_thread::get_id();
}
std::thread::id CrossThreadTrace::read(const std::thread::id &slot) const
{
std::lock_guard<std::mutex> guard(mutex);
return slot;
}
} // namespace sscl::tests
+362
View File
@@ -0,0 +1,362 @@
#ifndef SPINSCALE_TEST_SUPPORT_THREAD_HARNESS_H
#define SPINSCALE_TEST_SUPPORT_THREAD_HARNESS_H
#include <chrono>
#include <condition_variable>
#include <exception>
#include <functional>
#include <future>
#include <map>
#include <memory>
#include <mutex>
#include <optional>
#include <stdexcept>
#include <string>
#include <thread>
#include <type_traits>
#include <utility>
#include <boost/asio/io_context.hpp>
#include <boost/asio/post.hpp>
#include <spinscale/co/invokers.h>
#include <spinscale/co/postingPromise.h>
#include <spinscale/component.h>
#include <spinscale/componentThread.h>
namespace sscl::tests {
constexpr std::chrono::milliseconds defaultIdleTimeout{800};
constexpr std::chrono::milliseconds defaultTotalTimeout{10000};
constexpr std::chrono::milliseconds defaultPostingTaskTimeout{10000};
enum class PostingThreadRole : sscl::ThreadId
{
CALLER = 70,
CALLEE = 71,
ALTERNATE = 72,
BODY = 73,
WORLD = 74,
LEG = 75,
};
std::string threadRoleName(PostingThreadRole role);
class IoContextPump
{
public:
static void pumpUntilIdle(
boost::asio::io_context &ioContext,
std::chrono::milliseconds idleTimeout = defaultIdleTimeout,
std::chrono::milliseconds totalTimeout = defaultTotalTimeout);
template <typename Predicate>
static bool pumpUntil(
boost::asio::io_context &ioContext,
Predicate &&predicate,
std::chrono::milliseconds idleTimeout = defaultIdleTimeout,
std::chrono::milliseconds totalTimeout = defaultTotalTimeout)
{
const auto totalDeadline =
std::chrono::steady_clock::now() + totalTimeout;
auto lastProgress = std::chrono::steady_clock::now();
while (std::chrono::steady_clock::now() < totalDeadline)
{
if (std::invoke(predicate)) {
return true;
}
if (ioContext.poll_one() > 0)
{
lastProgress = std::chrono::steady_clock::now();
continue;
}
if (std::chrono::steady_clock::now() - lastProgress >= idleTimeout) {
return std::invoke(predicate);
}
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
return std::invoke(predicate);
}
};
class ThreadBoundComponent final
: public sscl::pptr::PuppeteerComponent
{
public:
ThreadBoundComponent();
void handleLoopExceptionHook() override;
std::exception_ptr loopException;
};
class DedicatedIoThread
{
public:
explicit DedicatedIoThread(PostingThreadRole role);
~DedicatedIoThread();
DedicatedIoThread(const DedicatedIoThread &) = delete;
DedicatedIoThread &operator=(const DedicatedIoThread &) = delete;
DedicatedIoThread(DedicatedIoThread &&) = delete;
DedicatedIoThread &operator=(DedicatedIoThread &&) = delete;
boost::asio::io_context &ioContext();
sscl::ThreadId threadId() const noexcept;
std::thread::id osThreadId() const;
std::shared_ptr<sscl::PuppeteerThread> componentThread() const;
void stopAndJoin();
struct StartupState;
template <typename Function>
void post(Function &&function)
{
boost::asio::post(
ioContext(),
std::forward<Function>(function));
}
template <typename Function>
auto runSync(Function &&function)
-> std::invoke_result_t<Function &>
{
using Result = std::invoke_result_t<Function &>;
if (std::this_thread::get_id() == osThreadId()) {
if constexpr (std::is_void_v<Result>) {
std::invoke(function);
return;
} else {
return std::invoke(function);
}
}
auto promise = std::make_shared<std::promise<Result>>();
auto future = promise->get_future();
post(
[promise, function = std::forward<Function>(function)]() mutable
{
try
{
if constexpr (std::is_void_v<Result>)
{
std::invoke(function);
promise->set_value();
}
else
{
promise->set_value(std::invoke(function));
}
}
catch (...)
{
promise->set_exception(std::current_exception());
}
});
return future.get();
}
private:
void releaseStartupBarrier();
void waitUntilInitialized();
PostingThreadRole role;
std::shared_ptr<StartupState> startupState;
ThreadBoundComponent component;
std::shared_ptr<sscl::PuppeteerThread> thread;
};
class ThreadRegistry
{
public:
static void registerThread(
PostingThreadRole role,
DedicatedIoThread &thread);
static void unregisterThread(PostingThreadRole role);
static boost::asio::io_context &ioContext(PostingThreadRole role);
static std::thread::id osThreadId(PostingThreadRole role);
private:
static std::mutex &registryMutex();
static std::map<PostingThreadRole, DedicatedIoThread *> &threadsByRole();
};
template <PostingThreadRole role>
struct PostingThreadTag
{
static boost::asio::io_context &io_context()
{
return ThreadRegistry::ioContext(role);
}
};
template <PostingThreadRole role, typename T>
using RolePostingPromise =
sscl::co::TaggedPostingPromise<T, PostingThreadTag<role>>;
template <PostingThreadRole role>
struct RolePostingPromiseTemplate
{
template <typename T>
using Type = RolePostingPromise<role, T>;
};
template <PostingThreadRole role, typename T>
using RoleViralPostingInvoker =
sscl::co::ViralPostingInvoker<
RolePostingPromiseTemplate<role>::template Type,
T>;
template <PostingThreadRole role>
using RoleNonViralPostingInvoker =
sscl::co::NonViralPostingInvoker<
RolePostingPromiseTemplate<role>::template Type>;
class PostingThreadSet
{
public:
PostingThreadSet();
~PostingThreadSet();
PostingThreadSet(const PostingThreadSet &) = delete;
PostingThreadSet &operator=(const PostingThreadSet &) = delete;
PostingThreadSet(PostingThreadSet &&) = delete;
PostingThreadSet &operator=(PostingThreadSet &&) = delete;
DedicatedIoThread &thread(PostingThreadRole role);
DedicatedIoThread &caller();
DedicatedIoThread &callee();
DedicatedIoThread &alternate();
DedicatedIoThread &body();
DedicatedIoThread &world();
DedicatedIoThread &leg();
private:
DedicatedIoThread callerThread;
DedicatedIoThread calleeThread;
DedicatedIoThread alternateThread;
DedicatedIoThread bodyThread;
DedicatedIoThread worldThread;
DedicatedIoThread legThread;
};
class CrossThreadTrace
{
public:
void recordConstructionThread();
void recordCalleeExecutionThread();
void recordFinalSuspendThread();
void recordAwaitResumeThread();
void recordCompletionCallbackThread();
std::thread::id constructionThread() const;
std::thread::id calleeExecutionThread() const;
std::thread::id finalSuspendThread() const;
std::thread::id awaitResumeThread() const;
std::thread::id completionCallbackThread() const;
private:
void record(std::thread::id &slot);
std::thread::id read(const std::thread::id &slot) const;
mutable std::mutex mutex;
std::thread::id constructionThreadId;
std::thread::id calleeExecutionThreadId;
std::thread::id finalSuspendThreadId;
std::thread::id awaitResumeThreadId;
std::thread::id completionCallbackThreadId;
};
template <typename InvokerFactory>
void runNonViralPostingTask(
DedicatedIoThread &callerThread,
InvokerFactory &&invokerFactory,
std::chrono::milliseconds timeout = defaultPostingTaskTimeout)
{
using Factory = std::decay_t<InvokerFactory>;
using Invoker = std::invoke_result_t<
Factory &, std::exception_ptr &, std::function<void()>>;
struct TaskState
{
explicit TaskState(Factory factoryIn)
: factory(std::move(factoryIn))
{}
Factory factory;
std::exception_ptr coroutineException;
std::exception_ptr taskException;
std::optional<Invoker> invoker;
std::mutex mutex;
std::condition_variable condition;
bool completed = false;
};
auto taskState = std::make_shared<TaskState>(
std::forward<InvokerFactory>(invokerFactory));
callerThread.post(
[taskState]()
{
auto completeTask = [taskState]()
{
taskState->taskException = taskState->coroutineException;
taskState->invoker.reset();
{
std::lock_guard<std::mutex> guard(taskState->mutex);
taskState->completed = true;
}
taskState->condition.notify_one();
};
try
{
taskState->invoker.emplace(
std::invoke(
taskState->factory,
taskState->coroutineException,
std::move(completeTask)));
}
catch (...)
{
{
std::lock_guard<std::mutex> guard(taskState->mutex);
taskState->taskException = std::current_exception();
taskState->completed = true;
}
taskState->condition.notify_one();
}
});
std::unique_lock<std::mutex> lock(taskState->mutex);
const bool completed = taskState->condition.wait_for(
lock,
timeout,
[&taskState]() { return taskState->completed; });
if (!completed) {
throw std::runtime_error("Timed out waiting for posting coroutine task");
}
std::exception_ptr taskException = taskState->taskException;
lock.unlock();
if (taskException) {
std::rethrow_exception(taskException);
}
}
} // namespace sscl::tests
#endif // SPINSCALE_TEST_SUPPORT_THREAD_HARNESS_H
+161
View File
@@ -0,0 +1,161 @@
#ifndef SPINSCALE_TEST_SUPPORT_TIMER_AWAITERS_H
#define SPINSCALE_TEST_SUPPORT_TIMER_AWAITERS_H
#include <coroutine>
#include <memory>
#include <mutex>
#include <optional>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <boost/asio/deadline_timer.hpp>
#include <boost/asio/error.hpp>
#include <boost/asio/io_context.hpp>
#include <boost/date_time/posix_time/posix_time_types.hpp>
#include <boost/system/error_code.hpp>
namespace sscl::tests {
using SharedDeadlineTimer = std::shared_ptr<boost::asio::deadline_timer>;
class CancelableDeadlineTimerRegistry
{
public:
void clear()
{
std::lock_guard<std::mutex> guard(mutex);
timersByLabel.clear();
}
void registerTimer(
int labelMilliseconds,
const SharedDeadlineTimer &timer)
{
std::lock_guard<std::mutex> guard(mutex);
timersByLabel[labelMilliseconds] = timer;
}
void cancel(int labelMilliseconds)
{
std::lock_guard<std::mutex> guard(mutex);
const auto iterator = timersByLabel.find(labelMilliseconds);
if (iterator == timersByLabel.end()) {
throw std::runtime_error(
"No cancelable deadline_timer registered for label "
+ std::to_string(labelMilliseconds));
}
const SharedDeadlineTimer timer = iterator->second.lock();
if (!timer) {
throw std::runtime_error(
"Cancelable deadline_timer expired before cancel for label "
+ std::to_string(labelMilliseconds));
}
timer->cancel();
}
private:
std::mutex mutex;
std::unordered_map<int, std::weak_ptr<boost::asio::deadline_timer>>
timersByLabel;
};
struct DeadlineTimerAwaiter
{
DeadlineTimerAwaiter(
boost::asio::io_context &ioContext,
int delayMilliseconds)
: timer(std::make_shared<boost::asio::deadline_timer>(ioContext))
{
start(delayMilliseconds);
}
DeadlineTimerAwaiter(
SharedDeadlineTimer sharedTimer,
int delayMilliseconds)
: timer(std::move(sharedTimer))
{
start(delayMilliseconds);
}
bool await_ready() const noexcept
{ return waitCompleted; }
bool await_suspend(std::coroutine_handle<> handle) noexcept
{
resumeHandle = handle;
return !waitCompleted;
}
boost::system::error_code await_resume() const noexcept
{ return completionErrorCode; }
private:
void start(int delayMilliseconds)
{
timer->expires_from_now(
boost::posix_time::milliseconds(delayMilliseconds));
timer->async_wait(
[this](const boost::system::error_code &errorCode)
{
completionErrorCode = errorCode;
waitCompleted = true;
if (resumeHandle) {
resumeHandle.resume();
}
});
}
SharedDeadlineTimer timer;
boost::system::error_code completionErrorCode;
bool waitCompleted = false;
std::coroutine_handle<> resumeHandle;
};
struct RegisteredDeadlineTimerAwaiter
{
RegisteredDeadlineTimerAwaiter(
boost::asio::io_context &ioContext,
int delayMilliseconds,
int registrationLabelMilliseconds,
CancelableDeadlineTimerRegistry &registry)
: timer(std::make_shared<boost::asio::deadline_timer>(ioContext))
{
registry.registerTimer(registrationLabelMilliseconds, timer);
waiter.emplace(timer, delayMilliseconds);
}
bool await_ready() const noexcept
{ return waiter->await_ready(); }
bool await_suspend(std::coroutine_handle<> handle) noexcept
{ return waiter->await_suspend(handle); }
boost::system::error_code await_resume() const noexcept
{ return waiter->await_resume(); }
SharedDeadlineTimer timer;
std::optional<DeadlineTimerAwaiter> waiter;
};
inline void throwIfTimerWaitFailed(
const boost::system::error_code &waitError)
{
if (waitError) {
throw std::runtime_error(
"deadline_timer wait failed: " + waitError.message());
}
}
inline bool timerWasCanceled(const boost::system::error_code &waitError)
{
return waitError == boost::asio::error::operation_aborted;
}
} // namespace sscl::tests
#endif // SPINSCALE_TEST_SUPPORT_TIMER_AWAITERS_H