Skip to content

Commit

Permalink
added address resolution from MPI rank
Browse files Browse the repository at this point in the history
  • Loading branch information
mdorier committed Jul 22, 2024
1 parent 17307d3 commit ff1d496
Show file tree
Hide file tree
Showing 15 changed files with 315 additions and 120 deletions.
6 changes: 4 additions & 2 deletions include/bedrock/DependencyFinder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <bedrock/SSGManager.hpp>
#include <bedrock/ProviderManager.hpp>
#include <bedrock/ClientManager.hpp>
#include <bedrock/MPIEnv.hpp>
#include <string>
#include <memory>

Expand All @@ -36,14 +37,15 @@ class DependencyFinder {
public:
/**
* @brief Constructor.
*
* @param mpi MPI context
* @param margo Margo context
* @param abtio ABT-IO context
* @param ssg SSG context
* @param pmanager Provider manager
* @param cmanager Client manager
*/
DependencyFinder(const MargoManager& margo, const ABTioManager& abtio,
DependencyFinder(const MPIEnv& mpi,
const MargoManager& margo, const ABTioManager& abtio,
const SSGManager& ssg, const MonaManager& mona,
const ProviderManager& pmanager,
const ClientManager& cmanager);
Expand Down
3 changes: 2 additions & 1 deletion include/bedrock/Jx9Manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <vector>
#include <unordered_map>
#include <memory>
#include <bedrock/MPIEnv.hpp>

namespace bedrock {

Expand Down Expand Up @@ -41,7 +42,7 @@ class Jx9Manager {
/**
* @brief Constructor.
*/
Jx9Manager();
Jx9Manager(MPIEnv mpiEnv);

/**
* @brief Copy-constructor.
Expand Down
80 changes: 80 additions & 0 deletions include/bedrock/MPIEnv.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* (C) 2024 The University of Chicago
*
* See COPYRIGHT in top-level directory.
*/
#ifndef __BEDROCK_MPI_H
#define __BEDROCK_MPI_H

#include <bedrock/Exception.hpp>
#include <thallium.hpp>

namespace bedrock {

class DependencyFinder;
class Server;
struct MPIEnvImpl;

class MPIEnv {

friend class Server;
friend class DependencyFinder;

public:

MPIEnv(const MPIEnv&) = default;
MPIEnv(MPIEnv&&) = default;
MPIEnv& operator=(const MPIEnv&) = default;
MPIEnv& operator=(MPIEnv&&) = default;

/**
* @brief The last call to the destructor will finalize MPI if it has been
* initialized by Bedrock.
*/
~MPIEnv();

/**
* @brief Returns true if Bedrock was built with MPI support.
* If it hasn't, all the methods bellow will throw an Exception.
*
* @param engine Thallium engine.
*/
bool isEnabled() const;

/**
* @brief Return the size of MPI_COMM_WORLD.
*/
int globalSize() const;

/**
* @brief Return the rank of the current process in MPI_COMM_WORLD.
*/
int globalRank() const;

/**
* @brief Return the Mercury address of the given rank.
*
* @param rank Rank of the process.
*/
const std::string& addressOfRank(int rank) const;

private:

/**
* @brief This constructor will initialize MPI if it hasn't been initialized.
*/
MPIEnv();

MPIEnv(std::shared_ptr<MPIEnvImpl> s)
: self(std::move(s)) {}

operator std::shared_ptr<MPIEnvImpl>() const {
return self;
}

std::shared_ptr<MPIEnvImpl> self;
};

} // namespace bedrock

#endif
53 changes: 53 additions & 0 deletions python/mochi/bedrock/test_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,59 @@ def test_dependency_on_ph_with_id_and_address(self):
providers.create(**provider_params)
self.assertEqual(len(providers), 3)

def test_dependency_on_ph_with_id_and_rank(self):
providers = self.server.providers
self.assertEqual(len(providers), 2)
clients = self.server.clients
self.assertEqual(len(clients), 2)

# Get the rank of the process to use instead of "local"
rank = 0

# Try creating a client without the required dependency
client_params = self.make_client_params([
{"name": "dep1", "type": "module_a",
"is_required": True, "kind": "provider_handle"}])
with self.assertRaises(mbs.BedrockException):
clients.create(**client_params)

# Try creating a client with the wrong provider ID
client_params["dependencies"] = {"dep1": "module_a:999@{rank}"}
with self.assertRaises(mbs.BedrockException):
clients.create(**client_params)

# Try creating a client with the wrong module
client_params["dependencies"] = {"dep1": "module_b:1@{rank}"}
with self.assertRaises(mbs.BedrockException):
clients.create(**client_params)

# Try creating a client with the required dependency
client_params["dependencies"] = {"dep1": f"module_a:1@{rank}"}
clients.create(**client_params)
self.assertEqual(len(clients), 3)

# Try creating a provider without the required dependency
provider_params = self.make_provider_params([
{"name": "dep1", "type": "module_a",
"is_required": True, "kind": "provider_handle"}])
with self.assertRaises(mbs.BedrockException):
providers.create(**provider_params)

# Try creating a provider with the wrong provider ID
provider_params["dependencies"] = {"dep1": "module_a:999@{rank}"}
with self.assertRaises(mbs.BedrockException):
providers.create(**provider_params)

# Try creating a provider with the wrong module
provider_params["dependencies"] = {"dep1": "module_b:1@{rank}"}
with self.assertRaises(mbs.BedrockException):
providers.create(**provider_params)

# Try creating a provider with the required dependency
provider_params["dependencies"] = {"dep1": f"module_a:1@{rank}"}
providers.create(**provider_params)
self.assertEqual(len(providers), 3)

def test_dependency_on_ph_with_ssg(self):
providers = self.server.providers
self.assertEqual(len(providers), 2)
Expand Down
2 changes: 1 addition & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ set (server-src-files
DependencyFinder.cpp
MargoLogging.cpp
Jx9Manager.cpp
MPI.cpp)
MPIEnv.cpp)

set (client-src-files
Client.cpp
Expand Down
18 changes: 16 additions & 2 deletions src/DependencyFinder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ namespace tl = thallium;

namespace bedrock {

DependencyFinder::DependencyFinder(const MargoManager& margo,
DependencyFinder::DependencyFinder(const MPIEnv& mpi,
const MargoManager& margo,
const ABTioManager& abtio,
const SSGManager& ssg,
const MonaManager& mona,
const ProviderManager& pmanager,
const ClientManager& cmanager)
: self(std::make_shared<DependencyFinderImpl>(margo.getMargoInstance())) {
self->m_mpi = mpi.self;
self->m_margo_context = margo;
self->m_abtio_context = abtio.self;
self->m_ssg_context = ssg.self;
Expand Down Expand Up @@ -242,14 +244,26 @@ std::shared_ptr<NamedDependency>
DependencyFinder::makeProviderHandle(const std::string& client_name,
const std::string& type,
uint16_t provider_id,
const std::string& locator,
const std::string& locatorArg,
std::string* resolved) const {
auto locator = locatorArg;
spdlog::trace("Making provider handle of type {} with id {} and locator {}",
type, provider_id, locator);
auto mid = MargoManager(self->m_margo_context).getMargoInstance();
auto client = findClient(type, client_name);
auto service_factory = ModuleContext::getServiceFactory(type);
hg_addr_t addr = HG_ADDR_NULL;
bool locator_is_number = true;
int rank = 0;
for(auto c : locator) {
if(c >= '0' && c <= '9') {
rank = rank*10 + (c - '0');
continue;
}
locator_is_number = false;
break;
}
if(locator_is_number) locator = MPIEnv(self->m_mpi).addressOfRank(rank);

if (locator == "local") {

Expand Down
4 changes: 3 additions & 1 deletion src/DependencyFinderImpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
#include "ProviderManagerImpl.hpp"
#include "ClientManagerImpl.hpp"
#include "Formatting.hpp"
#include "MPIEnvImpl.hpp"
#include "bedrock/VoidPtr.hpp"
#include "bedrock/RequestResult.hpp"
#include "bedrock/Exception.hpp"
#include <bedrock/Exception.hpp>
#include <thallium.hpp>
#include <string>
#include <unordered_map>
Expand All @@ -30,6 +31,7 @@ class DependencyFinderImpl {

public:
tl::engine m_engine;
std::shared_ptr<MPIEnvImpl> m_mpi;
std::shared_ptr<MargoManagerImpl> m_margo_context;
std::weak_ptr<ABTioManagerImpl> m_abtio_context;
std::weak_ptr<SSGManagerImpl> m_ssg_context;
Expand Down
10 changes: 5 additions & 5 deletions src/Jx9Manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ using nlohmann::json;

// LCOV_EXCL_START

Jx9Manager::Jx9Manager()
: self(std::make_shared<Jx9ManagerImpl>()) {}
Jx9Manager::Jx9Manager(MPIEnv mpi)
: self(std::make_shared<Jx9ManagerImpl>(std::move(mpi))) {}

Jx9Manager::Jx9Manager(const Jx9Manager&) = default;

Expand Down Expand Up @@ -72,10 +72,10 @@ std::string Jx9Manager::executeQuery(
// installing MPI_COMM_WORLD
json comm_world = nullptr;
#ifdef ENABLE_MPI
if(self->m_mpi->enabled()) {
if(self->m_mpi.isEnabled()) {
comm_world = json::object();
comm_world["rank"] = self->m_mpi->rank();
comm_world["size"] = self->m_mpi->size();
comm_world["rank"] = self->m_mpi.globalRank();
comm_world["size"] = self->m_mpi.globalSize();
}
#endif
jx9_value* jx9_comm_world = nullptr;
Expand Down
8 changes: 4 additions & 4 deletions src/Jx9ManagerImpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <vector>
#include <string>
#include <unordered_map>
#include "MPI.hpp"
#include <bedrock/MPIEnv.hpp>

namespace bedrock {

Expand All @@ -25,10 +25,10 @@ class Jx9ManagerImpl {
jx9* m_engine = nullptr;
tl::mutex m_mtx;
std::unordered_map<std::string, std::string> m_global_variables;
std::shared_ptr<MPI> m_mpi;
MPIEnv m_mpi;

Jx9ManagerImpl()
: m_mpi(std::make_shared<MPI>()) {
Jx9ManagerImpl(MPIEnv mpi)
: m_mpi(std::move(mpi)) {
spdlog::trace("Initializing Jx9 engine");
jx9_init(&m_engine);
}
Expand Down
15 changes: 0 additions & 15 deletions src/MPI.cpp

This file was deleted.

Loading

0 comments on commit ff1d496

Please sign in to comment.