diff --git a/chatlearn/runtime/engine.py b/chatlearn/runtime/engine.py index 330c9c7..3195d03 100644 --- a/chatlearn/runtime/engine.py +++ b/chatlearn/runtime/engine.py @@ -296,9 +296,9 @@ def learn(self): self.timers("sync_parameters").start() self.model_manager.sync_parameters(requires_grad=False, validate=self.runtime_args.validate_param_sync) self.timers("sync_parameters").stop() + logger.info(f"{LOG_START} " + get_full_proc_memory_info('After first param sync')) logger.info( - f"{LOG_START} {self._name} sync_parameters summary {self.timers.log(names=['sync_parameters'])} " \ - + get_full_proc_memory_info('After first param sync') + f"{LOG_START} {self._name} sync_parameters summary {self.timers.log(names=['sync_parameters'])} " ) self._data_loader = data_loader for episode_id in range(self._start_episode, self.runtime_args.num_episode): diff --git a/examples/megatron/configs/llama2/grpo_math_vllm.yaml b/examples/megatron/configs/llama2/grpo_math_vllm.yaml index 64f54bf..2a7b933 100644 --- a/examples/megatron/configs/llama2/grpo_math_vllm.yaml +++ b/examples/megatron/configs/llama2/grpo_math_vllm.yaml @@ -66,3 +66,4 @@ runtime: max_relay_episode: 1 exp_name: ${exp_name:chatlearn} validate_param_sync: ${validate_param_sync:False} + sync_memory_optimization_level: ${sync_memory_optimization_level:1} diff --git a/examples/megatron/configs/llama2/online_dpo_vllm.yaml b/examples/megatron/configs/llama2/online_dpo_vllm.yaml index e64f78e..f274ee6 100644 --- a/examples/megatron/configs/llama2/online_dpo_vllm.yaml +++ b/examples/megatron/configs/llama2/online_dpo_vllm.yaml @@ -60,3 +60,4 @@ runtime: output_dir: ${output_dir} exp_name: ${exp_name:chatlearn} validate_param_sync: ${validate_param_sync:False} + sync_memory_optimization_level: ${sync_memory_optimization_level:1} diff --git a/examples/megatron/configs/llama2/vllm_param_sync.yaml b/examples/megatron/configs/llama2/vllm_param_sync.yaml index 9177fe8..4ab36bb 100644 --- a/examples/megatron/configs/llama2/vllm_param_sync.yaml +++ b/examples/megatron/configs/llama2/vllm_param_sync.yaml @@ -49,3 +49,4 @@ runtime: exp_name: ${exp_name:chatlearn} debug: ${debug:False} validate_param_sync: ${validate_param_sync:False} + sync_memory_optimization_level: ${sync_memory_optimization_level:1} diff --git a/examples/megatron/configs/llama2/vllm_rlhf.yaml b/examples/megatron/configs/llama2/vllm_rlhf.yaml index b57602b..3425346 100644 --- a/examples/megatron/configs/llama2/vllm_rlhf.yaml +++ b/examples/megatron/configs/llama2/vllm_rlhf.yaml @@ -82,3 +82,4 @@ runtime: exp_name: ${exp_name:chatlearn} debug: ${debug:False} validate_param_sync: ${validate_param_sync:False} + sync_memory_optimization_level: ${sync_memory_optimization_level:1}