From 10e62a5312ddf339193da823f7fb15d859332902 Mon Sep 17 00:00:00 2001 From: Loic Pottier Date: Mon, 23 Sep 2024 17:00:21 -0700 Subject: [PATCH] Added CI tests for RabbitMQ components in AMS Signed-off-by: Loic Pottier --- .../x86_64-broadwell-cuda11.6.1/Dockerfile | 2 +- .../x86_64-broadwell-gcc11.2.1/Dockerfile | 2 +- .github/workflows/ci.yml | 99 +++++++++ src/AMSWorkflow/ams/rmq.py | 97 +++++++-- src/AMSWorkflow/ams/util.py | 18 ++ src/AMSlib/AMS.cpp | 10 +- src/AMSlib/wf/basedb.hpp | 13 +- src/AMSlib/wf/rmqdb.cpp | 17 +- tests/AMSlib/CMakeLists.txt | 33 ++- tests/AMSlib/ams_rmq_env.cpp | 205 ++++++++++++++++++ tests/AMSlib/json_configs/rmq.json.in | 2 +- tests/AMSlib/verify_rmq.py | 85 ++++++++ tools/rmq/recv_binary.py | 49 ++++- 13 files changed, 586 insertions(+), 46 deletions(-) create mode 100644 tests/AMSlib/ams_rmq_env.cpp create mode 100644 tests/AMSlib/verify_rmq.py diff --git a/.github/containers/x86_64-broadwell-cuda11.6.1/Dockerfile b/.github/containers/x86_64-broadwell-cuda11.6.1/Dockerfile index 15bc767f..25474ba9 100644 --- a/.github/containers/x86_64-broadwell-cuda11.6.1/Dockerfile +++ b/.github/containers/x86_64-broadwell-cuda11.6.1/Dockerfile @@ -2,7 +2,7 @@ FROM nvidia/cuda:11.6.1-devel-ubi8 AS base MAINTAINER Giorgis Georgakoudis RUN \ yum install -y dnf &&\ - dnf install -y git xz autoconf automake unzip patch gcc-gfortran bzip2 file &&\ + dnf install -y git xz autoconf automake unzip patch gcc-gfortran bzip2 file libevent-devel openssl-devel &&\ dnf upgrade -y &&\ dnf clean all COPY repo repo diff --git a/.github/containers/x86_64-broadwell-gcc11.2.1/Dockerfile b/.github/containers/x86_64-broadwell-gcc11.2.1/Dockerfile index b26e7f41..3e13ab1e 100644 --- a/.github/containers/x86_64-broadwell-gcc11.2.1/Dockerfile +++ b/.github/containers/x86_64-broadwell-gcc11.2.1/Dockerfile @@ -3,7 +3,7 @@ MAINTAINER Giorgis Georgakoudis RUN \ yum install -y dnf &&\ dnf group install -y "Development Tools" &&\ - dnf install -y git gcc-toolset-11 environment-modules &&\ + dnf install -y git gcc-toolset-11 environment-modules libevent-devel openssl-devel &&\ dnf upgrade -y COPY repo repo RUN \ diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0e6657ee..ac183da6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -6,6 +6,7 @@ on: branches: [ "develop" ] workflow_dispatch: + jobs: build-run-tests: # The type of runner that the job will run on @@ -389,3 +390,101 @@ jobs: -DWITH_ADIAK=Off \ $GITHUB_WORKSPACE make + + build-rmq-tests: + # The type of runner that the job will run on + runs-on: ubuntu-latest + services: + rabbitmq: + image: rabbitmq:3.11 + env: + RABBITMQ_DEFAULT_USER: ams + RABBITMQ_DEFAULT_PASS: ams + ports: + - 5672 + + container: + image: ghcr.io/llnl/ams-ci-almalinux8:latest + env: + RABBITMQ_USER: ams + RABBITMQ_PASS: ams + RABBITMQ_HOST: rabbitmq + RABBITMQ_PORT: 5672 + + steps: + - uses: actions/checkout@v4 + - name: Build Torch=On FAISS=On RMQ=On AMS + shell: bash -l {0} + run: | + module load gcc/11.2.1 + export SPACK_ROOT=/spack/ + source /spack/share/spack/setup-env.sh + spack env activate -p /ams-spack-env + rm -rf build/ + mkdir build + cd build + export AMS_MFEM_PATH=$(spack location -i mfem) + export AMS_TORCH_PATH=$(spack location -i py-torch)/lib/python3.10/site-packages/torch/share/cmake/Torch + export AMS_FAISS_PATH=$(spack location -i faiss) + export AMS_UMPIRE_PATH=$(spack location -i umpire) + export AMS_HDF5_PATH=$(spack location -i hdf5) + export AMS_AMQPCPP_PATH=$(spack location -i amqp-cpp)/cmake + cmake \ + -DBUILD_SHARED_LIBS=On \ + -DCMAKE_PREFIX_PATH=$INSTALL_DIR \ + -DWITH_CALIPER=On \ + -DWITH_HDF5=On \ + -DWITH_EXAMPLES=On \ + -DAMS_HDF5_DIR=$AMS_HDF5_PATH \ + -DCMAKE_INSTALL_PREFIX=./install \ + -DCMAKE_BUILD_TYPE=Release \ + -DWITH_CUDA=Off \ + -DUMPIRE_DIR=$AMS_UMPIRE_PATH \ + -DMFEM_DIR=$AMS_MFEM_PATH \ + -DWITH_FAISS=On \ + -DWITH_MPI=On \ + -DWITH_TORCH=On \ + -DWITH_TESTS=On \ + -DTorch_DIR=$AMS_TORCH_PATH \ + -DFAISS_DIR=$AMS_FAISS_PATH \ + -DWITH_AMS_DEBUG=On \ + -DWITH_WORKFLOW=On \ + -DWITH_ADIAK=Off \ + -DWITH_RMQ=On \ + -Damqpcpp_DIR=$AMS_AMQPCPP_PATH \ + $GITHUB_WORKSPACE + make + - name: Run tests Torch=On FAISS=On RMQ=On AMSlib RabbitMQ egress + run: | + cd build + export SPACK_ROOT=/spack/ + source /spack/share/spack/setup-env.sh + spack env activate -p /ams-spack-env + + # We overwrite the rmq.json created by CMake + echo """{ + \"db\": { + \"dbType\": \"rmq\", + \"rmq_config\": { + \"rabbitmq-name\": \"rabbit\", + \"rabbitmq-user\": \"${RABBITMQ_USER}\", + \"rabbitmq-password\": \"${RABBITMQ_PASS}\", + \"service-port\": ${RABBITMQ_PORT}, + \"service-host\": \"${RABBITMQ_HOST}\", + \"rabbitmq-vhost\": \"/\", + \"rabbitmq-outbound-queue\": \"test-ci\", + \"rabbitmq-exchange\": \"ams-fanout\", + \"rabbitmq-routing-key\": \"training\" + }, + \"update_surrogate\": false + }, + \"ml_models\": {}, + \"domain_models\": {} + }""" > $GITHUB_WORKSPACE/build/tests/AMSlib/rmq.json + + make test + env: + RABBITMQ_USER: ams + RABBITMQ_PASS: ams + RABBITMQ_HOST: rabbitmq + RABBITMQ_PORT: 5672 \ No newline at end of file diff --git a/src/AMSWorkflow/ams/rmq.py b/src/AMSWorkflow/ams/rmq.py index dc588c25..73cd385a 100644 --- a/src/AMSWorkflow/ams/rmq.py +++ b/src/AMSWorkflow/ams/rmq.py @@ -16,7 +16,6 @@ import json import pika - class AMSMessage(object): """ Represents a RabbitMQ incoming message from AMSLib. @@ -28,6 +27,24 @@ class AMSMessage(object): def __init__(self, body: str): self.body = body + self.num_elements = None + self.hsize = None + self.dtype_byte = None + self.mpi_rank = None + self.domain_name_size = None + self.domain_names = [] + self.input_dim = None + self.output_dim = None + + def __str__(self): + dt = "float" if self.dtype_byte == 4 else 8 + if not self.dtype_byte: + dt = None + return f"AMSMessage(domain={self.domain_names}, #mpi={self.mpi_rank}, num_elements={self.num_elements}, datatype={dt}, input_dim={self.input_dim}, output_dim={self.output_dim})" + + def __repr__(self): + return self.__str__() + def header_format(self) -> str: """ This string represents the AMS format in Python pack format: @@ -110,6 +127,15 @@ def _parse_header(self, body: str) -> dict: res["dsize"] = int(res["datatype"]) * int(res["num_element"]) * (int(res["input_dim"]) + int(res["output_dim"])) res["msg_size"] = hsize + res["dsize"] res["multiple_msg"] = len(body) != res["msg_size"] + + self.num_elements = int(res["num_element"]) + self.hsize = int(res["hsize"]) + self.dtype_byte = int(res["datatype"]) + self.mpi_rank = int(res["mpirank"]) + self.domain_name_size = int(res["domain_size"]) + self.input_dim = int(res["input_dim"]) + self.output_dim = int(res["output_dim"]) + return res def _parse_data(self, body: str, header_info: dict) -> Tuple[str, np.array, np.array]: @@ -144,30 +170,37 @@ def _decode(self, body: str) -> Tuple[np.array]: input = [] output = [] # Multiple AMS messages could be packed in one RMQ message + # TODO: we should manage potential mutliple messages per AMSMessage better while body: header_info = self._parse_header(body) domain_name, temp_input, temp_output = self._parse_data(body, header_info) + # print(f"MSG: {domain_name} input shape {temp_input.shape} outpute shape {temp_output.shape}") # total size of byte we read for that message chunk_size = header_info["hsize"] + header_info["dsize"] + header_info["domain_size"] input.append(temp_input) output.append(temp_output) # We remove the current message and keep going body = body[chunk_size:] + self.domain_names.append(domain_name) return domain_name, np.concatenate(input), np.concatenate(output) def decode(self) -> Tuple[str, np.array, np.array]: return self._decode(self.body) +def default_ams_callback(method, properties, body): + """Simple callback that decode incoming message assuming they are AMS binary messages""" + return AMSMessage(body) class AMSChannel: """ A wrapper around Pika RabbitMQ channel """ - def __init__(self, connection, q_name, logger: logging.Logger = None): + def __init__(self, connection, q_name, callback: Optional[Callable] = None, logger: Optional[logging.Logger] = None): self.connection = connection self.q_name = q_name self.logger = logger if logger else logging.getLogger(__name__) + self.callback = callback if callback else self.default_callback def __enter__(self): self.open() @@ -176,9 +209,9 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.close() - @staticmethod - def callback(method, properties, body): - return body.decode("utf-8") + def default_callback(self, method, properties, body): + """ Simple callback that return the message received""" + return body def open(self): self.channel = self.connection.channel() @@ -187,18 +220,19 @@ def open(self): def close(self): self.channel.close() - def receive(self, n_msg: int = None, accum_msg=list()): + def receive(self, n_msg: int = None, timeout: int = None, accum_msg = list()): """ Consume a message on the queue and post processing by calling the callback. @param n_msg The number of messages to receive. - if n_msg is None, this call will block for ever and will process all messages that arrives - if n_msg = 1 for example, this function will block until one message has been processed. + @param timeout If None, timout infinite, otherwise timeout in seconds @return a list containing all received messages """ if self.channel and self.channel.is_open: self.logger.info( - f"Starting to consume messages from queue={self.q_name}, routing_key={self.routing_key} ..." + f"Starting to consume messages from queue={self.q_name} ..." ) # we will consume only n_msg and requeue all other messages # if there are more messages in the queue. @@ -207,11 +241,15 @@ def receive(self, n_msg: int = None, accum_msg=list()): n_msg = max(n_msg, 0) message_consumed = 0 # Comsume n_msg messages and break out - for method_frame, properties, body in self.channel.consume(self.q_name): + for method_frame, properties, body in self.channel.consume(self.q_name, inactivity_timeout=timeout): + if (method_frame, properties, body) == (None, None, None): + self.logger.info(f"Timeout after {timeout} seconds") + self.channel.cancel() + break # Call the call on the message parts try: accum_msg.append( - BlockingClient.callback( + self.callback( method_frame, properties, body, @@ -223,10 +261,10 @@ def receive(self, n_msg: int = None, accum_msg=list()): finally: # Acknowledge the message even on failure self.channel.basic_ack(delivery_tag=method_frame.delivery_tag) + message_consumed += 1 self.logger.warning( - f"Consumed message {message_consumed+1}/{method_frame.delivery_tag} (exchange={method_frame.exchange}, routing_key={method_frame.routing_key})" + f"Consumed message {message_consumed}/{method_frame.delivery_tag} (exchange=\'{method_frame.exchange}\', routing_key={method_frame.routing_key})" ) - message_consumed += 1 # Escape out of the loop after nb_msg messages if message_consumed == n_msg: # Cancel the consumer and return any pending messages @@ -234,12 +272,13 @@ def receive(self, n_msg: int = None, accum_msg=list()): break return accum_msg - def send(self, text: str): + def send(self, text: str, exchange : str = ""): """ Send a message @param text The text to send + @param exchange Exchange to use """ - self.channel.basic_publish(exchange="", routing_key=self.q_name, body=text) + self.channel.basic_publish(exchange=exchange, routing_key=self.q_name, body=text) return def get_messages(self): @@ -250,26 +289,42 @@ def purge(self): if self.channel and self.channel.is_open: self.channel.queue_purge(self.q_name) - class BlockingClient: """ BlockingClient is a class that manages a simple blocking RMQ client lifecycle. """ - def __init__(self, host, port, vhost, user, password, cert, logger: logging.Logger = None): + def __init__( + self, + host: str, + port: int, + vhost: str, + user: str, + password: str, + cert: Optional[str] = None, + callback: Optional[Callable] = None, + logger: Optional[logging.Logger] = None + ): # CA Cert, can be generated with (where $REMOTE_HOST and $REMOTE_PORT can be found in the JSON file): # openssl s_client -connect $REMOTE_HOST:$REMOTE_PORT -showcerts < /dev/null 2>/dev/null | sed -ne '/-BEGIN CERTIFICATE-/,/-END CERTIFICATE-/p' rmq-pds.crt self.logger = logger if logger else logging.getLogger(__name__) self.cert = cert - self.context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - self.context.verify_mode = ssl.CERT_REQUIRED - self.context.check_hostname = False - self.context.load_verify_locations(self.cert) + + if self.cert is None or self.cert == "": + ssl_options = None + else: + self.context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + self.context.verify_mode = ssl.CERT_REQUIRED + self.context.check_hostname = False + self.context.load_verify_locations(self.cert) + ssl_options = pika.SSLOptions(self.context) + self.host = host self.vhost = vhost self.port = port self.user = user self.password = password + self.callback = callback self.credentials = pika.PlainCredentials(self.user, self.password) @@ -278,7 +333,7 @@ def __init__(self, host, port, vhost, user, password, cert, logger: logging.Logg port=self.port, virtual_host=self.vhost, credentials=self.credentials, - ssl_options=pika.SSLOptions(self.context), + ssl_options=ssl_options, ) def __enter__(self): @@ -290,7 +345,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): def connect(self, queue): """Connect to the queue""" - return AMSChannel(self.connection, queue) + return AMSChannel(self.connection, queue, self.callback) class AsyncConsumer(object): diff --git a/src/AMSWorkflow/ams/util.py b/src/AMSWorkflow/ams/util.py index 2cc1b141..5a3d759c 100644 --- a/src/AMSWorkflow/ams/util.py +++ b/src/AMSWorkflow/ams/util.py @@ -5,9 +5,11 @@ import datetime import socket +import subprocess import uuid from pathlib import Path +from typing import Tuple def get_unique_fn(): # Randomly generate the output file name. We use the uuid4 function with the socket name and the current @@ -20,6 +22,22 @@ def get_unique_fn(): ] return "_".join(fn) +def generate_tls_certificate(host: str, port: int) -> Tuple[bool,str]: + """Generate TLS certificate for RabbitMQ + + :param str host: The RabbitMQ hostname + :param int port: The RabbitMQ port + + :rtype: Tuple[bool,str] + :return: return a tuple with a boolean set to True if certificate got generated and the TLS certificate (other contains stderr) + """ + openssl = subprocess.run(["openssl", "s_client", "-connect", f"{host}:{port}", "-showcerts"], check=True, capture_output=True) + if openssl.returncode != 0: + return False, openssl.stderr.decode().strip() + sed = subprocess.run(["sed", "-ne", r"/-BEGIN CERTIFICATE-/,/-END CERTIFICATE-/p"], input=openssl.stdout, capture_output=True) + if sed.returncode != 0: + return False, sed.stderr.decode().strip() + return True, sed.stdout.decode().strip() def mkdir(root_path, fn): _tmp = root_path / Path(fn) diff --git a/src/AMSlib/AMS.cpp b/src/AMSlib/AMS.cpp index 4842e4b2..2b5afd17 100644 --- a/src/AMSlib/AMS.cpp +++ b/src/AMSlib/AMS.cpp @@ -5,9 +5,9 @@ * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception */ -#include - #include "AMS.h" + +#include #ifdef __ENABLE_MPI__ #include #endif @@ -378,7 +378,6 @@ class AMSWrap getEntry(rmq_entry, "rabbitmq-password"); std::string rmq_user = getEntry(rmq_entry, "rabbitmq-user"); std::string rmq_vhost = getEntry(rmq_entry, "rabbitmq-vhost"); - std::string rmq_cert = getEntry(rmq_entry, "rabbitmq-cert"); std::string rmq_out_queue = getEntry(rmq_entry, "rabbitmq-outbound-queue"); std::string exchange = @@ -387,6 +386,11 @@ class AMSWrap getEntry(rmq_entry, "rabbitmq-routing-key"); bool update_surrogate = getEntry(entry, "update_surrogate"); + // We allow connection to RabbitMQ without TLS certificate + std::string rmq_cert = ""; + if (rmq_entry.contains("rabbitmq-cert")) + rmq_cert = getEntry(rmq_entry, "rabbitmq-cert"); + auto &DB = ams::db::DBManager::getInstance(); DB.instantiate_rmq_db(port, host, diff --git a/src/AMSlib/wf/basedb.hpp b/src/AMSlib/wf/basedb.hpp index d87dee52..8ce86879 100644 --- a/src/AMSlib/wf/basedb.hpp +++ b/src/AMSlib/wf/basedb.hpp @@ -1010,7 +1010,7 @@ class AMSMessageInbound class RMQHandler : public AMQP::LibEventHandler { protected: - /** @brief Path to TLS certificate */ + /** @brief Path to TLS certificate (if empty, no TLS certificate)*/ std::string _cacert; /** @brief MPI rank (0 if no MPI support) */ uint64_t _rId; @@ -1035,7 +1035,7 @@ class RMQHandler : public AMQP::LibEventHandler */ RMQHandler(uint64_t rId, std::shared_ptr loop, - std::string cacert); + std::string cacert = ""); ~RMQHandler() = default; @@ -2076,10 +2076,11 @@ class DBManager { fs::path Path(rmq_cert); std::error_code ec; - CFATAL(AMS, - !fs::exists(Path, ec), - "Certificate file '%s' for RMQ server does not exist", - rmq_cert.c_str()); + CWARNING(AMS, + !fs::exists(Path, ec), + "Certificate file '%s' for RMQ server does not exist. AMS will " + "try to connect without it.", + rmq_cert.c_str()); dbType = AMSDBType::AMS_RMQ; updateSurrogate = update_surrogate; #ifdef __ENABLE_RMQ__ diff --git a/src/AMSlib/wf/rmqdb.cpp b/src/AMSlib/wf/rmqdb.cpp index a33cb3b6..8141d9a6 100644 --- a/src/AMSlib/wf/rmqdb.cpp +++ b/src/AMSlib/wf/rmqdb.cpp @@ -261,6 +261,12 @@ bool RMQHandler::connectionValid() bool RMQHandler::onSecuring(AMQP::TcpConnection* connection, SSL* ssl) { + // No TLS certificate provided + if (_cacert.empty()) { + DBG(RMQHandler, "No TLS certificate. Bypassing.") + return true; + } + ERR_clear_error(); unsigned long err; #if OPENSSL_VERSION_NUMBER < 0x10100000L @@ -943,11 +949,12 @@ bool RMQInterface::connect(std::string rmq_name, _rId = static_cast(distrib(generator)); AMQP::Login login(rmq_user, rmq_password); - _address = std::make_shared(service_host, - service_port, - login, - rmq_vhost, - /*is_secure*/ true); + bool is_secure = true; + // No TLS certificate provided + if (_cacert.empty()) is_secure = false; + + _address = std::make_shared( + service_host, service_port, login, rmq_vhost, is_secure); _publisher = std::make_shared(_rId, *_address, _cacert, _queue_sender); diff --git a/tests/AMSlib/CMakeLists.txt b/tests/AMSlib/CMakeLists.txt index a3a9c339..7de043f8 100644 --- a/tests/AMSlib/CMakeLists.txt +++ b/tests/AMSlib/CMakeLists.txt @@ -33,6 +33,22 @@ function(JSON_TESTS db_type) unset(JSON_FP) endfunction() +function(CHECK_RMQ_CONFIG file) + # Read the JSON file. + file(READ ${file} MY_JSON_STRING) + message(STATUS "RabbitMQ config ${file}") + + string(JSON DB_CONF GET ${MY_JSON_STRING} db) + string(JSON DB_CONF GET ${DB_CONF} rmq_config) + string(JSON RMQ_HOST GET ${DB_CONF} "service-host") + string(JSON RMQ_PORT GET ${DB_CONF} "service-port") + + if(NOT "${RMQ_HOST}" STREQUAL "" AND NOT "${RMQ_PORT}" STREQUAL "0") + message(STATUS "RabbitMQ config ${file}: ${RMQ_HOST}:${RMQ_PORT}") + else() + message(WARNING "RabbitMQ config file ${file} looks empty! Make sure to fill these fields before running the tests") + endif() +endfunction() function(INTEGRATION_TEST_ENV) JSON_TESTS("csv") @@ -43,12 +59,23 @@ function(INTEGRATION_TEST_ENV) add_test(NAME AMSEndToEndFromJSON::DuqMean::DuqMax::Double::DB::hdf5-debug::HOST COMMAND bash -c "AMS_OBJECTS=${JSON_FP} ${CMAKE_CURRENT_BINARY_DIR}/ams_end_to_end_env 0 8 9 \"double\" 1 1024 app_uq_mean_debug app_uq_max_debug;AMS_OBJECTS=${JSON_FP} python3 ${CMAKE_CURRENT_SOURCE_DIR}/verify_ete.py 0 8 9 \"double\" 1024 app_uq_mean_debug app_uq_max_debug") unset(JSON_FP) endif() +endfunction() + +function(INTEGRATION_TEST_RMQ) if (WITH_RMQ) - configure_file("${CMAKE_CURRENT_SOURCE_DIR}/json_configs/rmq.json.in" "rmq.json" @ONLY) + if(EXISTS "${CMAKE_CURRENT_BINARY_DIR}/rmq.json") + # If file exists we do not overwrite it + message(STATUS "Ctest will use ${CMAKE_CURRENT_BINARY_DIR}/rmq.json as RabbitMQ configuration for testing. Make sure RabbitMQ parameters are valid.") + else() + message(STATUS "Copying empty configuration to ${CMAKE_CURRENT_BINARY_DIR}/rmq.json") + configure_file("${CMAKE_CURRENT_SOURCE_DIR}/json_configs/rmq.json.in" "rmq.json" @ONLY) + endif() + set(JSON_FP "${CMAKE_CURRENT_BINARY_DIR}/rmq.json") + CHECK_RMQ_CONFIG(${JSON_FP}) + add_test(NAME AMSEndToEndFromJSON::NoModel::Double::DB::rmq::HOST COMMAND bash -c "AMS_OBJECTS=${JSON_FP} ${CMAKE_CURRENT_BINARY_DIR}/ams_rmq 0 8 9 \"double\" 2 1024; AMS_OBJECTS=${JSON_FP} python3 ${CMAKE_CURRENT_SOURCE_DIR}/verify_rmq.py 0 8 9 \"double\" 2 1024") endif() endfunction() - function (INTEGRATION_TEST) ####################################################### # TEST: output format @@ -186,6 +213,8 @@ endif() INTEGRATION_TEST() BUILD_TEST(ams_end_to_end_env ams_ete_env.cpp) INTEGRATION_TEST_ENV() + BUILD_TEST(ams_rmq ams_rmq_env.cpp) + INTEGRATION_TEST_RMQ() # UQ Tests diff --git a/tests/AMSlib/ams_rmq_env.cpp b/tests/AMSlib/ams_rmq_env.cpp new file mode 100644 index 00000000..da3ad1b4 --- /dev/null +++ b/tests/AMSlib/ams_rmq_env.cpp @@ -0,0 +1,205 @@ +#ifdef __AMS_ENABLE_MPI__ +#include +#endif +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "AMS.h" +#include "wf/debug.h" + +void createUmpirePool(std::string parent_name, std::string pool_name) +{ + auto &rm = umpire::ResourceManager::getInstance(); + auto alloc_resource = rm.makeAllocator( + pool_name, rm.getAllocator(parent_name)); +} + + +AMSDType getDataType(char *d_type) +{ + AMSDType dType = AMSDType::AMS_DOUBLE; + if (std::strcmp(d_type, "float") == 0) { + dType = AMSDType::AMS_SINGLE; + } else if (std::strcmp(d_type, "double") == 0) { + dType = AMSDType::AMS_DOUBLE; + } else { + assert(false && "Unknown data type"); + } + return dType; +} + +template +struct Problem { + int num_inputs; + int num_outputs; + Problem(int ni, int no) : num_inputs(ni), num_outputs(no) {} + + void run(long num_elements, DType **inputs, DType **outputs) + { + for (int i = 0; i < num_elements; i++) { + DType sum = 0; + for (int j = 0; j < num_inputs; j++) { + sum += inputs[j][i]; + } + + for (int j = 0; j < num_outputs; j++) { + outputs[j][i] = sum; + } + } + } + + + const DType *initialize_inputs(DType *inputs, long length) + { + for (int i = 0; i < length; i++) { + inputs[i] = static_cast(i); + } + return inputs; + } + + void ams_run(AMSExecutor &wf, + AMSResourceType resource, + int iterations, + int num_elements) + { + auto &rm = umpire::ResourceManager::getInstance(); + + for (int i = 0; i < iterations; i++) { + int elements = num_elements; // * ((DType)(rand()) / RAND_MAX) + 1; + std::vector inputs; + std::vector outputs; + + // Allocate Input memory + for (int j = 0; j < num_inputs; j++) { + DType *data = new DType[elements]; + inputs.push_back(initialize_inputs(data, elements)); + } + + // Allocate Output memory + for (int j = 0; j < num_outputs; j++) { + outputs.push_back(new DType[elements]); + } + + AMSExecute(wf, + (void *)this, + elements, + reinterpret_cast(inputs.data()), + reinterpret_cast(outputs.data()), + inputs.size(), + outputs.size()); + + for (int i = 0; i < num_outputs; i++) { + delete[] outputs[i]; + outputs[i] = nullptr; + } + + + for (int i = 0; i < num_inputs; i++) { + delete[] inputs[i]; + inputs[i] = nullptr; + } + } + } +}; + +void callBackDouble(void *cls, long elements, void **inputs, void **outputs) +{ + std::cout << "Called the double precision model\n"; + static_cast *>(cls)->run(elements, + (double **)(inputs), + (double **)(outputs)); +} + + +void callBackSingle(void *cls, long elements, void **inputs, void **outputs) +{ + std::cout << "Called the single precision model\n"; + static_cast *>(cls)->run(elements, + (float **)(inputs), + (float **)(outputs)); +} + + +int main(int argc, char **argv) +{ + + if (argc != 7) { + std::cout << "Wrong cli\n"; + std::cout << argv[0] + << " use_device(0|1) num_inputs num_outputs " + "data_type(float|double) " + "num_iterations num_elements" << std::endl; + return -1; + } + + + int use_device = std::atoi(argv[1]); + int num_inputs = std::atoi(argv[2]); + int num_outputs = std::atoi(argv[3]); + AMSDType data_type = getDataType(argv[4]); + int num_iterations = std::atoi(argv[5]); + int num_elements = std::atoi(argv[6]); + AMSResourceType resource = AMSResourceType::AMS_HOST; + srand(time(NULL)); + + // int num_inputs = 2; + // int num_outputs = 4; + // AMSDType data_type = getDataType("double"); + // int num_iterations = 1; + // int num_elements = 10; + + // Configure DB + // auto db_type = "rmq"; + + // AMSDBType dbType = AMSDBType::AMS_NONE; + // if (std::strcmp(db_type, "csv") == 0) { + // dbType = AMSDBType::AMS_CSV; + // } else if (std::strcmp(db_type, "hdf5") == 0) { + // dbType = AMSDBType::AMS_HDF5; + // } else if (std::strcmp(db_type, "rmq") == 0) { + // dbType = AMSDBType::AMS_RMQ; + // } + + createUmpirePool("HOST", "TEST_HOST"); + AMSSetAllocator(AMSResourceType::AMS_HOST, "TEST_HOST"); + + AMSCAbstrModel ams_model = AMSRegisterAbstractModel("rmq_db_no_model", + AMSUQPolicy::AMS_RANDOM, + 0.5, + "", + "", + "rmq_db_no_model", + 1); + + if (data_type == AMSDType::AMS_SINGLE) { + Problem prob(num_inputs, num_outputs); + AMSExecutor wf = AMSCreateExecutor(ams_model, + AMSDType::AMS_SINGLE, + resource, + (AMSPhysicFn)callBackSingle, + 0, + 1); + + prob.ams_run(wf, resource, num_iterations, num_elements); + } else { + Problem prob(num_inputs, num_outputs); + AMSExecutor wf = AMSCreateExecutor(ams_model, + AMSDType::AMS_DOUBLE, + resource, + (AMSPhysicFn)callBackDouble, + 0, + 1); + prob.ams_run(wf, resource, num_iterations, num_elements); + } + + return 0; +} diff --git a/tests/AMSlib/json_configs/rmq.json.in b/tests/AMSlib/json_configs/rmq.json.in index 9c29487a..9137c366 100644 --- a/tests/AMSlib/json_configs/rmq.json.in +++ b/tests/AMSlib/json_configs/rmq.json.in @@ -2,7 +2,7 @@ "db" : { "dbType" : "rmq", "rmq_config" : { - "service-port": , + "service-port": 0, "service-host": "", "rabbitmq-erlang-cookie": "", "rabbitmq-name": "", diff --git a/tests/AMSlib/verify_rmq.py b/tests/AMSlib/verify_rmq.py new file mode 100644 index 00000000..a40ad205 --- /dev/null +++ b/tests/AMSlib/verify_rmq.py @@ -0,0 +1,85 @@ +import sys +import json +from pathlib import Path +import os + +from ams.rmq import BlockingClient, default_ams_callback + +def verify( + use_device, + num_inputs, + num_outputs, + data_type, + num_iterations, + num_elements, + rmq_json, + timeout = None, + domain_test = "rmq_db_no_model" # defined in ams_rmq_env.cpp +): + host = rmq_json["service-host"] + vhost = rmq_json["rabbitmq-vhost"] + port = rmq_json["service-port"] + user = rmq_json["rabbitmq-user"] + password = rmq_json["rabbitmq-password"] + queue = rmq_json["rabbitmq-outbound-queue"] + cert = None + if "rabbitmq-cert" in rmq_json: + cert = rmq_json["rabbitmq-cert"] + cert = None if cert == "" else cert + + dtype = 4 + if data_type == "double": + dtype = 8 + + with BlockingClient(host, port, vhost, user, password, cert, default_ams_callback) as client: + with client.connect(queue) as channel: + msgs = channel.receive(n_msg = num_iterations, timeout = timeout) + + assert len(msgs) == num_iterations, f"Received incorrect number of messsages ({len(msgs)}): expected #msgs ({num_iterations})" + + for i, msg in enumerate(msgs): + domain, _, _ = msg.decode() + assert msg.num_elements == num_elements, f"Message #{i}: incorrect #elements ({msg.num_element}) vs. expected #elem {num_elements})" + assert msg.input_dim == num_inputs, f"Message #{i}: incorrect #inputs ({msg.input_dim}) vs. expected #inputs {num_inputs})" + assert msg.output_dim == num_outputs, f"Message #{i}: incorrect #outputs ({msg.output_dim}) vs. expected #outputs {num_outputs})" + assert msg.dtype_byte == dtype, f"Message #{i}: incorrect datatype ({msg.dtype_byte} bytes) vs. expected type {dtype} bytes)" + assert domain == domain_test, f"Message #{i}: incorrect domain name (got {domain}) expected rmq_db_no_model)" + + return 0 + +def from_json(argv): + print(argv) + use_device = int(argv[0]) + num_inputs = int(argv[1]) + num_outputs = int(argv[2]) + data_type = argv[3] + num_iterations = int(argv[4]) + num_elements = int(argv[5]) + + env_file = Path(os.environ["AMS_OBJECTS"]) + if not env_file.exists(): + print("Environment file does not exist") + return -1 + + with open(env_file, "r") as fd: + rmq_json = json.load(fd) + + res = verify( + use_device, + num_inputs, + num_outputs, + data_type, + num_iterations, + num_elements, + rmq_json["db"]["rmq_config"], + timeout = 60 # in seconds + ) + if res != 0: + return res + print("[Success] rmq test received") + return 0 + +if __name__ == "__main__": + if "AMS_OBJECTS" in os.environ: + sys.exit(from_json(sys.argv[1:])) + sys.exit(1) diff --git a/tools/rmq/recv_binary.py b/tools/rmq/recv_binary.py index 2aa628d0..ecb07058 100755 --- a/tools/rmq/recv_binary.py +++ b/tools/rmq/recv_binary.py @@ -92,7 +92,7 @@ def parse_data(body: str, header_info: dict) -> Tuple[str, np.array, np.array]: try: if data_size == 4: #if datatype takes 4 bytes - data = np.frombuffer(body[header_siz+domain_name_size:header_size+domain_name_size+data_size], dtype=np.float32) + data = np.frombuffer(body[header_size+domain_name_size:header_size+domain_name_size+data_size], dtype=np.float32) else: data = np.frombuffer(body[header_size+domain_name_size:header_size+domain_name_size+data_size], dtype=np.float64) except ValueError as e: @@ -133,7 +133,7 @@ def callback(ch, method, properties, body, args = None): stream = stream[chunk_size:] i += 1 -def main(credentials: str, cacert: str, queue: str): +def main(credentials: str, cacert: str, queue: str, n_msgs : int = None, timeout : int = None): conn = get_rmq_connection(credentials) if cacert is None: ssl_options = None @@ -165,17 +165,48 @@ def main(credentials: str, cacert: str, queue: str): result = channel.queue_declare(queue=queue, exclusive=False) queue_name = result.method.queue - channel.basic_consume(queue=queue_name, on_message_callback=callback, auto_ack=True) - print(f"Listening on queue = {queue_name}") + print(f"Listening on queue {queue_name}") + if timeout: + print(f"Set timeout of {timeout} secs") print(" [*] Waiting for messages. To exit press CTRL+C") - channel.start_consuming() + + message_consumed = 0 + # Get ten messages and break out + for method_frame, properties, body in channel.consume(queue_name, inactivity_timeout=timeout): + if (method_frame, properties, body) == (None, None, None): + print(f"Timed out after {timeout} seconds") + break + # Acknowledge the message + channel.basic_ack(method_frame.delivery_tag) + + callback(channel, method_frame, properties, body, None) + + message_consumed += 1 + # if n_msgs is None, consume for ever + if message_consumed == n_msgs: + print(f"Consumed {message_consumed} messages") + break + + + + # Cancel the consumer and return any pending messages + requeued_messages = channel.cancel() + if requeued_messages > 0: + print(f"Requeued {requeued_messages} messages") + + # Close the channel and the connection + channel.close() + connection.close() + def parse_args(): parser = argparse.ArgumentParser(description="Tools that consumes AMS-encoded messages from RabbitMQ queue") parser.add_argument('-c', '--creds', help="Credentials file (JSON)", required=True) parser.add_argument('-t', '--tls-cert', help="TLS certificate file", required=False) parser.add_argument('-q', '--queue', help="Queue to listen to", required=True) + parser.add_argument('-n', '--nmsgs', type=int, help="Max messages to consume", required=False, default=None) + parser.add_argument('--timeout', type=int, help="Numebr of seconds after which the consumer will timeout", default=10) args = parser.parse_args() return args @@ -183,7 +214,13 @@ def parse_args(): if __name__ == "__main__": try: args = parse_args() - main(credentials = args.creds, cacert = args.tls_cert, queue = args.queue) + main( + credentials = args.creds, + cacert = args.tls_cert, + queue = args.queue, + n_msgs = args.nmsgs, + timeout = args.timeout + ) except KeyboardInterrupt: print("") print("Done")