148 lines
3.6 KiB
C++
148 lines
3.6 KiB
C++
#include <user/compute.h>
|
|
#include <stdexcept>
|
|
#include <string>
|
|
#include <vector>
|
|
#include <iostream>
|
|
#include <string_view>
|
|
|
|
namespace smo {
|
|
namespace compute {
|
|
|
|
// Helper function to parse OpenCL version string
|
|
static std::pair<int, int> parseOpenClVersion(const std::string& versionStr)
|
|
{
|
|
size_t spacePos = versionStr.find(' ');
|
|
if (spacePos == std::string::npos) { return {-1, -1}; }
|
|
|
|
std::string versionNum = versionStr.substr(spacePos + 1);
|
|
size_t dotPos = versionNum.find('.');
|
|
if (dotPos == std::string::npos) { return {-1, -1}; }
|
|
|
|
try {
|
|
int major = std::stoi(versionNum.substr(0, dotPos));
|
|
int minor = std::stoi(versionNum.substr(dotPos + 1));
|
|
return {major, minor};
|
|
} catch (const std::exception&) {
|
|
return {-1, -1};
|
|
}
|
|
}
|
|
|
|
// Implementation of validateOpenClVersion (declared in user/compute.h)
|
|
bool validateOpenClVersion(
|
|
std::string_view versionStr, std::string_view versionType,
|
|
int minMajor, int minMinor)
|
|
{
|
|
auto [major, minor] = parseOpenClVersion(std::string(versionStr));
|
|
|
|
if (major == -1 && minor == -1)
|
|
{
|
|
std::cerr << __func__ << ": failed to parse OpenCL " << versionType
|
|
<< " version: " << versionStr << std::endl;
|
|
return false;
|
|
}
|
|
|
|
if (major < minMajor || (major == minMajor && minor < minMinor))
|
|
{
|
|
std::cerr << __func__ << ": OpenCL " << versionType << " version "
|
|
<< major << "." << minor << " found, but " << minMajor << "."
|
|
<< minMinor << " or higher is required" << std::endl;
|
|
return false;
|
|
}
|
|
|
|
std::cout << __func__ << ": OpenCL " << versionType << " version: "
|
|
<< versionStr << std::endl;
|
|
return true;
|
|
}
|
|
|
|
ComputeDevice::ComputeDevice(cl_platform_id platformId, cl_device_id deviceId)
|
|
: platform(platformId), device(deviceId),
|
|
context(nullptr), commandQueue(nullptr)
|
|
{
|
|
cl_int err;
|
|
|
|
// Create context for this device
|
|
context = clCreateContext(
|
|
nullptr, 1, &device,
|
|
nullptr, nullptr, &err);
|
|
|
|
if (err != CL_SUCCESS || !context)
|
|
{
|
|
throw std::runtime_error(
|
|
std::string(__func__) + ": failed to create context for device: " +
|
|
std::to_string(err));
|
|
}
|
|
|
|
// Create command queue
|
|
cl_command_queue_properties queueProps = 0;
|
|
commandQueue = clCreateCommandQueue(
|
|
context, device, queueProps, &err);
|
|
|
|
if (err != CL_SUCCESS || !commandQueue)
|
|
{
|
|
clReleaseContext(context);
|
|
context = nullptr;
|
|
throw std::runtime_error(
|
|
std::string(__func__) + ": failed to create command queue for "
|
|
"device: " + std::to_string(err));
|
|
}
|
|
}
|
|
|
|
ClBuffer::ClBuffer(void* hostPtr, size_t size, cl_mem_flags flags,
|
|
const std::vector<std::shared_ptr<ComputeDevice>>& devices)
|
|
: hostPtr(hostPtr), size(size), flags(flags)
|
|
{
|
|
associations.reserve(devices.size());
|
|
|
|
// Create a buffer for each device's context
|
|
for (const auto& device : devices)
|
|
{
|
|
if (!device->context) { continue; }
|
|
|
|
cl_int err;
|
|
cl_mem_flags bufferFlags = CL_MEM_USE_HOST_PTR | flags;
|
|
cl_mem buffer = clCreateBuffer(
|
|
device->context,
|
|
bufferFlags,
|
|
size, hostPtr,
|
|
&err);
|
|
|
|
if (err != CL_SUCCESS || !buffer)
|
|
{
|
|
// Release any buffers already created before throwing
|
|
for (auto& assoc : associations)
|
|
{
|
|
if (assoc.buffer) {
|
|
clReleaseMemObject(assoc.buffer);
|
|
}
|
|
}
|
|
throw std::runtime_error(
|
|
std::string(__func__) + ": failed to create buffer for "
|
|
"device: " + std::to_string(err));
|
|
}
|
|
|
|
associations.emplace_back(buffer, device);
|
|
}
|
|
}
|
|
|
|
cl_mem ClBuffer::getAssociatedBufferHandleForDevice(
|
|
const std::shared_ptr<ComputeDevice>& device) const
|
|
{
|
|
if (!device)
|
|
{
|
|
throw std::invalid_argument(std::string(__func__)
|
|
+ ": device is nullptr");
|
|
}
|
|
|
|
for (const auto& assoc : associations)
|
|
{
|
|
if (assoc.device == device) {
|
|
return assoc.buffer;
|
|
}
|
|
}
|
|
|
|
return nullptr;
|
|
}
|
|
|
|
} // namespace compute
|
|
} // namespace smo
|