From d8613520a23d40fed7abd449d5da770e87d0e728 Mon Sep 17 00:00:00 2001 From: Hao Lin Date: Mon, 6 Jan 2025 07:54:56 +0000 Subject: [PATCH] Add sync_memory_optimization_level for yamls --- chatlearn/runtime/engine.py | 4 ++-- examples/megatron/configs/llama2/grpo_math_vllm.yaml | 1 + examples/megatron/configs/llama2/online_dpo_vllm.yaml | 1 + examples/megatron/configs/llama2/vllm_param_sync.yaml | 1 + examples/megatron/configs/llama2/vllm_rlhf.yaml | 1 + 5 files changed, 6 insertions(+), 2 deletions(-) diff --git a/chatlearn/runtime/engine.py b/chatlearn/runtime/engine.py index 330c9c7f..3195d03b 100644 --- a/chatlearn/runtime/engine.py +++ b/chatlearn/runtime/engine.py @@ -296,9 +296,9 @@ def learn(self): self.timers("sync_parameters").start() self.model_manager.sync_parameters(requires_grad=False, validate=self.runtime_args.validate_param_sync) self.timers("sync_parameters").stop() + logger.info(f"{LOG_START} " + get_full_proc_memory_info('After first param sync')) logger.info( - f"{LOG_START} {self._name} sync_parameters summary {self.timers.log(names=['sync_parameters'])} " \ - + get_full_proc_memory_info('After first param sync') + f"{LOG_START} {self._name} sync_parameters summary {self.timers.log(names=['sync_parameters'])} " ) self._data_loader = data_loader for episode_id in range(self._start_episode, self.runtime_args.num_episode): diff --git a/examples/megatron/configs/llama2/grpo_math_vllm.yaml b/examples/megatron/configs/llama2/grpo_math_vllm.yaml index 64f54bfd..2a7b933b 100644 --- a/examples/megatron/configs/llama2/grpo_math_vllm.yaml +++ b/examples/megatron/configs/llama2/grpo_math_vllm.yaml @@ -66,3 +66,4 @@ runtime: max_relay_episode: 1 exp_name: ${exp_name:chatlearn} validate_param_sync: ${validate_param_sync:False} + sync_memory_optimization_level: ${sync_memory_optimization_level:1} diff --git a/examples/megatron/configs/llama2/online_dpo_vllm.yaml b/examples/megatron/configs/llama2/online_dpo_vllm.yaml index e64f78ee..f274ee6b 100644 --- a/examples/megatron/configs/llama2/online_dpo_vllm.yaml +++ b/examples/megatron/configs/llama2/online_dpo_vllm.yaml @@ -60,3 +60,4 @@ runtime: output_dir: ${output_dir} exp_name: ${exp_name:chatlearn} validate_param_sync: ${validate_param_sync:False} + sync_memory_optimization_level: ${sync_memory_optimization_level:1} diff --git a/examples/megatron/configs/llama2/vllm_param_sync.yaml b/examples/megatron/configs/llama2/vllm_param_sync.yaml index 9177fe87..4ab36bba 100644 --- a/examples/megatron/configs/llama2/vllm_param_sync.yaml +++ b/examples/megatron/configs/llama2/vllm_param_sync.yaml @@ -49,3 +49,4 @@ runtime: exp_name: ${exp_name:chatlearn} debug: ${debug:False} validate_param_sync: ${validate_param_sync:False} + sync_memory_optimization_level: ${sync_memory_optimization_level:1} diff --git a/examples/megatron/configs/llama2/vllm_rlhf.yaml b/examples/megatron/configs/llama2/vllm_rlhf.yaml index b57602b3..34253461 100644 --- a/examples/megatron/configs/llama2/vllm_rlhf.yaml +++ b/examples/megatron/configs/llama2/vllm_rlhf.yaml @@ -82,3 +82,4 @@ runtime: exp_name: ${exp_name:chatlearn} debug: ${debug:False} validate_param_sync: ${validate_param_sync:False} + sync_memory_optimization_level: ${sync_memory_optimization_level:1}