Skip to content

Commit

Permalink
[Feature]Support CEPH (#266)
Browse files Browse the repository at this point in the history
* support petrelfs

* fix deepspeed save/load/resume

* add ENV to toggle petrelfs

* support hf save_pretrained

* patch deepspeed engine
  • Loading branch information
pppppM authored Jan 24, 2024
1 parent b0f36f3 commit 076375d
Show file tree
Hide file tree
Showing 5 changed files with 471 additions and 2 deletions.
65 changes: 65 additions & 0 deletions docs/zh_cn/user_guides/ceph.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
## 功能说明

### 已支持的功能

- 保存 DeepSpeed Checkpoint 至 CEPH
- 从 Ceph 上的 DeepSpeed Checkpoint 续训
- `pth_to_hf` 支持 Ceph 上的 DeepSpeed Checkpoint

### 暂不支持的功能

- `--endpoint-url` 不可用
- 训练时从 Ceph 加载 Huggingface 模型, 与 `zero3` 加载权重冲突
- HuggingFace `save_pretrained` 保存至 Ceph, 逻辑过于复杂,没办法 patch

## 使用说明

#### 1. 验证 ceph 环境

使用前需确保 `petrel sdk` 可用,并且要使用的 Ceph bucket 存在且可用

验证 `aws` 命令行工具

```bash
# 验证 aws 命令行工具
aws s3 ls $YOUR_BUCKET
```

验证 `petrel sdk`

```python
bucket = 's3://xxx'

from mmengine import get_file_backend
backend = get_file_backend(bucket)

for f in backend.list_dir_or_file(bucket):
print(f)
```

#### 2. 训练时保存 Checkpoint 至 Ceph

`XTuner` 根据环境变量 `DS_CEPH_DIR` 来判断是否将 checkpoint 保存至 ceph

```bash
DS_CEPH_DIR=s3://xxxx srun ${SRUN_ARGS} xtuner train $CONFIG --launcher slurm
```

#### 3. 从 Ceph 上的 Checkpoint 续训

Resume 时,要填写 checkpoint 在 ceph 上的完整路径

```bash
DS_CEPH_DIR=s3://xxxx srun ${SRUN_ARGS} xtuner train $CONFIG --launcher slurm --resume s3://xxx/yyy/epoch_x.pth
```

#### 4. 将 Ceph 上的 Checkpoint 转换为 HF 模型

不支持 `$HF_DIR` 为 ceph 路径

由于 Checkpoint 中存储了优化器状态,加载比较耗时,对于 ZeRO 1&2 可以直接加载 checkpoint 中的 `model_states.pt` 文件加速转换过程;ZeRO 3 必须先加载整个 checkpoint

```bash
srun ${SRUN_ARGS} xtuner convert pth_to_hf $CONFIG s3://xxx/yyy/epoch_x.pth $HF_DIR

```
20 changes: 19 additions & 1 deletion xtuner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,25 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os

from mmengine.utils import digit_version

from .entry_point import cli
from .version import __version__, version_info

__all__ = ['__version__', 'version_info', 'digit_version', 'cli']
HF_CEPH_HUB = os.getenv('HF_CEPH_HUB', '')
HF_USE_CEPH = os.getenv('HF_USE_CEPH', 0) or HF_CEPH_HUB != ''
DS_CEPH_DIR = os.getenv('DS_CEPH_DIR', None)
if HF_USE_CEPH:
from .utils.fileio import (patch_hf_auto_from_pretrained,
patch_hf_save_pretrained)
patch_hf_auto_from_pretrained(HF_CEPH_HUB)
patch_hf_save_pretrained()

if DS_CEPH_DIR:
from .utils.fileio import patch_deepspeed_engine
patch_deepspeed_engine()

__all__ = [
'__version__', 'version_info', 'digit_version', 'cli', 'HF_USE_CEPH',
'DS_CEPH_DIR'
]
33 changes: 33 additions & 0 deletions xtuner/engine/_strategy/deepspeed.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine._strategy import DeepSpeedStrategy as MMEngineDeepSpeedStrategy

from xtuner import DS_CEPH_DIR
from xtuner.utils.fileio import patch_fileio


class DeepSpeedStrategy(MMEngineDeepSpeedStrategy):

Expand All @@ -20,3 +23,33 @@ def _wrap_model(self, model):
assert hasattr(wrapper.model, 'data_preprocessor')
wrapper.model.data_preprocessor.cuda()
return wrapper

def save_checkpoint(self, *args, **kwargs) -> None:
if DS_CEPH_DIR:
from os import path as osp
work_dir_prefix = osp.split(self.work_dir)[0]

filename = kwargs['filename'].replace(work_dir_prefix, DS_CEPH_DIR)
kwargs['filename'] = filename
with patch_fileio():
super().save_checkpoint(*args, **kwargs)
else:
super().save_checkpoint(*args, **kwargs)

def load_checkpoint(self, *args, **kwargs) -> None:
if DS_CEPH_DIR:

with patch_fileio():
checkpoint = super().load_checkpoint(*args, **kwargs)
else:
checkpoint = super().load_checkpoint(*args, **kwargs)
return checkpoint

def resume(self, *args, **kwargs) -> None:
if DS_CEPH_DIR:

with patch_fileio():
checkpoint = super().resume(*args, **kwargs)
else:
checkpoint = super().resume(*args, **kwargs)
return checkpoint
10 changes: 9 additions & 1 deletion xtuner/tools/model_converters/pth_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import shutil

from mmengine.config import Config, DictAction
from mmengine.fileio import PetrelBackend, get_file_backend

from xtuner.configs import cfgs_name_path
from xtuner.model.utils import guess_load_checkpoint
Expand Down Expand Up @@ -63,7 +64,14 @@ def main():

model = BUILDER.build(cfg.model)

state_dict = guess_load_checkpoint(args.pth_model)
backend = get_file_backend(args.pth_model)
if isinstance(backend, PetrelBackend):
from xtuner.utils.fileio import patch_fileio
with patch_fileio():
state_dict = guess_load_checkpoint(args.pth_model)
else:
state_dict = guess_load_checkpoint(args.pth_model)

model.load_state_dict(state_dict, strict=False)
print(f'Load PTH model from {args.pth_model}')

Expand Down
Loading

0 comments on commit 076375d

Please sign in to comment.