From ae38aabe2b460c9c9a3b50df8383cc9b906f2457 Mon Sep 17 00:00:00 2001 From: Matthieu Dorier Date: Thu, 1 Feb 2024 17:04:47 +0000 Subject: [PATCH] started adding tests for dependency finder --- include/bedrock/DependencyFinder.hpp | 30 ++-- include/bedrock/module.h | 29 ++-- python/mochi/bedrock/test_dependencies.py | 150 +++++++++++++++++++ python/mochi/bedrock/test_service_handle.py | 2 +- src/ClientManager.cpp | 6 +- src/DependencyFinder.cpp | 153 ++++++++------------ src/ProviderManager.cpp | 6 +- tests/Client.cpp | 6 +- tests/modules/ModuleC.cpp | 61 ++++++++ 9 files changed, 318 insertions(+), 125 deletions(-) create mode 100644 python/mochi/bedrock/test_dependencies.py create mode 100644 tests/modules/ModuleC.cpp diff --git a/include/bedrock/DependencyFinder.hpp b/include/bedrock/DependencyFinder.hpp index e1a2335..5c6a794 100644 --- a/include/bedrock/DependencyFinder.hpp +++ b/include/bedrock/DependencyFinder.hpp @@ -81,27 +81,38 @@ class DependencyFinder { /** * @brief Resolve a specification, returning a void* handle to it. * This function throws an exception if the specification could not - * be resolved. A specification string follows the following grammar: + * be resolved. A specification is either a name, or a string follows + * the following grammar: * * SPEC := IDENTIFIER * | IDENTIFIER '@' LOCATION - * IDENTIFIER := NAME + * IDENTIFIER := SPECIFIER + * | NAME '->' SPECIFIER + * SPECIFIER := NAME * | TYPE ':' ID * LOCATION := ADDRESS - * | 'ssg://' GROUP '/' RANK + * | 'ssg://' NAME '/' RANK * ADDRESS := * NAME := * ID := * + * + * For instance, "abc" represents the name "abc". + * "abc:123" represents a provider of type "abc" with + * provider id 123. "abc->def@address" represents a provider handle + * created from client named "abc", pointing to a provider named "def" + * at address "address". + * * @param [in] type Type of dependency. + * @param [in] kind Kind of dependency (BEDROCK_KIND_*). * @param [in] spec Specification string. * @param [out] Resolved specification. * * @return handle to dependency */ std::shared_ptr - find(const std::string& type, const std::string& spec, - std::string* resolved) const; + find(const std::string& type, int32_t kind, + const std::string& spec, std::string* resolved) const; /** * @brief Find a local provider based on a type and provider id. @@ -145,15 +156,6 @@ class DependencyFinder { std::shared_ptr findClient( const std::string& type, const std::string& name) const; - /** - * @brief Get an admin of a given type. - * - * @param type Type of admin. - * - * @return An abstract pointer to the dependency. - */ - std::shared_ptr getAdmin(const std::string& type) const; - /** * @brief Make a provider handle to a specified provider. * Throws an exception if no provider was found with this diff --git a/include/bedrock/module.h b/include/bedrock/module.h index 3e3d74f..eda97f3 100644 --- a/include/bedrock/module.h +++ b/include/bedrock/module.h @@ -20,6 +20,12 @@ extern "C" { #define BEDROCK_REQUIRED 0x1 #define BEDROCK_ARRAY 0x2 +#define BEDROCK_KIND_CLIENT (0x1 << 2) +#define BEDROCK_KIND_PROVIDER_HANDLE (0x2 << 2) +#define BEDROCK_KIND_PROVIDER (0x3 << 2) + +#define BEDROCK_GET_KIND_FROM_FLAG(__flag__) (__flag__ & ~0b11) + typedef struct bedrock_args* bedrock_args_t; #define BEDROCK_ARGS_NULL ((bedrock_args_t)NULL) @@ -36,15 +42,20 @@ typedef void* bedrock_module_client_t; * a module dependency. The name correspondings to the name * of the dependency in the module configuration. The type * corresponds to the type of dependency (name of other modules - * from which the dependency comes from). The flags field - * allows parameterizing the dependency. It should be an or-ed + * from which the dependency comes from). + * + * The flags field allows parameterizing the dependency. It should be an or-ed * value from BEDROCK_REQUIRED (this dependency is required) * and BEDROCK_ARRAY (this dependency is an array). Note that * BEDROCK_REQUIRED | BEDROCK_ARRAY indicates that the array * should contain at least 1 entry. * + * The flag should be or-ed with one of the BEDROCK_KIND_* + * values to specify the kind of dependency that is expected + * if the dependency is from a module (client, provider, or provider handle). + * * For example, the following bedrock_dependency - * { "storage", "bake", BEDROCK_REQUIRED | BEDROCK_ARRAY } + * { "storage", "bake", BEDROCK_REQUIRED | BEDROCK_ARRAY | BEDROCK_KIND_PROVIDER_HANDLE } * indicates that a provider for this module requires to be * created with a "dependencies" section in its JSON looking * like the following: @@ -52,17 +63,17 @@ typedef void* bedrock_module_client_t; * "storage" : [ "bake:34@na+sm://1234", ... ] * } * that is, a "storage" key is expected (name = "storage"), - * and it will resolve to an array of bake constructs - * (e.g. providers or provider handles). + * and it will resolve to an array of at least one bake (type = "bake") + * provider handles (flags has BEDROCK_KIND_PROVIDER_HANDLE). */ struct bedrock_dependency { - const char* name; - const char* type; - int32_t flags; + const char* name; + const char* type; + int32_t flags; }; #define BEDROCK_NO_MORE_DEPENDENCIES \ - { NULL, NULL, 0 } + { NULL, NULL, 0} /** * @brief Type of function called to register a provider. diff --git a/python/mochi/bedrock/test_dependencies.py b/python/mochi/bedrock/test_dependencies.py new file mode 100644 index 0000000..ae01c0e --- /dev/null +++ b/python/mochi/bedrock/test_dependencies.py @@ -0,0 +1,150 @@ +import unittest +import pymargo.logging +import mochi.bedrock.server as mbs +import mochi.bedrock.spec as spec + + +class TestProviderManager(unittest.TestCase): + + def setUp(self): + config = { + "libraries": { + "module_a": "libModuleA.so", + "module_b": "libModuleB.so", + "module_c": "libModuleC.so", + }, + "providers": [ + { + "name": "my_provider_a", + "type": "module_a", + "provider_id": 1 + }, + { + "name": "my_provider_b", + "type": "module_b", + "provider_id": 2 + } + ], + "clients": [ + { + "name": "my_client_a", + "type": "module_a" + }, + { + "name": "my_client_b", + "type": "module_b" + } + ] + } + self.server = mbs.Server(address="na+sm", config=config) + self.server.margo.engine.logger.set_log_level(pymargo.logging.level.critical) + + def tearDown(self): + self.server.finalize() + del self.server + + def make_client_params(self, expected_dependencies: dict={}): + params = { + "name": "my_provider_C", + "type": "module_c", + "config": { + "expected_client_dependencies": expected_dependencies + } + } + return params + + def make_provider_params(self, expected_dependencies: dict={}): + params = self.make_client_params({}) + params["provider_id"] = 3 + params["pool"] = "__primary__" + params["config"]["expected_provider_dependencies"] = expected_dependencies + return params + + + def test_no_dependency(self): + providers = self.server.providers + self.assertEqual(len(providers), 2) + clients = self.server.clients + self.assertEqual(len(clients), 2) + + client_params = self.make_client_params() + clients.create(**client_params) + + provider_params = self.make_provider_params() + providers.create(**provider_params) + + self.assertEqual(len(providers), 3) + self.assertEqual(len(clients), 3) + + def test_optional_dependency(self): + providers = self.server.providers + self.assertEqual(len(providers), 2) + clients = self.server.clients + self.assertEqual(len(clients), 2) + + client_params = self.make_client_params([ + {"name": "dep1", + "type": "module_a", + "kind": "provider_handle", + "is_array": False, + "is_required": False, + }]) + clients.create(**client_params) + + provider_params = self.make_provider_params([ + {"name": "dep1", + "type": "module_a", + "kind": "provider_handle", + "is_array": False, + "is_required": False, + }]) + providers.create(**provider_params) + + self.assertEqual(len(providers), 3) + self.assertEqual(len(clients), 3) + + def test_required_dependency(self): + providers = self.server.providers + self.assertEqual(len(providers), 2) + clients = self.server.clients + self.assertEqual(len(clients), 2) + + # Try creating a client without the required dependency + client_params = self.make_client_params([ + {"name": "dep1", + "type": "module_a", + "kind": "provider_handle", + "is_array": False, + "is_required": True, + }]) + with self.assertRaises(mbs.BedrockException): + clients.create(**client_params) + + # Try creating a client with the required dependency + client_params["dependencies"] = { + "dep1": "my_provider_a@local" + } + clients.create(**client_params) + self.assertEqual(len(clients), 3) + + # Try creating a provider without the required dependency + provider_params = self.make_provider_params([ + {"name": "dep1", + "type": "module_a", + "kind": "provider_handle", + "is_array": False, + "is_required": True, + }]) + with self.assertRaises(mbs.BedrockException): + providers.create(**provider_params) + + # Try creating a provider with the required dependency + provider_params["dependencies"] = { + "dep1": "my_provider_a@local" + } + providers.create(**provider_params) + self.assertEqual(len(providers), 3) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/mochi/bedrock/test_service_handle.py b/python/mochi/bedrock/test_service_handle.py index d2a64b8..16c2d49 100644 --- a/python/mochi/bedrock/test_service_handle.py +++ b/python/mochi/bedrock/test_service_handle.py @@ -70,7 +70,7 @@ def test_load_module(self): self.sh.load_module("module_a", "libModuleA.so") self.sh.load_module("module_b", "libModuleB.so") with self.assertRaises(mbc.ClientException): - self.sh.load_module("module_c", "libModuleC.so") + self.sh.load_module("module_x", "libModuleX.so") def add_pool(self, config): initial_num_pools = len(self.server.margo.pools) diff --git a/src/ClientManager.cpp b/src/ClientManager.cpp index 7222444..0ab5957 100644 --- a/src/ClientManager.cpp +++ b/src/ClientManager.cpp @@ -251,7 +251,8 @@ ClientManager::addClientFromJSON(const std::string& jsonString) { throw DETAILED_EXCEPTION("Dependency {} should be a string", dependency.name); } - auto ptr = dependencyFinder.find(dependency.type, + auto ptr = dependencyFinder.find( + dependency.type, BEDROCK_GET_KIND_FROM_FLAG(dependency.flags), dep_config.get(), nullptr); resolved_dependency_map[dependency.name].dependencies.push_back(ptr); resolved_dependency_map[dependency.name].is_array = false; @@ -267,7 +268,8 @@ ClientManager::addClientFromJSON(const std::string& jsonString) { "Item in dependency array {} should be a string", dependency.name); } - auto ptr = dependencyFinder.find(dependency.type, + auto ptr = dependencyFinder.find( + dependency.type, BEDROCK_GET_KIND_FROM_FLAG(dependency.flags), elem.get(), nullptr); resolved_dependency_map[dependency.name].dependencies.push_back(ptr); resolved_dependency_map[dependency.name].is_array = true; diff --git a/src/DependencyFinder.cpp b/src/DependencyFinder.cpp index fd31ab0..5dc48e3 100644 --- a/src/DependencyFinder.cpp +++ b/src/DependencyFinder.cpp @@ -44,16 +44,9 @@ DependencyFinder::~DependencyFinder() = default; DependencyFinder::operator bool() const { return static_cast(self); } -static bool isPositiveNumber(const std::string& str) { - if (str.empty()) return false; - for (unsigned i = 0; i < str.size(); i++) - if (!isdigit(str[i])) return false; - return true; -} - std::shared_ptr DependencyFinder::find( - const std::string& type, const std::string& spec, - std::string* resolved) const { + const std::string& type, int32_t kind, + const std::string& spec, std::string* resolved) const { spdlog::trace("DependencyFinder search for {} of type {}", spec, type); if (type == "pool") { // Argobots pool @@ -90,9 +83,9 @@ std::shared_ptr DependencyFinder::find( return mona_id; } else if (type == "ssg") { // SSG group - // + auto ssg_manager_impl = self->m_ssg_context.lock(); - if(!ssg_manager_impl) { + if (!ssg_manager_impl) { throw Exception("Could not resolve SSG dependency: no SSGManager found"); } auto group = SSGManager(ssg_manager_impl).getGroup(spec); @@ -102,98 +95,70 @@ std::shared_ptr DependencyFinder::find( if (resolved) { *resolved = spec; } return group; - } else { // Provider or provider handle + } else if (kind == BEDROCK_KIND_CLIENT) { + + auto client = findClient(type, spec); + if (client) { *resolved = client->getName(); } + return client; + + } else if (kind == BEDROCK_KIND_PROVIDER) { + + // the spec can be in the form "name" or "type:id" + std::regex re( + "([a-zA-Z_][a-zA-Z0-9_]*)" // identifier (name or type) + "(?::([0-9]+))?"); // specifier ("client" or provider id) + std::smatch match; + if (!std::regex_search(spec, match, re) || match.str(0) != spec) { + throw Exception("Ill-formated dependency specification \"{}\"", spec); + } + auto identifier = match.str(1); // name or type + auto provider_id_str = match.str(2); // provider id + + if(provider_id_str.empty()) { // identifier is a name + uint16_t provider_id; + auto ptr = findProvider(type, identifier, &provider_id); + if (resolved) *resolved = type + ":" + std::to_string(provider_id); + return ptr; + } else { + uint16_t provider_id = std::atoi(provider_id_str.c_str()); + auto ptr = findProvider(type, provider_id); + if (resolved) *resolved = type + ":" + std::to_string(provider_id); + return ptr; + } + + } else { // Provider handle std::regex re( "(?:([a-zA-Z_][a-zA-Z0-9_]*)\\->)?" // client name (name->) "([a-zA-Z_][a-zA-Z0-9_]*)" // identifier (name or type) - "(?::(client|[0-9]+))?" // specifier ("client" or provider id) - "(?:@(.+))?"); // locator (@address) + "(?::([0-9]+))?" // optional provider id + "(?:@(.+))?"); // optional locator (@address) std::smatch match; - if (std::regex_search(spec, match, re)) { - if (match.str(0) != spec) { - throw Exception("Ill-formated dependency specification \"{}\"", - spec); - } - auto client_name = match.str(1); - auto identifier = match.str(2); // name or type - auto specifier - = match.str(3); // provider id, or "client" or "admin" - auto locator = match.str(4); // address or "local" + if (!std::regex_search(spec, match, re) || match.str(0) != spec) { + throw Exception("Ill-formated dependency specification \"{}\"", spec); + } - if (locator.empty() && !client_name.empty()) { - throw Exception( - "Client name (\"{}\") specified in dependency that is not " - "a provider handle", - client_name); - } + auto client_name = match.str(1); // client to use for provider handles + auto identifier = match.str(2); // name or type + auto provider_id_str = match.str(3); // provider id + auto locator = match.str(4); // address or "local" + if(locator.empty()) locator = "local"; - if (locator.empty()) { // local dependency to a provider or a client - // or admin - if (specifier == "client") { // requesting a client - auto client = findClient(type, ""); - if (client) { *resolved = client->getName(); } - return client; - } else if (isPositiveNumber( - specifier)) { // dependency to a provider - // specified by type:id - uint16_t provider_id = atoi(specifier.c_str()); - if (type != identifier) { - throw Exception( - "Invalid provider type in \"{}\" (expected {})", - spec, type); - } - if (resolved) { *resolved = spec; } - return findProvider(type, provider_id); - } else { // dependency to a provider specified by name, or a - // client - try { - try { - auto ptr = findClient(type, identifier); - if (resolved) { *resolved = identifier; } - return ptr; - } catch (const Exception&) { - // didn't fine a client, try a provider - uint16_t provider_id; - auto ptr - = findProvider(type, identifier, &provider_id); - if (resolved) { - *resolved - = type + ":" + std::to_string(provider_id); - } - return ptr; - } - } catch (const Exception&) { - throw Exception( - "Could not find client or provider with " - "specification \"{}\"", - spec); - } - } - } else { // dependency to a provider handle - if (specifier.empty()) { - // dependency specified as name@location - return makeProviderHandle(client_name, type, identifier, - locator, resolved); - } else if (isPositiveNumber(specifier)) { - // dependency specified as type:id@location - uint16_t provider_id = atoi(specifier.c_str()); - if (type != identifier) { - throw Exception( - "Invalid provider type in \"{}\" (expected {})", - spec, type); - } - return makeProviderHandle(client_name, type, provider_id, - locator, resolved); - } else { // invalid - throw Exception( - "Ill-formated dependency specification \"{}\"", spec); - } - } + if (provider_id_str.empty()) { + // dependency specified as client->name@location + return makeProviderHandle( + client_name, type, identifier, locator, resolved); } else { - throw Exception("Ill-formated dependency specification \"{}\"", - spec); + // dependency specified as client->type:id@location + uint16_t provider_id = atoi(provider_id_str.c_str()); + if (type != identifier) { + throw Exception( + "Invalid provider type in \"{}\" (expected {})", + spec, type); + } + return makeProviderHandle( + client_name, type, provider_id, locator, resolved); } } return nullptr; diff --git a/src/ProviderManager.cpp b/src/ProviderManager.cpp index b074e61..1fd02f8 100644 --- a/src/ProviderManager.cpp +++ b/src/ProviderManager.cpp @@ -253,7 +253,8 @@ ProviderManager::addProviderFromJSON(const std::string& jsonString) { "Dependency \"{}\" should be a string", dependency.name); } auto dep_handle = dependencyFinder.find( - dependency.type, dep_config.get(), nullptr); + dependency.type, BEDROCK_GET_KIND_FROM_FLAG(dependency.flags), + dep_config.get(), nullptr); resolved_dependency_map[dependency.name].is_array = false; resolved_dependency_map[dependency.name].dependencies.push_back(dep_handle); @@ -273,7 +274,8 @@ ProviderManager::addProviderFromJSON(const std::string& jsonString) { dependency.name); } auto dep_handle = dependencyFinder.find( - dependency.type, elem.get(), nullptr); + dependency.type, BEDROCK_GET_KIND_FROM_FLAG(dependency.flags), + elem.get(), nullptr); resolved_dependency_map[dependency.name].is_array = true; resolved_dependency_map[dependency.name].dependencies.push_back(dep_handle); } diff --git a/tests/Client.cpp b/tests/Client.cpp index addfcac..5ac31fd 100644 --- a/tests/Client.cpp +++ b/tests/Client.cpp @@ -141,7 +141,7 @@ TEST_CASE("Tests various object creation and removal via a ServiceHandle", "[ser REQUIRE(bedrock::ModuleContext::getServiceFactory("module_b") != nullptr); // load libModuleC.so, which does not exist REQUIRE_THROWS_AS( - serviceHandle.loadModule("module_c", "libModuleC.so"), + serviceHandle.loadModule("module_x", "libModuleX.so"), bedrock::Exception); } @@ -194,11 +194,11 @@ TEST_CASE("Tests various object creation and removal via a ServiceHandle", "[ser // create a provider of an invalid type REQUIRE_THROWS_AS( - serviceHandle.startProvider("my_provider_c", "module_c", 234), + serviceHandle.startProvider("my_provider_x", "module_x", 234), bedrock::Exception); // create a provider of an invalid type asynchronously serviceHandle.startProvider( - "my_provider_c", "module_c", 234, nullptr, + "my_provider_x", "module_x", 234, nullptr, "", "{}", bedrock::DependencyMap(), {}, &req); REQUIRE_THROWS_AS(req.wait(), bedrock::Exception); } diff --git a/tests/modules/ModuleC.cpp b/tests/modules/ModuleC.cpp new file mode 100644 index 0000000..0796604 --- /dev/null +++ b/tests/modules/ModuleC.cpp @@ -0,0 +1,61 @@ +#include "BaseModule.hpp" +#include + +using json = nlohmann::json; + +class ModuleCServiceFactory : public BaseServiceFactory { + + static inline std::vector extractDependencies( + const json& expected_dependencies) { + if(!expected_dependencies.is_array()) + return {}; + std::vector dependencies; + for(const auto& dep : expected_dependencies) { + if(!dep.is_object()) continue; + if(!dep.contains("name") || !dep.contains("type")) + continue; + int32_t flags = 0; + auto name = dep["name"].get(); + auto type = dep["type"].get(); + auto kind = dep.value("kind", std::string{""}); + if(kind == "client") { + flags |= BEDROCK_KIND_CLIENT; + } + if(kind == "provider") { + flags |= BEDROCK_KIND_PROVIDER; + } + if(kind == "provider_handle") { + flags |= BEDROCK_KIND_PROVIDER_HANDLE; + } + if(dep.value("is_required", false)) { + flags |= BEDROCK_REQUIRED; + } + if(dep.value("is_array", false)) { + flags |= BEDROCK_ARRAY; + } + dependencies.push_back(bedrock::Dependency{name, type, flags}); + } + return dependencies; + } + + std::vector getProviderDependencies(const char* cfg) override { + auto config = json::parse(cfg); + std::vector dependencies; + if(!config.contains("expected_provider_dependencies")) + return dependencies; + auto& expected_dependencies = config["expected_provider_dependencies"]; + return extractDependencies(expected_dependencies); + } + + std::vector getClientDependencies(const char* cfg) override { + auto config = json::parse(cfg); + std::vector dependencies; + if(!config.contains("expected_client_dependencies")) + return dependencies; + auto& expected_dependencies = config["expected_client_dependencies"]; + return extractDependencies(expected_dependencies); + } + +}; + +BEDROCK_REGISTER_MODULE_FACTORY(module_c, ModuleCServiceFactory)