-
Notifications
You must be signed in to change notification settings - Fork 305
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* support petrelfs * fix deepspeed save/load/resume * add ENV to toggle petrelfs * support hf save_pretrained * patch deepspeed engine
- Loading branch information
Showing
5 changed files
with
471 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
## 功能说明 | ||
|
||
### 已支持的功能 | ||
|
||
- 保存 DeepSpeed Checkpoint 至 CEPH | ||
- 从 Ceph 上的 DeepSpeed Checkpoint 续训 | ||
- `pth_to_hf` 支持 Ceph 上的 DeepSpeed Checkpoint | ||
|
||
### 暂不支持的功能 | ||
|
||
- `--endpoint-url` 不可用 | ||
- 训练时从 Ceph 加载 Huggingface 模型, 与 `zero3` 加载权重冲突 | ||
- HuggingFace `save_pretrained` 保存至 Ceph, 逻辑过于复杂,没办法 patch | ||
|
||
## 使用说明 | ||
|
||
#### 1. 验证 ceph 环境 | ||
|
||
使用前需确保 `petrel sdk` 可用,并且要使用的 Ceph bucket 存在且可用 | ||
|
||
验证 `aws` 命令行工具 | ||
|
||
```bash | ||
# 验证 aws 命令行工具 | ||
aws s3 ls $YOUR_BUCKET | ||
``` | ||
|
||
验证 `petrel sdk` | ||
|
||
```python | ||
bucket = 's3://xxx' | ||
|
||
from mmengine import get_file_backend | ||
backend = get_file_backend(bucket) | ||
|
||
for f in backend.list_dir_or_file(bucket): | ||
print(f) | ||
``` | ||
|
||
#### 2. 训练时保存 Checkpoint 至 Ceph | ||
|
||
`XTuner` 根据环境变量 `DS_CEPH_DIR` 来判断是否将 checkpoint 保存至 ceph | ||
|
||
```bash | ||
DS_CEPH_DIR=s3://xxxx srun ${SRUN_ARGS} xtuner train $CONFIG --launcher slurm | ||
``` | ||
|
||
#### 3. 从 Ceph 上的 Checkpoint 续训 | ||
|
||
Resume 时,要填写 checkpoint 在 ceph 上的完整路径 | ||
|
||
```bash | ||
DS_CEPH_DIR=s3://xxxx srun ${SRUN_ARGS} xtuner train $CONFIG --launcher slurm --resume s3://xxx/yyy/epoch_x.pth | ||
``` | ||
|
||
#### 4. 将 Ceph 上的 Checkpoint 转换为 HF 模型 | ||
|
||
不支持 `$HF_DIR` 为 ceph 路径 | ||
|
||
由于 Checkpoint 中存储了优化器状态,加载比较耗时,对于 ZeRO 1&2 可以直接加载 checkpoint 中的 `model_states.pt` 文件加速转换过程;ZeRO 3 必须先加载整个 checkpoint | ||
|
||
```bash | ||
srun ${SRUN_ARGS} xtuner convert pth_to_hf $CONFIG s3://xxx/yyy/epoch_x.pth $HF_DIR | ||
|
||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,25 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import os | ||
|
||
from mmengine.utils import digit_version | ||
|
||
from .entry_point import cli | ||
from .version import __version__, version_info | ||
|
||
__all__ = ['__version__', 'version_info', 'digit_version', 'cli'] | ||
HF_CEPH_HUB = os.getenv('HF_CEPH_HUB', '') | ||
HF_USE_CEPH = os.getenv('HF_USE_CEPH', 0) or HF_CEPH_HUB != '' | ||
DS_CEPH_DIR = os.getenv('DS_CEPH_DIR', None) | ||
if HF_USE_CEPH: | ||
from .utils.fileio import (patch_hf_auto_from_pretrained, | ||
patch_hf_save_pretrained) | ||
patch_hf_auto_from_pretrained(HF_CEPH_HUB) | ||
patch_hf_save_pretrained() | ||
|
||
if DS_CEPH_DIR: | ||
from .utils.fileio import patch_deepspeed_engine | ||
patch_deepspeed_engine() | ||
|
||
__all__ = [ | ||
'__version__', 'version_info', 'digit_version', 'cli', 'HF_USE_CEPH', | ||
'DS_CEPH_DIR' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.