Tests: Add all tests from the coro creation repo

We went back and brought along all the tests we implemented while
we were building the new coro framework.
This commit is contained in:
2026-06-13 17:17:57 -04:00
parent 1763685c0e
commit a29c779f6e
11 changed files with 3199 additions and 28 deletions
+362
View File
@@ -0,0 +1,362 @@
#ifndef SPINSCALE_TEST_SUPPORT_THREAD_HARNESS_H
#define SPINSCALE_TEST_SUPPORT_THREAD_HARNESS_H
#include <chrono>
#include <condition_variable>
#include <exception>
#include <functional>
#include <future>
#include <map>
#include <memory>
#include <mutex>
#include <optional>
#include <stdexcept>
#include <string>
#include <thread>
#include <type_traits>
#include <utility>
#include <boost/asio/io_context.hpp>
#include <boost/asio/post.hpp>
#include <spinscale/co/invokers.h>
#include <spinscale/co/postingPromise.h>
#include <spinscale/component.h>
#include <spinscale/componentThread.h>
namespace sscl::tests {
constexpr std::chrono::milliseconds defaultIdleTimeout{800};
constexpr std::chrono::milliseconds defaultTotalTimeout{10000};
constexpr std::chrono::milliseconds defaultPostingTaskTimeout{10000};
enum class PostingThreadRole : sscl::ThreadId
{
CALLER = 70,
CALLEE = 71,
ALTERNATE = 72,
BODY = 73,
WORLD = 74,
LEG = 75,
};
std::string threadRoleName(PostingThreadRole role);
class IoContextPump
{
public:
static void pumpUntilIdle(
boost::asio::io_context &ioContext,
std::chrono::milliseconds idleTimeout = defaultIdleTimeout,
std::chrono::milliseconds totalTimeout = defaultTotalTimeout);
template <typename Predicate>
static bool pumpUntil(
boost::asio::io_context &ioContext,
Predicate &&predicate,
std::chrono::milliseconds idleTimeout = defaultIdleTimeout,
std::chrono::milliseconds totalTimeout = defaultTotalTimeout)
{
const auto totalDeadline =
std::chrono::steady_clock::now() + totalTimeout;
auto lastProgress = std::chrono::steady_clock::now();
while (std::chrono::steady_clock::now() < totalDeadline)
{
if (std::invoke(predicate)) {
return true;
}
if (ioContext.poll_one() > 0)
{
lastProgress = std::chrono::steady_clock::now();
continue;
}
if (std::chrono::steady_clock::now() - lastProgress >= idleTimeout) {
return std::invoke(predicate);
}
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
return std::invoke(predicate);
}
};
class ThreadBoundComponent final
: public sscl::pptr::PuppeteerComponent
{
public:
ThreadBoundComponent();
void handleLoopExceptionHook() override;
std::exception_ptr loopException;
};
class DedicatedIoThread
{
public:
explicit DedicatedIoThread(PostingThreadRole role);
~DedicatedIoThread();
DedicatedIoThread(const DedicatedIoThread &) = delete;
DedicatedIoThread &operator=(const DedicatedIoThread &) = delete;
DedicatedIoThread(DedicatedIoThread &&) = delete;
DedicatedIoThread &operator=(DedicatedIoThread &&) = delete;
boost::asio::io_context &ioContext();
sscl::ThreadId threadId() const noexcept;
std::thread::id osThreadId() const;
std::shared_ptr<sscl::PuppeteerThread> componentThread() const;
void stopAndJoin();
struct StartupState;
template <typename Function>
void post(Function &&function)
{
boost::asio::post(
ioContext(),
std::forward<Function>(function));
}
template <typename Function>
auto runSync(Function &&function)
-> std::invoke_result_t<Function &>
{
using Result = std::invoke_result_t<Function &>;
if (std::this_thread::get_id() == osThreadId()) {
if constexpr (std::is_void_v<Result>) {
std::invoke(function);
return;
} else {
return std::invoke(function);
}
}
auto promise = std::make_shared<std::promise<Result>>();
auto future = promise->get_future();
post(
[promise, function = std::forward<Function>(function)]() mutable
{
try
{
if constexpr (std::is_void_v<Result>)
{
std::invoke(function);
promise->set_value();
}
else
{
promise->set_value(std::invoke(function));
}
}
catch (...)
{
promise->set_exception(std::current_exception());
}
});
return future.get();
}
private:
void releaseStartupBarrier();
void waitUntilInitialized();
PostingThreadRole role;
std::shared_ptr<StartupState> startupState;
ThreadBoundComponent component;
std::shared_ptr<sscl::PuppeteerThread> thread;
};
class ThreadRegistry
{
public:
static void registerThread(
PostingThreadRole role,
DedicatedIoThread &thread);
static void unregisterThread(PostingThreadRole role);
static boost::asio::io_context &ioContext(PostingThreadRole role);
static std::thread::id osThreadId(PostingThreadRole role);
private:
static std::mutex &registryMutex();
static std::map<PostingThreadRole, DedicatedIoThread *> &threadsByRole();
};
template <PostingThreadRole role>
struct PostingThreadTag
{
static boost::asio::io_context &io_context()
{
return ThreadRegistry::ioContext(role);
}
};
template <PostingThreadRole role, typename T>
using RolePostingPromise =
sscl::co::TaggedPostingPromise<T, PostingThreadTag<role>>;
template <PostingThreadRole role>
struct RolePostingPromiseTemplate
{
template <typename T>
using Type = RolePostingPromise<role, T>;
};
template <PostingThreadRole role, typename T>
using RoleViralPostingInvoker =
sscl::co::ViralPostingInvoker<
RolePostingPromiseTemplate<role>::template Type,
T>;
template <PostingThreadRole role>
using RoleNonViralPostingInvoker =
sscl::co::NonViralPostingInvoker<
RolePostingPromiseTemplate<role>::template Type>;
class PostingThreadSet
{
public:
PostingThreadSet();
~PostingThreadSet();
PostingThreadSet(const PostingThreadSet &) = delete;
PostingThreadSet &operator=(const PostingThreadSet &) = delete;
PostingThreadSet(PostingThreadSet &&) = delete;
PostingThreadSet &operator=(PostingThreadSet &&) = delete;
DedicatedIoThread &thread(PostingThreadRole role);
DedicatedIoThread &caller();
DedicatedIoThread &callee();
DedicatedIoThread &alternate();
DedicatedIoThread &body();
DedicatedIoThread &world();
DedicatedIoThread &leg();
private:
DedicatedIoThread callerThread;
DedicatedIoThread calleeThread;
DedicatedIoThread alternateThread;
DedicatedIoThread bodyThread;
DedicatedIoThread worldThread;
DedicatedIoThread legThread;
};
class CrossThreadTrace
{
public:
void recordConstructionThread();
void recordCalleeExecutionThread();
void recordFinalSuspendThread();
void recordAwaitResumeThread();
void recordCompletionCallbackThread();
std::thread::id constructionThread() const;
std::thread::id calleeExecutionThread() const;
std::thread::id finalSuspendThread() const;
std::thread::id awaitResumeThread() const;
std::thread::id completionCallbackThread() const;
private:
void record(std::thread::id &slot);
std::thread::id read(const std::thread::id &slot) const;
mutable std::mutex mutex;
std::thread::id constructionThreadId;
std::thread::id calleeExecutionThreadId;
std::thread::id finalSuspendThreadId;
std::thread::id awaitResumeThreadId;
std::thread::id completionCallbackThreadId;
};
template <typename InvokerFactory>
void runNonViralPostingTask(
DedicatedIoThread &callerThread,
InvokerFactory &&invokerFactory,
std::chrono::milliseconds timeout = defaultPostingTaskTimeout)
{
using Factory = std::decay_t<InvokerFactory>;
using Invoker = std::invoke_result_t<
Factory &, std::exception_ptr &, std::function<void()>>;
struct TaskState
{
explicit TaskState(Factory factoryIn)
: factory(std::move(factoryIn))
{}
Factory factory;
std::exception_ptr coroutineException;
std::exception_ptr taskException;
std::optional<Invoker> invoker;
std::mutex mutex;
std::condition_variable condition;
bool completed = false;
};
auto taskState = std::make_shared<TaskState>(
std::forward<InvokerFactory>(invokerFactory));
callerThread.post(
[taskState]()
{
auto completeTask = [taskState]()
{
taskState->taskException = taskState->coroutineException;
taskState->invoker.reset();
{
std::lock_guard<std::mutex> guard(taskState->mutex);
taskState->completed = true;
}
taskState->condition.notify_one();
};
try
{
taskState->invoker.emplace(
std::invoke(
taskState->factory,
taskState->coroutineException,
std::move(completeTask)));
}
catch (...)
{
{
std::lock_guard<std::mutex> guard(taskState->mutex);
taskState->taskException = std::current_exception();
taskState->completed = true;
}
taskState->condition.notify_one();
}
});
std::unique_lock<std::mutex> lock(taskState->mutex);
const bool completed = taskState->condition.wait_for(
lock,
timeout,
[&taskState]() { return taskState->completed; });
if (!completed) {
throw std::runtime_error("Timed out waiting for posting coroutine task");
}
std::exception_ptr taskException = taskState->taskException;
lock.unlock();
if (taskException) {
std::rethrow_exception(taskException);
}
}
} // namespace sscl::tests
#endif // SPINSCALE_TEST_SUPPORT_THREAD_HARNESS_H