Skip to content

Commit

Permalink
Clean up environment stuff.
Browse files Browse the repository at this point in the history
  • Loading branch information
aarongreig committed Dec 20, 2024
1 parent 5fc0da5 commit 6a1d9ad
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 341 deletions.
317 changes: 6 additions & 311 deletions test/conformance/source/environment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,75 +69,14 @@ AdapterEnvironment::AdapterEnvironment() {

PlatformEnvironment *PlatformEnvironment::instance = nullptr;

constexpr std::pair<const char *, ur_platform_backend_t> backends[] = {
{"LEVEL_ZERO", UR_PLATFORM_BACKEND_LEVEL_ZERO},
{"L0", UR_PLATFORM_BACKEND_LEVEL_ZERO},
{"OPENCL", UR_PLATFORM_BACKEND_OPENCL},
{"CUDA", UR_PLATFORM_BACKEND_CUDA},
{"HIP", UR_PLATFORM_BACKEND_HIP},
{"NATIVE_CPU", UR_PLATFORM_BACKEND_NATIVE_CPU},
{"UNKNOWN", UR_PLATFORM_BACKEND_UNKNOWN},
};

namespace {
/* unused due to platform filtering being commented out
constexpr const char *backend_to_str(ur_platform_backend_t backend) {
for (auto b : backends) {
if (b.second == backend) {
return b.first;
}
}
return "INVALID";
};*/

ur_platform_backend_t str_to_backend(std::string str) {

std::transform(str.begin(), str.end(), str.begin(), ::toupper);
for (auto b : backends) {
if (b.first == str) {
return b.second;
}
}
return UR_PLATFORM_BACKEND_UNKNOWN;
};
} // namespace

std::ostream &operator<<(std::ostream &out,
const std::vector<ur_platform_handle_t> &platforms) {
for (auto platform : platforms) {
out << "\n * \"" << platform << "\"";
}
return out;
}

std::ostream &operator<<(std::ostream &out,
const std::vector<ur_device_handle_t> &devices) {
for (auto device : devices) {
out << "\n * \"" << device << "\"";
}
return out;
}

uur::PlatformEnvironment::PlatformEnvironment(int argc, char **argv)
: AdapterEnvironment(), platform_options{parsePlatformOptions(argc, argv)} {
uur::PlatformEnvironment::PlatformEnvironment()
: AdapterEnvironment() {
instance = this;

// Check for errors from parsing platform options
if (!error.empty()) {
return;
}

selectPlatformFromOptions();
populatePlatforms();
}

void uur::PlatformEnvironment::selectPlatformFromOptions() {
struct platform_info {
ur_adapter_handle_t adapter;
ur_platform_handle_t platform;
std::string name;
ur_platform_backend_t backend;
};
std::vector<platform_info> discovered_platforms;
void uur::PlatformEnvironment::populatePlatforms() {
for (auto a : adapters) {
uint32_t count = 0;
ASSERT_SUCCESS(urPlatformGet(&a, 1, 0, nullptr, &count));
Expand All @@ -147,96 +86,8 @@ void uur::PlatformEnvironment::selectPlatformFromOptions() {

for (auto p : platform_list) {
platforms.push_back(p);
ur_platform_backend_t backend;
ASSERT_SUCCESS(urPlatformGetInfo(p, UR_PLATFORM_INFO_BACKEND,
sizeof(ur_platform_backend_t),
&backend, nullptr));

size_t size;
ASSERT_SUCCESS(
urPlatformGetInfo(p, UR_PLATFORM_INFO_NAME, 0, nullptr, &size));
std::vector<char> platform_name{};
platform_name.reserve(size);
ASSERT_SUCCESS(urPlatformGetInfo(p, UR_PLATFORM_INFO_NAME, size,
platform_name.data(), nullptr));

discovered_platforms.push_back(platform_info{
a, p, std::string(platform_name.data()), backend});
}
}

std::string default_name{};
std::map<ur_platform_backend_t, std::string> backend_platform_names{};
auto stream = std::stringstream{platform_options.platform_name};
for (std::string filter; std::getline(stream, filter, ';');) {
auto split = filter.find(':');
if (split == std::string::npos) {
default_name = filter;
} else if (split == filter.length() - 1) {
// E.g: `OPENCL:`, ignore it
} else {
backend_platform_names.insert(
{str_to_backend(filter.substr(0, split)),
filter.substr(split + 1)});
}
}

std::vector<platform_info> platforms_filtered{};
std::copy_if(discovered_platforms.begin(), discovered_platforms.end(),
std::inserter(platforms_filtered, platforms_filtered.begin()),
[&](platform_info info) {
if (!default_name.empty() && default_name != info.name) {
return false;
}
if (backend_platform_names.count(info.backend) &&
backend_platform_names[info.backend] != info.name) {
return false;
}
if (platform_options.platform_backend &&
platform_options.platform_backend != info.backend) {
return false;
}
return true;
});
/*
if (platforms_filtered.size() == 0) {
std::stringstream errstr;
errstr << "No platforms were found with the following filters:";
if (platform_options.platform_backend) {
errstr << " --backend="
<< backend_to_str(*platform_options.platform_backend);
}
if (!platform_options.platform_name.empty()) {
errstr << " --platform=\"" << platform_options.platform_name
<< "\"";
}
if (!platform_options.platform_backend &&
platform_options.platform_name.empty()) {
errstr << " (none)";
}
errstr << "\nAvailable platforms:\n";
for (auto p : platforms) {
errstr << " --backend=" << backend_to_str(p.backend)
<< " --platform=\"" << p.name << "\"\n";
}
FAIL() << errstr.str();
} else if (platforms_filtered.size() == 1 ||
platform_options.platforms_count == 1) {
auto &selected = platforms_filtered[0];
platform = selected.platform;
adapter = selected.adapter;
std::cerr << "Selected platform: [" << backend_to_str(selected.backend)
<< "] " << selected.name << "\n";
} else if (platforms_filtered.size() > 1) {
std::stringstream errstr;
errstr << "Multiple possible platforms found; please select one of the "
"following or set --platforms_count=1:\n";
for (const auto &p : platforms_filtered) {
errstr << " --backend=" << backend_to_str(p.backend)
<< " --platform=\"" << p.name << "\"\n";
}
FAIL() << errstr.str();
}*/
}

void uur::PlatformEnvironment::SetUp() {
Expand All @@ -261,96 +112,9 @@ void uur::PlatformEnvironment::TearDown() {
}
}

PlatformEnvironment::PlatformOptions
PlatformEnvironment::parsePlatformOptions(int argc, char **argv) {
PlatformOptions options{};
auto parse_backend = [&](std::string backend_string) {
options.platform_backend = str_to_backend(backend_string);
if (options.platform_backend == UR_PLATFORM_BACKEND_UNKNOWN) {
std::stringstream errstr{error};
errstr << "--backend not valid; expected one of [";
bool first = true;
for (auto b : backends) {
if (!first) {
errstr << ", ";
}
errstr << b.first;
first = false;
}
errstr << "], but got `" << backend_string << "`";
error = errstr.str();
return false;
}
return true;
};

for (int argi = 1; argi < argc; ++argi) {
const char *arg = argv[argi];
if (!(std::strcmp(arg, "-h") && std::strcmp(arg, "--help"))) {
// TODO - print help
break;
} else if (std::strncmp(
arg, "--platform=", sizeof("--platform=") - 1) == 0) {
options.platform_name =
std::string(&arg[std::strlen("--platform=")]);
} else if (std::strncmp(arg, "--backend=", sizeof("--backend=") - 1) ==
0) {
std::string backend_string{&arg[std::strlen("--backend=")]};
if (!parse_backend(std::move(backend_string))) {
return options;
}
} else if (std::strncmp(arg, "--platforms_count=",
sizeof("--platforms_count=") - 1) == 0) {
options.platforms_count = std::strtoul(
&arg[std::strlen("--platforms_count=")], nullptr, 10);
}
}

/* If a platform was not provided using the --platform/--backend command line options,
* check if environment variable is set to use as a fallback. */
if (options.platform_name.empty()) {
auto env_platform = ur_getenv("UR_CTS_ADAPTER_PLATFORM");
if (env_platform.has_value()) {
options.platform_name = env_platform.value();
}
}
if (!options.platform_backend) {
auto env_backend = ur_getenv("UR_CTS_BACKEND");
if (env_backend.has_value()) {
if (!parse_backend(env_backend.value())) {
return options;
}
}
}

return options;
}

DevicesEnvironment::DeviceOptions
DevicesEnvironment::parseDeviceOptions(int argc, char **argv) {
DeviceOptions options{};
for (int argi = 1; argi < argc; ++argi) {
const char *arg = argv[argi];
if (!(std::strcmp(arg, "-h") && std::strcmp(arg, "--help"))) {
// TODO - print help
break;
} else if (std::strncmp(arg, "--device=", sizeof("--device=") - 1) ==
0) {
options.device_name = std::string(&arg[std::strlen("--device=")]);
} else if (std::strncmp(arg, "--devices_count=",
sizeof("--devices_count=") - 1) == 0) {
options.devices_count = std::strtoul(
&arg[std::strlen("--devices_count=")], nullptr, 10);
}
}
return options;
}

DevicesEnvironment *DevicesEnvironment::instance = nullptr;

DevicesEnvironment::DevicesEnvironment(int argc, char **argv)
: PlatformEnvironment(argc, argv),
device_options(parseDeviceOptions(argc, argv)) {
DevicesEnvironment::DevicesEnvironment() : PlatformEnvironment() {
instance = this;
if (!error.empty()) {
return;
Expand All @@ -376,65 +140,6 @@ DevicesEnvironment::DevicesEnvironment(int argc, char **argv)
error = "Could not find any devices to test";
return;
}
// Get the argument (devices_count) to limit test devices count.
// In case, the devices_count is "0", the variable count will not be changed.
// The CTS will run on all devices.
/* filter devices with options after accumulating them all
if (device_options.device_name.empty()) {
if (device_options.devices_count >
(std::numeric_limits<uint32_t>::max)()) {
error = "Invalid devices_count argument";
return;
} else if (device_options.devices_count > 0) {
count = (std::min)(
count, static_cast<uint32_t>(device_options.devices_count));
}
devices.resize(count);
if (urDeviceGet(platform, UR_DEVICE_TYPE_ALL, count, devices.data(),
nullptr)) {
error = "urDeviceGet() failed to get devices.";
return;
}
} else {
devices.resize(count);
if (urDeviceGet(platform, UR_DEVICE_TYPE_ALL, count, devices.data(),
nullptr)) {
error = "urDeviceGet() failed to get devices.";
return;
}
for (unsigned i = 0; i < count; i++) {
size_t size;
if (urDeviceGetInfo(devices[i], UR_DEVICE_INFO_NAME, 0, nullptr,
&size)) {
error = "urDeviceGetInfo() failed";
return;
}
std::vector<char> device_name(size);
if (urDeviceGetInfo(devices[i], UR_DEVICE_INFO_NAME, size,
device_name.data(), nullptr)) {
error = "urDeviceGetInfo() failed";
return;
}
if (device_options.device_name == device_name.data()) {
device = devices[i];
devices.clear();
devices.resize(1);
devices[0] = device;
break;
}
}
if (!device) {
std::stringstream ss_error;
ss_error << "Device \"" << device_options.device_name
<< "\" not found. Select a single device from below "
"using the "
"--device=NAME command-line options:"
<< devices << std::endl
<< "or set --devices_count=COUNT.";
error = ss_error.str();
return;
}
}*/
}

void DevicesEnvironment::SetUp() {
Expand All @@ -461,7 +166,7 @@ KernelsEnvironment *KernelsEnvironment::instance = nullptr;

KernelsEnvironment::KernelsEnvironment(int argc, char **argv,
const std::string &kernels_default_dir)
: DevicesEnvironment(argc, argv),
: DevicesEnvironment(),
kernel_options(parseKernelOptions(argc, argv, kernels_default_dir)) {
instance = this;
if (!error.empty()) {
Expand Down Expand Up @@ -577,16 +282,6 @@ void KernelsEnvironment::LoadSource(
cached_kernels[kernel_name] = binary_ptr;
binary_out = binary_ptr;
}
/*
void LoadSource(const std::string &kernel_name,
ur_device_handle_t device,
std::shared_ptr<std::vector<char>> &binary_out) {
ur_platform_handle_t platform = nullptr;
if(urDeviceGetInfo(device, UR_DEVICE_INFO_PLATFORM, sizeof(ur_platform_handle_t), &platform, nullptr)) {
FAIL() << "Failed to retrieve platform from device";
}
LoadSource(kernel_name, platform, binary_out);
}*/

ur_result_t KernelsEnvironment::CreateProgram(
ur_platform_handle_t hPlatform, ur_context_handle_t hContext,
Expand Down
4 changes: 2 additions & 2 deletions test/conformance/source/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ int main(int argc, char **argv) {
auto *environment =
new uur::KernelsEnvironment(argc, argv, KERNELS_DEFAULT_DIR);
#elif DEVICES_ENVIRONMENT
auto *environment = new uur::DevicesEnvironment(argc, argv);
auto *environment = new uur::DevicesEnvironment();
#elif PLATFORM_ENVIRONMENT
auto *environment = new uur::PlatformEnvironment(argc, argv);
auto *environment = new uur::PlatformEnvironment();
#else
auto *environment = new uur::AdapterEnvironment();
#endif
Expand Down
Loading

0 comments on commit 6a1d9ad

Please sign in to comment.