308 lines
7.8 KiB
C++
308 lines
7.8 KiB
C++
|
|
#include <computeManager/computeManager.h>
|
||
|
|
#include <iostream>
|
||
|
|
#include <stdexcept>
|
||
|
|
#include <string>
|
||
|
|
#include <string_view>
|
||
|
|
#include <algorithm>
|
||
|
|
|
||
|
|
namespace smo {
|
||
|
|
namespace compute {
|
||
|
|
|
||
|
|
// Helper function to parse OpenCL version string
|
||
|
|
static std::pair<int, int> parseOpenClVersion(const std::string& versionStr)
|
||
|
|
{
|
||
|
|
size_t spacePos = versionStr.find(' ');
|
||
|
|
if (spacePos == std::string::npos) { return {-1, -1}; }
|
||
|
|
|
||
|
|
std::string versionNum = versionStr.substr(spacePos + 1);
|
||
|
|
size_t dotPos = versionNum.find('.');
|
||
|
|
if (dotPos == std::string::npos) { return {-1, -1}; }
|
||
|
|
|
||
|
|
try {
|
||
|
|
int major = std::stoi(versionNum.substr(0, dotPos));
|
||
|
|
int minor = std::stoi(versionNum.substr(dotPos + 1));
|
||
|
|
return {major, minor};
|
||
|
|
} catch (const std::exception&) {
|
||
|
|
return {-1, -1};
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Helper function to validate OpenCL version
|
||
|
|
static bool validateOpenClVersion(
|
||
|
|
std::string_view versionStr, std::string_view versionType,
|
||
|
|
int minMajor, int minMinor)
|
||
|
|
{
|
||
|
|
auto [major, minor] = parseOpenClVersion(std::string(versionStr));
|
||
|
|
|
||
|
|
if (major == -1 && minor == -1)
|
||
|
|
{
|
||
|
|
std::cerr << __func__ << ": failed to parse OpenCL " << versionType
|
||
|
|
<< " version: " << versionStr << std::endl;
|
||
|
|
return false;
|
||
|
|
}
|
||
|
|
|
||
|
|
if (major < minMajor || (major == minMajor && minor < minMinor))
|
||
|
|
{
|
||
|
|
std::cerr << __func__ << ": OpenCL " << versionType << " version "
|
||
|
|
<< major << "." << minor << " found, but " << minMajor << "."
|
||
|
|
<< minMinor << " or higher is required" << std::endl;
|
||
|
|
return false;
|
||
|
|
}
|
||
|
|
|
||
|
|
std::cout << __func__ << ": OpenCL " << versionType << " version: "
|
||
|
|
<< versionStr << std::endl;
|
||
|
|
return true;
|
||
|
|
}
|
||
|
|
|
||
|
|
ComputeDevice::ComputeDevice(cl_platform_id platformId, cl_device_id deviceId)
|
||
|
|
: platform(platformId), device(deviceId),
|
||
|
|
context(nullptr), commandQueue(nullptr)
|
||
|
|
{
|
||
|
|
cl_int err;
|
||
|
|
|
||
|
|
// Create context for this device
|
||
|
|
context = clCreateContext(
|
||
|
|
nullptr, 1, &device,
|
||
|
|
nullptr, nullptr, &err);
|
||
|
|
|
||
|
|
if (err != CL_SUCCESS || !context)
|
||
|
|
{
|
||
|
|
throw std::runtime_error(
|
||
|
|
std::string(__func__) + ": failed to create context for device: " +
|
||
|
|
std::to_string(err));
|
||
|
|
}
|
||
|
|
|
||
|
|
// Create command queue
|
||
|
|
cl_command_queue_properties queueProps = 0;
|
||
|
|
commandQueue = clCreateCommandQueue(
|
||
|
|
context, device, queueProps, &err);
|
||
|
|
|
||
|
|
if (err != CL_SUCCESS || !commandQueue)
|
||
|
|
{
|
||
|
|
clReleaseContext(context);
|
||
|
|
context = nullptr;
|
||
|
|
throw std::runtime_error(
|
||
|
|
std::string(__func__) + ": failed to create command queue for "
|
||
|
|
"device: " + std::to_string(err));
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
void ComputeManager::initialize()
|
||
|
|
{
|
||
|
|
if (initialized) { return; }
|
||
|
|
|
||
|
|
cl_int err;
|
||
|
|
|
||
|
|
// Get number of platforms
|
||
|
|
cl_uint numPlatforms = 0;
|
||
|
|
err = clGetPlatformIDs(0, nullptr, &numPlatforms);
|
||
|
|
if (err != CL_SUCCESS)
|
||
|
|
{
|
||
|
|
throw std::runtime_error(
|
||
|
|
std::string(__func__) + ": failed to get OpenCL platforms: " +
|
||
|
|
std::to_string(err));
|
||
|
|
}
|
||
|
|
if (numPlatforms == 0)
|
||
|
|
{
|
||
|
|
throw std::runtime_error(
|
||
|
|
std::string(__func__) + ": no OpenCL platforms found");
|
||
|
|
}
|
||
|
|
|
||
|
|
// Get all platforms
|
||
|
|
std::vector<cl_platform_id> platforms(numPlatforms);
|
||
|
|
err = clGetPlatformIDs(numPlatforms, platforms.data(), nullptr);
|
||
|
|
if (err != CL_SUCCESS)
|
||
|
|
{
|
||
|
|
throw std::runtime_error(
|
||
|
|
std::string(__func__) + ": failed to enumerate OpenCL platforms: " +
|
||
|
|
std::to_string(err));
|
||
|
|
}
|
||
|
|
|
||
|
|
// Enumerate devices for each platform
|
||
|
|
for (cl_uint p = 0; p < numPlatforms; ++p)
|
||
|
|
{
|
||
|
|
cl_platform_id platform = platforms[p];
|
||
|
|
|
||
|
|
// Check platform version
|
||
|
|
char platformVersion[128];
|
||
|
|
err = clGetPlatformInfo(
|
||
|
|
platform, CL_PLATFORM_VERSION,
|
||
|
|
sizeof(platformVersion), platformVersion, nullptr);
|
||
|
|
|
||
|
|
if (err == CL_SUCCESS)
|
||
|
|
{
|
||
|
|
if (!validateOpenClVersion(platformVersion, "platform", 1, 2))
|
||
|
|
{
|
||
|
|
std::cout << __func__ << ": skipping platform " << p
|
||
|
|
<< " with incompatible OpenCL version "
|
||
|
|
<< std::string(platformVersion) << std::endl;
|
||
|
|
continue;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Get number of devices
|
||
|
|
cl_uint numDevices = 0;
|
||
|
|
err = clGetDeviceIDs(
|
||
|
|
platform, CL_DEVICE_TYPE_ALL, 0, nullptr, &numDevices);
|
||
|
|
|
||
|
|
if (err != CL_SUCCESS || numDevices == 0)
|
||
|
|
{
|
||
|
|
std::cout << __func__ << ": skipping platform " << p
|
||
|
|
<< " with no devices" << std::endl;
|
||
|
|
continue;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Get all devices
|
||
|
|
std::vector<cl_device_id> platformDevices(numDevices);
|
||
|
|
err = clGetDeviceIDs(
|
||
|
|
platform, CL_DEVICE_TYPE_ALL, numDevices,
|
||
|
|
platformDevices.data(), nullptr);
|
||
|
|
|
||
|
|
if (err != CL_SUCCESS)
|
||
|
|
{
|
||
|
|
throw std::runtime_error(
|
||
|
|
std::string(__func__) + ": failed to enumerate devices for "
|
||
|
|
"platform " + std::to_string(p) + ": " + std::to_string(err));
|
||
|
|
}
|
||
|
|
|
||
|
|
// Create ComputeDevice for each device
|
||
|
|
for (cl_uint d = 0; d < numDevices; ++d)
|
||
|
|
{
|
||
|
|
cl_device_id device = platformDevices[d];
|
||
|
|
|
||
|
|
// Check device version
|
||
|
|
char deviceVersion[128];
|
||
|
|
err = clGetDeviceInfo(
|
||
|
|
device, CL_DEVICE_VERSION,
|
||
|
|
sizeof(deviceVersion), deviceVersion, nullptr);
|
||
|
|
|
||
|
|
if (err == CL_SUCCESS)
|
||
|
|
{
|
||
|
|
if (!validateOpenClVersion(deviceVersion, "device", 1, 2))
|
||
|
|
{
|
||
|
|
std::cout << __func__ << ": skipping device " << d
|
||
|
|
<< " with incompatible OpenCL version "
|
||
|
|
<< std::string(deviceVersion) << std::endl;
|
||
|
|
continue;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Create ComputeDevice (constructor creates context and queue)
|
||
|
|
try
|
||
|
|
{
|
||
|
|
auto deviceObj = std::make_shared<ComputeDevice>(
|
||
|
|
platform, device);
|
||
|
|
devices.push_back(deviceObj);
|
||
|
|
}
|
||
|
|
catch (const std::runtime_error& e)
|
||
|
|
{
|
||
|
|
// Re-throw with more context about which device/platform
|
||
|
|
throw std::runtime_error(
|
||
|
|
std::string(__func__) + ": failed to create ComputeDevice "
|
||
|
|
"for device " + std::to_string(d) + " on platform " +
|
||
|
|
std::to_string(p) + ": " + e.what());
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if (devices.empty())
|
||
|
|
{
|
||
|
|
throw std::runtime_error(
|
||
|
|
std::string(__func__) + ": no compatible OpenCL devices found");
|
||
|
|
}
|
||
|
|
|
||
|
|
initialized = true;
|
||
|
|
std::cout << __func__ << ": Initialized with " << devices.size()
|
||
|
|
<< " compute device(s)" << std::endl;
|
||
|
|
}
|
||
|
|
|
||
|
|
void ComputeManager::finalize()
|
||
|
|
{
|
||
|
|
if (!initialized) { return; }
|
||
|
|
|
||
|
|
// Release all devices (their shared_ptrs will clean up contexts/queues)
|
||
|
|
devices.clear();
|
||
|
|
initialized = false;
|
||
|
|
std::cout << __func__ << ": Finalized" << std::endl;
|
||
|
|
}
|
||
|
|
|
||
|
|
ClBuffer::ClBuffer(void* hostPtr, size_t size, cl_mem_flags flags,
|
||
|
|
const std::vector<std::shared_ptr<ComputeDevice>>& devices)
|
||
|
|
: hostPtr(hostPtr), size(size), flags(flags)
|
||
|
|
{
|
||
|
|
associations.reserve(devices.size());
|
||
|
|
|
||
|
|
// Create a buffer for each device's context
|
||
|
|
for (const auto& device : devices)
|
||
|
|
{
|
||
|
|
if (!device->context) { continue; }
|
||
|
|
|
||
|
|
cl_int err;
|
||
|
|
cl_mem_flags bufferFlags = CL_MEM_USE_HOST_PTR | flags;
|
||
|
|
cl_mem buffer = clCreateBuffer(
|
||
|
|
device->context,
|
||
|
|
bufferFlags,
|
||
|
|
size, hostPtr,
|
||
|
|
&err);
|
||
|
|
|
||
|
|
if (err != CL_SUCCESS || !buffer)
|
||
|
|
{
|
||
|
|
// Release any buffers already created before throwing
|
||
|
|
for (auto& assoc : associations)
|
||
|
|
{
|
||
|
|
if (assoc.buffer) {
|
||
|
|
clReleaseMemObject(assoc.buffer);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
throw std::runtime_error(
|
||
|
|
std::string(__func__) + ": failed to create buffer for "
|
||
|
|
"device: " + std::to_string(err));
|
||
|
|
}
|
||
|
|
|
||
|
|
associations.emplace_back(buffer, device);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
std::shared_ptr<ClBuffer>
|
||
|
|
ComputeManager::createUseHostPtrBuffer(
|
||
|
|
void* hostPtr, size_t size, cl_mem_flags flags)
|
||
|
|
{
|
||
|
|
if (!initialized)
|
||
|
|
{
|
||
|
|
std::cerr << __func__ << ": ComputeManager not initialized"
|
||
|
|
<< std::endl;
|
||
|
|
throw std::runtime_error(
|
||
|
|
std::string(__func__) + ": ComputeManager not initialized");
|
||
|
|
}
|
||
|
|
|
||
|
|
return std::make_shared<ClBuffer>(hostPtr, size, flags, devices);
|
||
|
|
}
|
||
|
|
|
||
|
|
void ComputeManager::releaseUseHostPtrBuffer(std::shared_ptr<ClBuffer> buffer)
|
||
|
|
{
|
||
|
|
// No-op: ClBuffer's destructor handles cleanup automatically
|
||
|
|
// This function exists for API compatibility
|
||
|
|
(void)buffer;
|
||
|
|
}
|
||
|
|
|
||
|
|
std::shared_ptr<ComputeDevice> ComputeManager::getDevice()
|
||
|
|
{
|
||
|
|
if (!initialized || devices.empty()) {
|
||
|
|
return nullptr;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Return first available device
|
||
|
|
// In the future, this will filter based on ComputeDeviceConstraints
|
||
|
|
return devices[0];
|
||
|
|
}
|
||
|
|
|
||
|
|
void ComputeManager::releaseDevice(std::shared_ptr<ComputeDevice> device)
|
||
|
|
{
|
||
|
|
// Placeholder for future refcounting implementation
|
||
|
|
// Devices are only removed in finalize()
|
||
|
|
(void)device;
|
||
|
|
}
|
||
|
|
|
||
|
|
} // namespace compute
|
||
|
|
} // namespace smo
|