From 4c894a709a656d4cddeea63e425914e4d1c9f755 Mon Sep 17 00:00:00 2001 From: Trevor Grant Date: Fri, 27 Oct 2023 15:20:03 -0500 Subject: [PATCH] Fixes #45 Signed-off-by: Trevor Grant --- caikit_ray_backend/blocks/ray_train.py | 3 +++ caikit_ray_backend/ray_submitter.py | 14 ++----------- tests/test_ray_backend.py | 27 +++++++++++++++++++------- 3 files changed, 25 insertions(+), 19 deletions(-) diff --git a/caikit_ray_backend/blocks/ray_train.py b/caikit_ray_backend/blocks/ray_train.py index abff48c..4ea8bcc 100644 --- a/caikit_ray_backend/blocks/ray_train.py +++ b/caikit_ray_backend/blocks/ray_train.py @@ -218,6 +218,9 @@ def train( error.value_check("", num_gpus > 0) env_vars["requested_gpus"] = num_gpus + training_timeout = self.config.get("training_timeout", 60) + env_vars["training_timeout"] = float(training_timeout) + # Serialize **kwargs and add them to environment variables my_kwargs = {} for key, value in kwargs.items(): diff --git a/caikit_ray_backend/ray_submitter.py b/caikit_ray_backend/ray_submitter.py index efcec68..0160b7a 100644 --- a/caikit_ray_backend/ray_submitter.py +++ b/caikit_ray_backend/ray_submitter.py @@ -14,7 +14,6 @@ # Standard -from time import sleep import base64 import json import os @@ -24,7 +23,6 @@ import ray # First Party -from caikit import get_config from caikit.core.toolkit.errors import error_handler import alog @@ -78,28 +76,20 @@ def main(): if model_path: error.type_check("", str, model_path=model_path) - timeout = 3 - if get_config().training_timeout: - try: - timeout = float(get_config().training_timeout) - except ValueError: - log.warn( - f"training_timeout: '{get_config().training_timeout}' cannot be converted to int, ignoring" - ) + timeout = runtime_env.get("training_timeout", float(60)) # Finally kick off training with alog.ContextTimer(log.debug, "Done training %s in: ", module_class): task = ray_training_tasks.train_and_save.options( num_cpus=num_cpus, num_gpus=num_gpus ).remote(module_class, model_path, *args, **kwargs) - ready, _ = ray.wait([task], timeout=timeout) - if ready: ray.get(task) else: ray.cancel(task) log.error("Task did not complete before time out.") + raise TimeoutError("Task did not complete before time out.") if __name__ == "__main__": diff --git a/tests/test_ray_backend.py b/tests/test_ray_backend.py index c48c01b..a61920b 100644 --- a/tests/test_ray_backend.py +++ b/tests/test_ray_backend.py @@ -16,6 +16,7 @@ """ # Standard from datetime import datetime +import logging import os import time @@ -46,7 +47,10 @@ def jsonl_file_data_stream(): def test_job_submission_client(mock_ray_cluster, jsonl_file_data_stream): - config = {"connection": {"address": mock_ray_cluster.address}} + config = { + "connection": {"address": mock_ray_cluster.address}, + "training_timeout": 30.0, + } trainer = RayJobTrainModule(config, "ray_backend") args = [jsonl_file_data_stream] @@ -82,7 +86,10 @@ def test_job_submission_client(mock_ray_cluster, jsonl_file_data_stream): def test_wait(mock_ray_cluster, jsonl_file_data_stream): - config = {"connection": {"address": mock_ray_cluster.address}} + config = { + "connection": {"address": mock_ray_cluster.address}, + "training_timeout": 30.0, + } trainer = RayJobTrainModule(config, "ray_backend") args = [jsonl_file_data_stream] @@ -101,7 +108,10 @@ def test_wait(mock_ray_cluster, jsonl_file_data_stream): def test_load(mock_ray_cluster, jsonl_file_data_stream): - config = {"connection": {"address": mock_ray_cluster.address}} + config = { + "connection": {"address": mock_ray_cluster.address}, + "training_timeout": 30.0, + } trainer = RayJobTrainModule(config, "ray_backend") args = [jsonl_file_data_stream] @@ -118,7 +128,10 @@ def test_load(mock_ray_cluster, jsonl_file_data_stream): def test_cancel(mock_ray_cluster, jsonl_file_data_stream): - config = {"connection": {"address": mock_ray_cluster.address}} + config = { + "connection": {"address": mock_ray_cluster.address}, + "training_timeout": 30.0, + } trainer = RayJobTrainModule(config, "ray_backend") args = [jsonl_file_data_stream] @@ -145,7 +158,7 @@ def test_cancel(mock_ray_cluster, jsonl_file_data_stream): def test_timeout(mock_ray_cluster, jsonl_file_data_stream): config = { "connection": {"address": mock_ray_cluster.address}, - "training_timeout": 3, + "training_timeout": 0.1, } trainer = RayJobTrainModule(config, "ray_backend") @@ -156,11 +169,11 @@ def test_timeout(mock_ray_cluster, jsonl_file_data_stream): save_path="/tmp", ) - time.sleep(5) + time.sleep(3) status = model_future.get_info().status print("Final status was", status) - assert status == TrainingStatus.CANCELED + assert status == TrainingStatus.ERRORED ## Test Ray Backend