Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Errors when converting megatron checkpoints to hf format #18

Open
etsurin opened this issue Dec 15, 2024 · 0 comments
Open

Errors when converting megatron checkpoints to hf format #18

etsurin opened this issue Dec 15, 2024 · 0 comments

Comments

@etsurin
Copy link

etsurin commented Dec 15, 2024

Question

I used megatron-lm to pretrain llama3.1 model and got distcp checkpoints, however, errors occur when I use this tool to convert the checkpoints into hf format.

Environment

absl-py                  2.1.0
accelerate               1.1.1
aiohappyeyeballs         2.4.3
aiohttp                  3.11.8
aiosignal                1.3.1
annotated-types          0.7.0
antlr4-python3-runtime   4.9.3
apex                     0.1
attrs                    24.2.0
certifi                  2024.8.30
charset-normalizer       3.4.0
dill                     0.3.9
einops                   0.8.0
filelock                 3.16.1
flash_attn               2.6.3
frozenlist               1.5.0
fsspec                   2024.10.0
grpcio                   1.68.0
huggingface-hub          0.26.2
idna                     3.10
importlib_metadata       8.5.0
Jinja2                   3.1.4
lightning-utilities      0.11.9
Markdown                 3.7
MarkupSafe               3.0.2
mpmath                   1.3.0
multidict                6.1.0
networkx                 3.4.2
ninja                    1.11.1.1
numpy                    2.1.3
nvidia-cublas-cu12       12.1.3.1
nvidia-cuda-cupti-cu12   12.1.105
nvidia-cuda-nvrtc-cu12   12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12        9.1.0.70
nvidia-cufft-cu12        11.0.2.54
nvidia-curand-cu12       10.3.2.106
nvidia-cusolver-cu12     11.4.5.107
nvidia-cusparse-cu12     12.1.0.106
nvidia-nccl-cu12         2.20.5
nvidia-nvjitlink-cu12    12.6.77
nvidia-nvtx-cu12         12.1.105
omegaconf                2.3.0
packaging                24.2
pandas                   2.2.3
pillow                   11.0.0
pip                      24.2
propcache                0.2.0
protobuf                 5.28.3
psutil                   6.1.0
pybind11                 2.13.6
pydantic                 2.9.2
pydantic_core            2.23.4
python-dateutil          2.9.0.post0
pytorch-lightning        2.4.0
pytz                     2024.2
PyYAML                   6.0.2
regex                    2024.11.6
requests                 2.32.3
safetensors              0.4.5
setuptools               75.1.0
six                      1.16.0
sympy                    1.13.3
tensorboard              2.18.0
tensorboard-data-server  0.7.2
tokenizers               0.20.3
torch                    2.4.0
torchaudio               2.4.0
torchmetrics             1.6.0
torchvision              0.19.0
tqdm                     4.67.0
transformer_engine       1.11.0
transformer_engine_cu12  1.11.0
transformer_engine_torch 1.11.0
transformers             4.46.2
triton                   3.0.0
typing_extensions        4.12.2
tzdata                   2024.2
urllib3                  2.2.3
Werkzeug                 3.1.3
wheel                    0.44.0
yarl                     1.18.0
zipp                     3.21.0

Command

Here is the command I used. I just followed the readme.

  python tools/checkpoint-convert/convert_fsdp.py \
  --hf-base-model-path xxx/Meta-Llama3.1-8B-hf \
  --tokenizer-path xxx/Meta-Llama3.1-8B-hf \
  --fsdp-checkpoint-path xxx/Meta-Llama3.1-8B-megatron/iter_0001250 \
  --checkpoint-output-path xxx/Meta-Llama3.1-8B-hfckpt/iter_0001250 \
  --sequence-length 8192

Files of checkpoint

iter_xxxxxxx
|-.metadata
|-common.pt
|-metadata.json
|-__0_0.distcp
|-__0_1.distcp
|-__1_0.distcp
|-__1_1.distcp
|-__2_0.distcp
|-__2_1.distcp
|-__3_0.distcp
|-__3_1.distcp

metadata
Below are the content of metadata.json
{"sharded_backend": "torch_dist", "sharded_backend_version": 1, "common_backend": "torch", "common_backend_version": 1}

Errors

Here are the errors I've got

Traceback (most recent call last):
  File "/gs/fs/tgd-kanno-lab003/llm-recipes/tools/checkpoint-convert/convert_fsdp.py", line 68, in <module>
    main()
  File "/gs/fs/tgd-kanno-lab003/llm-recipes/tools/checkpoint-convert/convert_fsdp.py", line 55, in main
    model = load_sharded_model_single_gpu(model, args.fsdp_checkpoint_path)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/gs/fs/tgd-kanno-lab003/llm-recipes/tools/checkpoint-convert/convert_fsdp.py", line 20, in load_sharded_model_single_gpu
    dist_cp.load(
  File "/gs/fs/tgd-kanno-lab003/miniconda3/envs/recipe/lib/python3.12/site-packages/torch/distributed/checkpoint/logger.py", line 66, in wrapper
    result = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/gs/fs/tgd-kanno-lab003/miniconda3/envs/recipe/lib/python3.12/site-packages/torch/distributed/checkpoint/utils.py", line 434, in inner_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/gs/fs/tgd-kanno-lab003/miniconda3/envs/recipe/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 168, in load
    _load_state_dict(
  File "/gs/fs/tgd-kanno-lab003/miniconda3/envs/recipe/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 220, in _load_state_dict
    central_plan: LoadPlan = distW.reduce_scatter("plan", local_step, global_step)
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/gs/fs/tgd-kanno-lab003/miniconda3/envs/recipe/lib/python3.12/site-packages/torch/distributed/checkpoint/utils.py", line 192, in reduce_scatter
    raise result
torch.distributed.checkpoint.api.CheckpointException: CheckpointException ranks:dict_keys([0])
Traceback (most recent call last): (RANK 0)
  File "/gs/fs/tgd-kanno-lab003/miniconda3/envs/recipe/lib/python3.12/site-packages/torch/distributed/checkpoint/utils.py", line 165, in reduce_scatter
    local_data = map_fun()
                 ^^^^^^^^^
  File "/gs/fs/tgd-kanno-lab003/miniconda3/envs/recipe/lib/python3.12/site-packages/torch/distributed/checkpoint/logger.py", line 66, in wrapper
    result = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/gs/fs/tgd-kanno-lab003/miniconda3/envs/recipe/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 209, in local_step
    local_plan = planner.create_local_plan()
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/gs/fs/tgd-kanno-lab003/miniconda3/envs/recipe/lib/python3.12/site-packages/torch/distributed/checkpoint/default_planner.py", line 197, in create_local_plan
    return create_default_local_load_plan(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/gs/fs/tgd-kanno-lab003/miniconda3/envs/recipe/lib/python3.12/site-packages/torch/distributed/checkpoint/default_planner.py", line 316, in create_default_local_load_plan
    raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")
RuntimeError: Missing key in checkpoint state_dict: model.model.embed_tokens.weight.

It seems that the tool can not read the sharded checkpoint. Could you please help me with the checkpoint converting?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant