Skip to content

Commit

Permalink
update doc
Browse files Browse the repository at this point in the history
  • Loading branch information
SeaOfOcean committed Sep 4, 2024
1 parent 6ca98ca commit 2ca0517
Show file tree
Hide file tree
Showing 17 changed files with 102 additions and 928 deletions.
4 changes: 2 additions & 2 deletions chatlearn/runtime/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,15 @@ def get_kwarg(key):
self.onload()
generation_batch_size = self.module_args.generation_batch_size
final_results = None
if not trainable and generation_batch_size and not hasattr(self, 'generate_vllm'):
if not trainable and generation_batch_size:
# split into micro-batches if generation_batch_size < input_batch, then concat the results
# this happens when different models have difference batch sizes
input_batch = 0
for value in args[0].values():
input_batch = len(value)
break
input_data = args[0]
if input_batch > generation_batch_size:
if input_batch > generation_batch_size and not hasattr(self, 'generate_vllm'):
args = list(args)
batches = split_along_batch(input_data, generation_batch_size)
results = []
Expand Down
5 changes: 1 addition & 4 deletions docs/en/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@ ChatLearn Documentation
:maxdepth: 1
:caption: Programming

programming/rlhf
programming/dpo
programming/online_dpo
programming/vllm
programming
config_yaml
advanced

Expand Down
75 changes: 33 additions & 42 deletions docs/en/programming/rlhf.md → docs/en/programming.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,8 @@
# RLHF
# Programming Interface
This chapter will introduce the programming interface of ChatLearn.

This section will introduce the programming interface of ChatLearn. We will start with the main file and explain how to construct the `RLHFEngine`. Then, we will discuss how to write models.


## Main Training File

The following is an example of a user's main training file:

1. Call `chatlearn.init()` to initialize the runtime environment for RLHF.
2. Define the models required for training. Each model needs to have a unique `model_name`. Different models are distinguished by their `model_name` when configuring the model parameters. Please refer to the [training configuration](../config_yaml) for more details.
3. Define the engine [RLHFEngine](../api/engine.rst).
4. Set up the training dataset.
5. Call `engine.learn` to start the RLHF training.
## Training Main File
The following is an example of the user's training main file.

```python
from examples.megatron.models import PolicyInference
Expand All @@ -20,63 +11,60 @@ from examples.megatron.models import PolicyTrainer
from examples.megatron.models import RewardInference
from examples.megatron.models import ValueInference
from examples.megatron.models import ValueTrainer

import chatlearn
from chatlearn import RLHFEngine

# init
chatlearn.init()

# define models
policy_model = PolicyInference("policy")
reference_model = PolicyReference("reference")
reward_model = RewardInference("reward")
value_model = ValueInference("value")
ppo_policy_model = PolicyTrainer("ppo_policy")
ppo_value_model = ValueTrainer("ppo_value")

# define engine
engine = RLHFEngine(policy_model,
reference_model,
reward_model,
value_model,
ppo_policy_model,
ppo_value_model)

# set dataset
train_prompts = ["test"] * 4096
engine.set_dataset(train_prompts)

# start rlhf training
engine.learn()
```

1. Call `chatlearn.init()` to initialize the runtime environment of ChatLearn.
2. Define models, where each model needs to define a unique `model_name`. Different model configurations are distinguished by `model_name`. See [training configuration file](config_yaml) for details.
3. Define the engine [RLHFEngine](api/engine.rst).
4. Define evaluator (optional)
4. Set the training dataset.
5. Call `engine.learn` to start the training for alignment.

## Model Definition
For a complete example, please refer to [train_rlhf_llama.sh](https://github.com/alibaba/ChatLearn/blob/main/examples/megatron/scripts/train_rlhf_llama.sh)

![image.png](../../images/class.png)
## Define Model

User-defined models need to inherit from `BaseModule` or its subclasses. `TorchModule` is the wrapper for general Torch models, `MegatronModule` is the wrapper for Megatron models, `DeepSpeedModule` is the wrapper for DeepSpeed models, and `VLLMModule` is the wrapper for vLLM generation models, If you want to use VLLMModule, refer to [vLLM generation](vllm.md). If the user's RLHF modeling is based on Megatron-LM, they can directly inherit from `MegatronModule` to complete the model construction.
![image.png](../images/class.png)

Here are examples of model construction for both the inference and training models, using inheritance from `MegatronModule`:
1. For the inference model, users need to implement the `setup` and `forward_step` methods. In `setup`, define the model, initialize parameters, and define global parameters. In `forward_step`, implement the logic required for one forward pass of the model.
2. For the training model, users need to implement the `setup` and `train_step` methods. In `train_step`, implement the logic required for one training step.
3. In addition, the `PolicyInference` model needs to implement the `build_dataset` method to construct the prompt dataset.
The user's model needs to inherit `BaseModule` or its subclasses. `TorchModule` is a general encapsulation of Torch models, `MegatronModule` is an encapsulation of Megatron models, `DeepSpeedModule` is an encapsulation of DeepSpeed models, `VLLMModule` is an encapsulation of vLLM models. The following two code snippets show examples of model construction for inference and training:
1. For inference models, users need to implement the `setup` and `forward_step` methods. In `setup`, implement model definition, parameter initialization, global parameter definition, etc. In `forward_step`, implement the logic required for one forward step of the model.
2. For training models, users need to implement the `setup` and `train_step` methods. In `train_step`, implement the logic required for training a step.
3. In addition, the first model of the engine needs to implement the `build_dataset` method to construct the prompt dataset.

For more API information, refer to [RLHF Module API](../api/module.rst).
Refer to [Module API](api/module.rst) for more API information.

```python
from chatlearn import MegatronModule


class PolicyInference(MegatronModule):
from chatlearn import VLLMModule

class PolicyInference(VLLMModule):
def __init__(self, name):
"""
Args:
name: model name
"""

def setup(self):
"""
1. define model, self.model = xxx
Expand All @@ -100,7 +88,6 @@ class PolicyInference(MegatronModule):
def build_dataset(self, train_prompts, is_eval=False):
"""
Build prompt dataset. The implementation of build_dataset is exclusive to PolicyInference, whereas other models are not required to adopt it.
Args:
train_prompts: prompts provided by RLHFEngine.set_dataset(train_prompts)
is_eval: eval mode
Expand All @@ -113,9 +100,7 @@ class PolicyInference(MegatronModule):
```python
from chatlearn import MegatronModule


class PolicyTrainer(MegatronModule):

def setup(self):
"""
1. define model, self.model = xxx
Expand All @@ -124,7 +109,6 @@ class PolicyTrainer(MegatronModule):
4. init model parameters
"""
pass

def train_step(self, data, iteration):
"""
Perform train_step for one batch, including a list of micro-batches
Expand All @@ -134,11 +118,18 @@ class PolicyTrainer(MegatronModule):
"""
pass
```
## Dataset

To utilize a user-defined dataset in the ChatLearn framework, users need to inherit from `torch.utils.data.Dataset` and specify the `collate_fn` method. Inheriting from `torch.utils.data.Dataset` requires users to override the `__init__`, `__getitem__`, and `__len__` methods according to their specific requirements, as explained in the PyTorch tutorial on creating custom datasets see [Creating a Custom Dataset for Your Files](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#creating-a-custom-dataset-for-your-files). The `collate_fn` method allows users to customize the data collation process, as documented in the [collate_fn API](https://pytorch.org/docs/stable/data.html#dataloader-collate-fn). If users do not require custom data collation, they can set `self.collate_fn = None` in the `__init__` method.
## Define Engine

![image.png](../images/engine.jpg)

Here is an example that demonstrates how to inherit from torch.utils.data.Dataset and specify the collate_fn method:
ChatLearn provides a series of built-in Engine types that users can directly use to construct training. Additionally, users can also construct custom engines to customize the model flow, as described in [Custom Model Flow](tutorial/custom_model_flow.md).

## Define Evaluator
The use of an evaluator can be found in [Constructing Evaluator](tutorial/evaluator.md).

## Dataset
The Dataset used by the user needs to inherit `torch.utils.data.Dataset` and specify the `collate_fn` method. To inherit `torch.utils.data.Dataset`, users need to override the `__init__`, `__getitem__`, and `__len__` methods as per the requirements (see [Creating a Custom Dataset for Your Files](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#creating-a-custom-dataset-for-your-files)). The `collate_fn` method allows users to customize data collation (see [collate-fn](https://pytorch.org/docs/stable/data.html#dataloader-collate-fn)). If users do not need to customize data collation, they should set `self.collate_fn = None` in the `__init__` method.

```bash
class PromptDataset(Dataset):
Expand All @@ -147,16 +138,16 @@ class PromptDataset(Dataset):
"""
def __init__(self, data):
self.data = data

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
return {"query": self.data[idx]}

def collate_fn(self, samples):
batched_data = {}
for sample_key, sample_value in samples.items():
batched_data[sample_key] = torch.stack(sample_value)
return batched_data
```
```
125 changes: 0 additions & 125 deletions docs/en/programming/dpo.md

This file was deleted.

56 changes: 0 additions & 56 deletions docs/en/programming/online_dpo.md

This file was deleted.

Loading

0 comments on commit 2ca0517

Please sign in to comment.