You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I used megatron-lm to pretrain llama3.1 model and got distcp checkpoints, however, errors occur when I use this tool to convert the checkpoints into hf format.
metadata
Below are the content of metadata.json
{"sharded_backend": "torch_dist", "sharded_backend_version": 1, "common_backend": "torch", "common_backend_version": 1}
Errors
Here are the errors I've got
Traceback (most recent call last):
File "/gs/fs/tgd-kanno-lab003/llm-recipes/tools/checkpoint-convert/convert_fsdp.py", line 68, in <module>
main()
File "/gs/fs/tgd-kanno-lab003/llm-recipes/tools/checkpoint-convert/convert_fsdp.py", line 55, in main
model = load_sharded_model_single_gpu(model, args.fsdp_checkpoint_path)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gs/fs/tgd-kanno-lab003/llm-recipes/tools/checkpoint-convert/convert_fsdp.py", line 20, in load_sharded_model_single_gpu
dist_cp.load(
File "/gs/fs/tgd-kanno-lab003/miniconda3/envs/recipe/lib/python3.12/site-packages/torch/distributed/checkpoint/logger.py", line 66, in wrapper
result = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/gs/fs/tgd-kanno-lab003/miniconda3/envs/recipe/lib/python3.12/site-packages/torch/distributed/checkpoint/utils.py", line 434, in inner_func
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/gs/fs/tgd-kanno-lab003/miniconda3/envs/recipe/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 168, in load
_load_state_dict(
File "/gs/fs/tgd-kanno-lab003/miniconda3/envs/recipe/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 220, in _load_state_dict
central_plan: LoadPlan = distW.reduce_scatter("plan", local_step, global_step)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gs/fs/tgd-kanno-lab003/miniconda3/envs/recipe/lib/python3.12/site-packages/torch/distributed/checkpoint/utils.py", line 192, in reduce_scatter
raise result
torch.distributed.checkpoint.api.CheckpointException: CheckpointException ranks:dict_keys([0])
Traceback (most recent call last): (RANK 0)
File "/gs/fs/tgd-kanno-lab003/miniconda3/envs/recipe/lib/python3.12/site-packages/torch/distributed/checkpoint/utils.py", line 165, in reduce_scatter
local_data = map_fun()
^^^^^^^^^
File "/gs/fs/tgd-kanno-lab003/miniconda3/envs/recipe/lib/python3.12/site-packages/torch/distributed/checkpoint/logger.py", line 66, in wrapper
result = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/gs/fs/tgd-kanno-lab003/miniconda3/envs/recipe/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 209, in local_step
local_plan = planner.create_local_plan()
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gs/fs/tgd-kanno-lab003/miniconda3/envs/recipe/lib/python3.12/site-packages/torch/distributed/checkpoint/default_planner.py", line 197, in create_local_plan
return create_default_local_load_plan(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gs/fs/tgd-kanno-lab003/miniconda3/envs/recipe/lib/python3.12/site-packages/torch/distributed/checkpoint/default_planner.py", line 316, in create_default_local_load_plan
raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")
RuntimeError: Missing key in checkpoint state_dict: model.model.embed_tokens.weight.
It seems that the tool can not read the sharded checkpoint. Could you please help me with the checkpoint converting?
The text was updated successfully, but these errors were encountered:
Question
I used megatron-lm to pretrain llama3.1 model and got distcp checkpoints, however, errors occur when I use this tool to convert the checkpoints into hf format.
Environment
Command
Here is the command I used. I just followed the readme.
Files of checkpoint
iter_xxxxxxx
|-.metadata
|-common.pt
|-metadata.json
|-__0_0.distcp
|-__0_1.distcp
|-__1_0.distcp
|-__1_1.distcp
|-__2_0.distcp
|-__2_1.distcp
|-__3_0.distcp
|-__3_1.distcp
metadata
Below are the content of metadata.json
{"sharded_backend": "torch_dist", "sharded_backend_version": 1, "common_backend": "torch", "common_backend_version": 1}
Errors
Here are the errors I've got
It seems that the tool can not read the sharded checkpoint. Could you please help me with the checkpoint converting?
The text was updated successfully, but these errors were encountered: