Skip to content

Commit

Permalink
fix none src model and skip load ckpt for vllm (#50)
Browse files Browse the repository at this point in the history
* fix none src model and skip load ckpt for vllm

* fix none src model and skip load ckpt for vllm
  • Loading branch information
stcy07 authored Sep 3, 2024
1 parent d0715bc commit 36c6f48
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 4 deletions.
3 changes: 2 additions & 1 deletion chatlearn/models/vllm_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ def _init_args(self):

def setup(self):
"""Set up model and load checkpoint"""
model = [get_model(self.model_provider, self.model_args)]
need_load_ckpt = self.src_parameter_model is None
model = [get_model(self.model_provider, self.model_args, need_load_ckpt)]

assert len(model) == 1, "Above condition should have caught this"
self.model = model[0]
Expand Down
3 changes: 2 additions & 1 deletion chatlearn/runtime/dist_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,8 @@ def register_func(self):
"free_grad_buffers",
"build_grad_buffers",
"eval",
"train"]:
"train",
"set_src_parameter_model"]:
dist_call = partial(self.call_replica_func, func_name)
setattr(self, func_name, dist_call)

Expand Down
8 changes: 8 additions & 0 deletions chatlearn/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ def setup(self):
# for ease to access model by self.{model_name}
for model in self.remote_models:
setattr(self, model.name, model)

if hasattr(self, '_param_sync_pairs'):
ref_set_src = []
for src_model, dst_model in self._param_sync_pairs:
remote_src_model = getattr(self, src_model.name)
remote_dst_model = getattr(self, dst_model.name)
ref_set_src += remote_dst_model.set_src_parameter_model(remote_src_model)
future.wait(ref_set_src)
# include compile in init, compile dependencies need to be called serially
logger.info(get_full_proc_memory_info('Before model init'))
for model in self.remote_models:
Expand Down
4 changes: 2 additions & 2 deletions chatlearn/utils/vllm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,13 +396,13 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
return args


def get_model(model_provider, args):
def get_model(model_provider, args, need_load_ckpt=True):
with _set_default_torch_dtype(args.get("params_dtype")):
# Create a model instance.
# The weights will be initialized as empty tensors.
model = model_provider()
model = model.cuda()
if args["load"]:
if args["load"] and need_load_ckpt:
model.load_weights()
else:
# For accurate performance evaluation, we assign
Expand Down

0 comments on commit 36c6f48

Please sign in to comment.