Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed CI builds for RabbitMQ #89

Merged
merged 1 commit into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/containers/x86_64-broadwell-cuda11.6.1/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ FROM nvidia/cuda:11.6.1-devel-ubi8 AS base
MAINTAINER Giorgis Georgakoudis <georgakoudis1@llnl.gov>
RUN \
yum install -y dnf &&\
dnf install -y git xz autoconf automake unzip patch gcc-gfortran bzip2 file &&\
dnf install -y git xz autoconf automake unzip patch gcc-gfortran bzip2 file libevent-devel openssl-devel &&\
dnf upgrade -y &&\
dnf clean all
COPY repo repo
Expand Down
2 changes: 1 addition & 1 deletion .github/containers/x86_64-broadwell-gcc11.2.1/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ MAINTAINER Giorgis Georgakoudis <georgakoudis1@llnl.gov>
RUN \
yum install -y dnf &&\
dnf group install -y "Development Tools" &&\
dnf install -y git gcc-toolset-11 environment-modules &&\
dnf install -y git gcc-toolset-11 environment-modules libevent-devel openssl-devel &&\
dnf upgrade -y
COPY repo repo
RUN \
Expand Down
99 changes: 99 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
branches: [ "develop" ]
workflow_dispatch:


jobs:
build-run-tests:
# The type of runner that the job will run on
Expand Down Expand Up @@ -389,3 +390,101 @@ jobs:
-DWITH_ADIAK=Off \
$GITHUB_WORKSPACE
make

build-rmq-tests:
# The type of runner that the job will run on
runs-on: ubuntu-latest
services:
rabbitmq:
image: rabbitmq:3.11
env:
RABBITMQ_DEFAULT_USER: ams
RABBITMQ_DEFAULT_PASS: ams
ports:
- 5672

container:
image: ghcr.io/llnl/ams-ci-almalinux8:latest
env:
RABBITMQ_USER: ams
RABBITMQ_PASS: ams
RABBITMQ_HOST: rabbitmq
RABBITMQ_PORT: 5672

steps:
- uses: actions/checkout@v4
- name: Build Torch=On FAISS=On RMQ=On AMS
shell: bash -l {0}
run: |
module load gcc/11.2.1
export SPACK_ROOT=/spack/
source /spack/share/spack/setup-env.sh
spack env activate -p /ams-spack-env
rm -rf build/
mkdir build
cd build
export AMS_MFEM_PATH=$(spack location -i mfem)
export AMS_TORCH_PATH=$(spack location -i py-torch)/lib/python3.10/site-packages/torch/share/cmake/Torch
export AMS_FAISS_PATH=$(spack location -i faiss)
export AMS_UMPIRE_PATH=$(spack location -i umpire)
export AMS_HDF5_PATH=$(spack location -i hdf5)
export AMS_AMQPCPP_PATH=$(spack location -i amqp-cpp)/cmake
cmake \
-DBUILD_SHARED_LIBS=On \
-DCMAKE_PREFIX_PATH=$INSTALL_DIR \
-DWITH_CALIPER=On \
-DWITH_HDF5=On \
-DWITH_EXAMPLES=On \
-DAMS_HDF5_DIR=$AMS_HDF5_PATH \
-DCMAKE_INSTALL_PREFIX=./install \
-DCMAKE_BUILD_TYPE=Release \
-DWITH_CUDA=Off \
-DUMPIRE_DIR=$AMS_UMPIRE_PATH \
-DMFEM_DIR=$AMS_MFEM_PATH \
-DWITH_FAISS=On \
-DWITH_MPI=On \
-DWITH_TORCH=On \
-DWITH_TESTS=On \
-DTorch_DIR=$AMS_TORCH_PATH \
-DFAISS_DIR=$AMS_FAISS_PATH \
-DWITH_AMS_DEBUG=On \
-DWITH_WORKFLOW=On \
-DWITH_ADIAK=Off \
-DWITH_RMQ=On \
-Damqpcpp_DIR=$AMS_AMQPCPP_PATH \
$GITHUB_WORKSPACE
make
- name: Run tests Torch=On FAISS=On RMQ=On AMSlib RabbitMQ egress
run: |
cd build
export SPACK_ROOT=/spack/
source /spack/share/spack/setup-env.sh
spack env activate -p /ams-spack-env

# We overwrite the rmq.json created by CMake
echo """{
\"db\": {
\"dbType\": \"rmq\",
\"rmq_config\": {
\"rabbitmq-name\": \"rabbit\",
\"rabbitmq-user\": \"${RABBITMQ_USER}\",
\"rabbitmq-password\": \"${RABBITMQ_PASS}\",
\"service-port\": ${RABBITMQ_PORT},
\"service-host\": \"${RABBITMQ_HOST}\",
\"rabbitmq-vhost\": \"/\",
\"rabbitmq-outbound-queue\": \"test-ci\",
\"rabbitmq-exchange\": \"ams-fanout\",
\"rabbitmq-routing-key\": \"training\"
},
\"update_surrogate\": false
},
\"ml_models\": {},
\"domain_models\": {}
}""" > $GITHUB_WORKSPACE/build/tests/AMSlib/rmq.json

make test
env:
RABBITMQ_USER: ams
RABBITMQ_PASS: ams
RABBITMQ_HOST: rabbitmq
RABBITMQ_PORT: 5672
97 changes: 76 additions & 21 deletions src/AMSWorkflow/ams/rmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import json
import pika


class AMSMessage(object):
"""
Represents a RabbitMQ incoming message from AMSLib.
Expand All @@ -28,6 +27,24 @@ class AMSMessage(object):
def __init__(self, body: str):
self.body = body

self.num_elements = None
self.hsize = None
self.dtype_byte = None
self.mpi_rank = None
self.domain_name_size = None
self.domain_names = []
self.input_dim = None
self.output_dim = None

def __str__(self):
dt = "float" if self.dtype_byte == 4 else 8
if not self.dtype_byte:
dt = None
return f"AMSMessage(domain={self.domain_names}, #mpi={self.mpi_rank}, num_elements={self.num_elements}, datatype={dt}, input_dim={self.input_dim}, output_dim={self.output_dim})"

def __repr__(self):
return self.__str__()

def header_format(self) -> str:
"""
This string represents the AMS format in Python pack format:
Expand Down Expand Up @@ -110,6 +127,15 @@ def _parse_header(self, body: str) -> dict:
res["dsize"] = int(res["datatype"]) * int(res["num_element"]) * (int(res["input_dim"]) + int(res["output_dim"]))
res["msg_size"] = hsize + res["dsize"]
res["multiple_msg"] = len(body) != res["msg_size"]

self.num_elements = int(res["num_element"])
self.hsize = int(res["hsize"])
self.dtype_byte = int(res["datatype"])
self.mpi_rank = int(res["mpirank"])
self.domain_name_size = int(res["domain_size"])
self.input_dim = int(res["input_dim"])
self.output_dim = int(res["output_dim"])

return res

def _parse_data(self, body: str, header_info: dict) -> Tuple[str, np.array, np.array]:
Expand Down Expand Up @@ -144,30 +170,37 @@ def _decode(self, body: str) -> Tuple[np.array]:
input = []
output = []
# Multiple AMS messages could be packed in one RMQ message
# TODO: we should manage potential mutliple messages per AMSMessage better
while body:
header_info = self._parse_header(body)
domain_name, temp_input, temp_output = self._parse_data(body, header_info)
# print(f"MSG: {domain_name} input shape {temp_input.shape} outpute shape {temp_output.shape}")
# total size of byte we read for that message
chunk_size = header_info["hsize"] + header_info["dsize"] + header_info["domain_size"]
input.append(temp_input)
output.append(temp_output)
# We remove the current message and keep going
body = body[chunk_size:]
self.domain_names.append(domain_name)
return domain_name, np.concatenate(input), np.concatenate(output)

def decode(self) -> Tuple[str, np.array, np.array]:
return self._decode(self.body)

def default_ams_callback(method, properties, body):
"""Simple callback that decode incoming message assuming they are AMS binary messages"""
return AMSMessage(body)

class AMSChannel:
"""
A wrapper around Pika RabbitMQ channel
"""

def __init__(self, connection, q_name, logger: logging.Logger = None):
def __init__(self, connection, q_name, callback: Optional[Callable] = None, logger: Optional[logging.Logger] = None):
self.connection = connection
self.q_name = q_name
self.logger = logger if logger else logging.getLogger(__name__)
self.callback = callback if callback else self.default_callback

def __enter__(self):
self.open()
Expand All @@ -176,9 +209,9 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

@staticmethod
def callback(method, properties, body):
return body.decode("utf-8")
def default_callback(self, method, properties, body):
""" Simple callback that return the message received"""
return body

def open(self):
self.channel = self.connection.channel()
Expand All @@ -187,18 +220,19 @@ def open(self):
def close(self):
self.channel.close()

def receive(self, n_msg: int = None, accum_msg=list()):
def receive(self, n_msg: int = None, timeout: int = None, accum_msg = list()):
"""
Consume a message on the queue and post processing by calling the callback.
@param n_msg The number of messages to receive.
- if n_msg is None, this call will block for ever and will process all messages that arrives
- if n_msg = 1 for example, this function will block until one message has been processed.
@param timeout If None, timout infinite, otherwise timeout in seconds
@return a list containing all received messages
"""

if self.channel and self.channel.is_open:
self.logger.info(
f"Starting to consume messages from queue={self.q_name}, routing_key={self.routing_key} ..."
f"Starting to consume messages from queue={self.q_name} ..."
)
# we will consume only n_msg and requeue all other messages
# if there are more messages in the queue.
Expand All @@ -207,11 +241,15 @@ def receive(self, n_msg: int = None, accum_msg=list()):
n_msg = max(n_msg, 0)
message_consumed = 0
# Comsume n_msg messages and break out
for method_frame, properties, body in self.channel.consume(self.q_name):
for method_frame, properties, body in self.channel.consume(self.q_name, inactivity_timeout=timeout):
if (method_frame, properties, body) == (None, None, None):
self.logger.info(f"Timeout after {timeout} seconds")
self.channel.cancel()
break
# Call the call on the message parts
try:
accum_msg.append(
BlockingClient.callback(
self.callback(
method_frame,
properties,
body,
Expand All @@ -223,23 +261,24 @@ def receive(self, n_msg: int = None, accum_msg=list()):
finally:
# Acknowledge the message even on failure
self.channel.basic_ack(delivery_tag=method_frame.delivery_tag)
message_consumed += 1
self.logger.warning(
f"Consumed message {message_consumed+1}/{method_frame.delivery_tag} (exchange={method_frame.exchange}, routing_key={method_frame.routing_key})"
f"Consumed message {message_consumed}/{method_frame.delivery_tag} (exchange=\'{method_frame.exchange}\', routing_key={method_frame.routing_key})"
)
message_consumed += 1
# Escape out of the loop after nb_msg messages
if message_consumed == n_msg:
# Cancel the consumer and return any pending messages
self.channel.cancel()
break
return accum_msg

def send(self, text: str):
def send(self, text: str, exchange : str = ""):
"""
Send a message
@param text The text to send
@param exchange Exchange to use
"""
self.channel.basic_publish(exchange="", routing_key=self.q_name, body=text)
self.channel.basic_publish(exchange=exchange, routing_key=self.q_name, body=text)
return

def get_messages(self):
Expand All @@ -250,26 +289,42 @@ def purge(self):
if self.channel and self.channel.is_open:
self.channel.queue_purge(self.q_name)


class BlockingClient:
"""
BlockingClient is a class that manages a simple blocking RMQ client lifecycle.
"""

def __init__(self, host, port, vhost, user, password, cert, logger: logging.Logger = None):
def __init__(
self,
host: str,
port: int,
vhost: str,
user: str,
password: str,
cert: Optional[str] = None,
callback: Optional[Callable] = None,
logger: Optional[logging.Logger] = None
):
# CA Cert, can be generated with (where $REMOTE_HOST and $REMOTE_PORT can be found in the JSON file):
# openssl s_client -connect $REMOTE_HOST:$REMOTE_PORT -showcerts < /dev/null 2>/dev/null | sed -ne '/-BEGIN CERTIFICATE-/,/-END CERTIFICATE-/p' rmq-pds.crt
self.logger = logger if logger else logging.getLogger(__name__)
self.cert = cert
self.context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self.context.verify_mode = ssl.CERT_REQUIRED
self.context.check_hostname = False
self.context.load_verify_locations(self.cert)

if self.cert is None or self.cert == "":
ssl_options = None
else:
self.context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self.context.verify_mode = ssl.CERT_REQUIRED
self.context.check_hostname = False
self.context.load_verify_locations(self.cert)
ssl_options = pika.SSLOptions(self.context)

self.host = host
self.vhost = vhost
self.port = port
self.user = user
self.password = password
self.callback = callback

self.credentials = pika.PlainCredentials(self.user, self.password)

Expand All @@ -278,7 +333,7 @@ def __init__(self, host, port, vhost, user, password, cert, logger: logging.Logg
port=self.port,
virtual_host=self.vhost,
credentials=self.credentials,
ssl_options=pika.SSLOptions(self.context),
ssl_options=ssl_options,
)

def __enter__(self):
Expand All @@ -290,7 +345,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):

def connect(self, queue):
"""Connect to the queue"""
return AMSChannel(self.connection, queue)
return AMSChannel(self.connection, queue, self.callback)


class AsyncConsumer(object):
Expand Down
Loading
Loading