355 lines
9.2 KiB
C++
355 lines
9.2 KiB
C++
#include <iostream>
|
|
#include <stdexcept>
|
|
#include <optional>
|
|
#include <algorithm>
|
|
#include <stimBuffApis/stimBuffApiManager.h>
|
|
#include <stimBuffApis/stimBuffApiLib.h>
|
|
#include <comparatorLibs/comparatorApiManager.h>
|
|
#include <loadableLib/loadableLibraryManager.h>
|
|
#include <body/bodyThread.h>
|
|
#include <componentThread.h>
|
|
#include <opts.h>
|
|
#include <user/smoHooks.h>
|
|
#include <mind.h>
|
|
#include <deviceManager/deviceManager.h>
|
|
#include <marionette/marionette.h>
|
|
#include <computeManager/computeManager.h>
|
|
|
|
namespace smo {
|
|
namespace stim_buff {
|
|
|
|
namespace {
|
|
|
|
void assertBodyThread()
|
|
{
|
|
auto self = sscl::ComponentThread::getSelf();
|
|
if (self->id != SmoThreadId::BODY)
|
|
{
|
|
throw std::runtime_error(
|
|
std::string(__func__)
|
|
+ ": Must be executed on Body thread");
|
|
}
|
|
}
|
|
|
|
std::shared_ptr<sscl::ComponentThread> ComponentThread_getSelf()
|
|
{
|
|
return sscl::ComponentThread::getSelf();
|
|
}
|
|
|
|
/* Local static function to wrap OptionParser::getOptions for SmoCallbacks */
|
|
OptionParser& OptionParser_getOptions()
|
|
{
|
|
return OptionParser::getOptions();
|
|
}
|
|
|
|
/* Local static functions to wrap ComputeManager methods for SmoCallbacks */
|
|
std::shared_ptr<smo::compute::ClBuffer>
|
|
ComputeManager_createUseHostPtrBuffer(
|
|
void* hostPtr, size_t size, cl_mem_flags flags)
|
|
{
|
|
return smo::compute::ComputeManager::getInstance().createUseHostPtrBuffer(
|
|
hostPtr, size, flags);
|
|
}
|
|
|
|
void ComputeManager_releaseUseHostPtrBuffer(
|
|
std::shared_ptr<smo::compute::ClBuffer> buffer)
|
|
{
|
|
smo::compute::ComputeManager::getInstance().releaseUseHostPtrBuffer(
|
|
buffer);
|
|
}
|
|
|
|
std::shared_ptr<smo::compute::ComputeDevice> ComputeManager_getDevice()
|
|
{
|
|
return smo::compute::ComputeManager::getInstance().getDevice();
|
|
}
|
|
|
|
void ComputeManager_releaseDevice(
|
|
std::shared_ptr<smo::compute::ComputeDevice> device)
|
|
{
|
|
smo::compute::ComputeManager::getInstance().releaseDevice(device);
|
|
}
|
|
|
|
std::optional<std::string> searchForLibInSmoSearchPathsHook(
|
|
const std::string& libraryPath)
|
|
{
|
|
return loadable_lib::LoadableLibraryManager::getInstance()
|
|
.searchForLibInSmoSearchPaths(libraryPath);
|
|
}
|
|
|
|
std::shared_ptr<cologex::ExportedComparatorTypeDesc>
|
|
ComparatorManager_getComparatorTypeHook(cologex::ComparatorTypeId typeId)
|
|
{
|
|
return comparator_lib::ComparatorApiManager::getInstance()
|
|
.getComparatorType(typeId);
|
|
}
|
|
|
|
std::unique_ptr<cologex::Comparator> Comparator_getNewInstanceHook(
|
|
const std::shared_ptr<cologex::ExportedComparatorTypeDesc>& comparatorType)
|
|
{
|
|
return comparator_lib::ComparatorApiManager::getInstance()
|
|
.getNewComparatorInstance(comparatorType);
|
|
}
|
|
|
|
/* Hooks to be provided to stimBuffApiLibs, enabling them to call into Salmanoff
|
|
* code.
|
|
*/
|
|
SmoCallbacks smoCallbacks =
|
|
{
|
|
.searchForLibInSmoSearchPaths = searchForLibInSmoSearchPathsHook,
|
|
.ComponentThread_getSelf = ComponentThread_getSelf,
|
|
.OptionParser_getOptions = OptionParser_getOptions,
|
|
.ComputeManager_createUseHostPtrBuffer =
|
|
ComputeManager_createUseHostPtrBuffer,
|
|
.ComputeManager_releaseUseHostPtrBuffer =
|
|
ComputeManager_releaseUseHostPtrBuffer,
|
|
.ComputeManager_getDevice = ComputeManager_getDevice,
|
|
.ComputeManager_releaseDevice = ComputeManager_releaseDevice,
|
|
.ComparatorManager_getComparatorType =
|
|
ComparatorManager_getComparatorTypeHook,
|
|
.Comparator_getNewInstance = Comparator_getNewInstanceHook
|
|
};
|
|
|
|
} // namespace
|
|
|
|
const SmoCallbacks& getSmoCallbacks()
|
|
{
|
|
return smoCallbacks;
|
|
}
|
|
|
|
StimBuffApiLib& StimBuffApiManager::loadStimBuffApiLib(
|
|
const std::string& libraryPath)
|
|
{
|
|
loadable_lib::LoadableLibraryManager& llm =
|
|
loadable_lib::LoadableLibraryManager::getInstance();
|
|
|
|
if (findStimBuffApiLibByLibraryPath(libraryPath))
|
|
{
|
|
throw std::runtime_error(
|
|
std::string(__func__) + ": StimBuffApiLib already loaded: "
|
|
+ libraryPath);
|
|
}
|
|
|
|
std::shared_ptr<loadable_lib::LoadableLibraryManager::LoadedSharedLibrary>
|
|
loadedLibrary = llm.loadSharedLibrary(libraryPath);
|
|
|
|
auto descFn = loadable_lib::LoadableLibraryManager::resolveSymbol<
|
|
SMO_GET_STIM_BUFF_API_DESC_FN_TYPEDEF *>(
|
|
loadedLibrary->getDlopenHandle(),
|
|
SMO_GET_STIM_BUFF_API_DESC_FN_NAME_STR);
|
|
|
|
const stim_buff::SmoThreadingModelDesc& threadingModel =
|
|
mrntt::getSmoThreadingModelDesc();
|
|
if (!threadingModel.componentThread)
|
|
{
|
|
throw std::runtime_error(
|
|
std::string(__func__)
|
|
+ ": SmoThreadingModelDesc has not been initialized");
|
|
}
|
|
|
|
const StimBuffApiDesc& libApiDesc = descFn(
|
|
smoCallbacks, threadingModel);
|
|
|
|
auto lib = std::make_shared<StimBuffApiLib>(loadedLibrary, descFn);
|
|
lib->setStimBuffApiDesc(libApiDesc);
|
|
s.rsrc.libs.push_back(lib);
|
|
|
|
return *lib;
|
|
}
|
|
|
|
std::optional<std::shared_ptr<StimBuffApiLib>>
|
|
StimBuffApiManager::findStimBuffApiLibByLibraryPath(
|
|
const std::string& libraryPath)
|
|
{
|
|
auto& libs = s.rsrc.libs;
|
|
auto it = std::find_if(
|
|
libs.begin(), libs.end(),
|
|
[&libraryPath](const std::shared_ptr<StimBuffApiLib>& lib) {
|
|
return lib->loadedSharedLibrary->libraryPath == libraryPath;
|
|
});
|
|
|
|
if (it != libs.end()) { return *it; }
|
|
return std::nullopt;
|
|
}
|
|
|
|
std::optional<std::shared_ptr<StimBuffApiLib>>
|
|
StimBuffApiManager::findStimBuffApiLibByApiName(const std::string& apiName)
|
|
{
|
|
auto& libs = s.rsrc.libs;
|
|
auto it = std::find_if(
|
|
libs.begin(), libs.end(),
|
|
[&apiName](const std::shared_ptr<StimBuffApiLib>& lib) {
|
|
return lib->stimBuffApiDesc.name == apiName;
|
|
});
|
|
|
|
if (it != libs.end()) { return *it; }
|
|
return std::nullopt;
|
|
}
|
|
|
|
StimBuffApiLib& StimBuffApiManager::getStimBuffApiLibByApiName(
|
|
const std::string& apiName)
|
|
{
|
|
auto libOpt = findStimBuffApiLibByApiName(apiName);
|
|
if (!libOpt)
|
|
{
|
|
throw std::runtime_error(
|
|
std::string(__func__) + ": No library for API '" + apiName + "'");
|
|
}
|
|
|
|
return *libOpt.value();
|
|
}
|
|
|
|
void StimBuffApiManager::unloadStimBuffApiLib(const std::string& libraryPath)
|
|
{
|
|
auto& libs = s.rsrc.libs;
|
|
auto it = std::find_if(
|
|
libs.begin(), libs.end(),
|
|
[&libraryPath](const std::shared_ptr<StimBuffApiLib>& lib) {
|
|
return lib->loadedSharedLibrary->libraryPath == libraryPath;
|
|
});
|
|
|
|
if (it == libs.end())
|
|
{
|
|
std::cerr << std::string(__func__) + ": Library not found: "
|
|
<< libraryPath << '\n';
|
|
return;
|
|
}
|
|
|
|
std::shared_ptr<loadable_lib::LoadableLibraryManager::LoadedSharedLibrary>
|
|
loadedLibrary = (*it)->loadedSharedLibrary;
|
|
libs.erase(it);
|
|
|
|
loadable_lib::LoadableLibraryManager::getInstance()
|
|
.unloadSharedLibrary(loadedLibrary);
|
|
}
|
|
|
|
void StimBuffApiManager::unloadAllStimBuffApiLibs(void)
|
|
{
|
|
std::vector<std::shared_ptr<
|
|
loadable_lib::LoadableLibraryManager::LoadedSharedLibrary>>
|
|
loadedLibrariesTmp;
|
|
loadedLibrariesTmp.reserve(s.rsrc.libs.size());
|
|
|
|
for (const auto& lib : s.rsrc.libs) {
|
|
loadedLibrariesTmp.push_back(lib->loadedSharedLibrary);
|
|
}
|
|
|
|
s.rsrc.libs.clear();
|
|
|
|
loadable_lib::LoadableLibraryManager& llm =
|
|
loadable_lib::LoadableLibraryManager::getInstance();
|
|
for (const auto& loadedLibrary : loadedLibrariesTmp) {
|
|
llm.unloadSharedLibrary(loadedLibrary);
|
|
}
|
|
}
|
|
|
|
void StimBuffApiManager::loadAllStimBuffApiLibsFromOptions(void)
|
|
{
|
|
const auto& options = OptionParser::getOptions();
|
|
for (const auto& libPath : options.senseApiLibs) {
|
|
loadStimBuffApiLib(libPath);
|
|
}
|
|
}
|
|
|
|
std::string StimBuffApiManager::stringifyLibs() const
|
|
{
|
|
std::string result;
|
|
for (const auto& lib : s.rsrc.libs)
|
|
{
|
|
if (!result.empty()) {
|
|
result += "\n";
|
|
}
|
|
result += lib->stringify();
|
|
}
|
|
return result;
|
|
}
|
|
|
|
body::BodyViralPostingInvoker<void>
|
|
StimBuffApiManager::initializeStimBuffApiLibCReq(
|
|
StimBuffApiLib& lib, bool acquireListLock)
|
|
{
|
|
assertBodyThread();
|
|
|
|
std::optional<sscl::co::CoQutex::ReleaseHandle> listGuard;
|
|
if (acquireListLock)
|
|
{
|
|
listGuard.emplace(
|
|
co_await s.lock.getAcquireInvocationAndSuspensionPolicy());
|
|
}
|
|
|
|
if (!lib.stimBuffApiDesc.sal_mgmt_libOps.initializeCInd)
|
|
{
|
|
throw std::runtime_error(
|
|
std::string(__func__) + ": initializeCInd() is NULL for library '"
|
|
+ lib.loadedSharedLibrary->libraryPath + "'");
|
|
}
|
|
|
|
sscl::co::CoQutex::ReleaseHandle libGuard =
|
|
co_await lib.s.lock.getAcquireInvocationAndSuspensionPolicy();
|
|
|
|
co_await lib.stimBuffApiDesc.sal_mgmt_libOps.initializeCInd();
|
|
|
|
co_return;
|
|
}
|
|
|
|
body::BodyViralPostingInvoker<void>
|
|
StimBuffApiManager::finalizeStimBuffApiLibCReq(
|
|
StimBuffApiLib& lib, bool acquireListLock)
|
|
{
|
|
assertBodyThread();
|
|
|
|
std::optional<sscl::co::CoQutex::ReleaseHandle> listGuard;
|
|
if (acquireListLock)
|
|
{
|
|
listGuard.emplace(
|
|
co_await s.lock.getAcquireInvocationAndSuspensionPolicy());
|
|
}
|
|
|
|
if (!lib.stimBuffApiDesc.sal_mgmt_libOps.finalizeCInd)
|
|
{
|
|
throw std::runtime_error(
|
|
std::string(__func__) + ": finalizeCInd() is NULL for library '"
|
|
+ lib.loadedSharedLibrary->libraryPath + "'");
|
|
}
|
|
|
|
sscl::co::CoQutex::ReleaseHandle libGuard =
|
|
co_await lib.s.lock.getAcquireInvocationAndSuspensionPolicy();
|
|
|
|
lib.loadedSharedLibrary->isBeingDestroyed.store(true);
|
|
co_await lib.stimBuffApiDesc.sal_mgmt_libOps.finalizeCInd();
|
|
|
|
co_return;
|
|
}
|
|
|
|
body::BodyViralPostingInvoker<void>
|
|
StimBuffApiManager::initializeAllStimBuffApiLibsCReq()
|
|
{
|
|
assertBodyThread();
|
|
|
|
sscl::co::CoQutex::ReleaseHandle listGuard =
|
|
co_await s.lock.getAcquireInvocationAndSuspensionPolicy();
|
|
|
|
for (auto& lib : s.rsrc.libs) {
|
|
co_await initializeStimBuffApiLibCReq(*lib, false);
|
|
}
|
|
|
|
co_return;
|
|
}
|
|
|
|
body::BodyViralPostingInvoker<void>
|
|
StimBuffApiManager::finalizeAllStimBuffApiLibsCReq()
|
|
{
|
|
assertBodyThread();
|
|
|
|
sscl::co::CoQutex::ReleaseHandle listGuard =
|
|
co_await s.lock.getAcquireInvocationAndSuspensionPolicy();
|
|
|
|
for (auto& lib : s.rsrc.libs) {
|
|
co_await finalizeStimBuffApiLibCReq(*lib, false);
|
|
}
|
|
|
|
co_return;
|
|
}
|
|
|
|
} // namespace stim_buff
|
|
} // namespace smo
|