Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cp] add option to choose kv shards rotation method #684

Closed
wants to merge 394 commits into from

Conversation

XilunWu
Copy link
Contributor

@XilunWu XilunWu commented Nov 20, 2024

Stack from ghstack (oldest at bottom):

Summary
Requires the land of pytorch/pytorch#142093

wanchaol and others added 30 commits April 19, 2024 22:26
This PR adds support for Llama3 8b/70b, mainly it:
- add tiktonizer, add instructions to download tokenizer
- add options for the llama model to support Llama3
- add Llama3 8b/70b configs
ghstack-source-id: 4dd1cdb033e840e00cacd98339780424231b595b
Pull Request resolved: #257
as titled, the test tokenizer borrowed from torchtune
https://github.com/pytorch/torchtune/blob/main/tests/assets/tiktoken_small.model,
where this small test model is offline generated from
https://gist.github.com/ebsmothers/54b133dd87db6679b14318545aaa2de4 so
it should have no correlation with any specific model/data
ghstack-source-id: b4fe7f63f15bab367cf00b5d408eb43c640541c2
Pull Request resolved: #262
ghstack-source-id: a9bd1d33bf7bc9f5055a645c9639bcbe628afbfb
Pull Request resolved: #258
ghstack-source-id: 34b380d251e0a80ac5328fdaeb33a1e488f9c735
Pull Request resolved: #261
This PR is mainly to fix the spelling where activation checkpointing is
missing an n... (**checkpoiting**).
Not sure how I missed it earlier but it's glaring when you see the
charts in visual form (vs text).

<img width="578" alt="Screenshot 2024-04-24 at 2 45 25 PM"
src="https://github.com/pytorch/torchtitan/assets/46302957/a81727b2-07b1-4d69-a0c1-743d74d2aa5a">

fixed:
<img width="592" alt="Screenshot 2024-04-24 at 3 10 30 PM"
src="https://github.com/pytorch/torchtitan/assets/46302957/769e51db-4aa6-4dbd-99d8-7e691658e280">


Also add a couple line breaks to help with layout, and one or two minor
grammar updates.
Update to final legal license terms requested by Meta legal for release.
ghstack-source-id: 2b74fe48dbeae0367a41214c6d0e8b1fcd608db8
Pull Request resolved: #270
* Image was very blurry
* Markdown formatting was off
* Simplified some sentences
ghstack-source-id: 606aee2c4815173958b30ca34a3dbf8e90aed8de
Pull Request resolved: #275
ghstack-source-id: cc29739b147fe1f52bfc5b791330fd7cf1659652
Pull Request resolved: #271
1. update readme
2. small refactor to loss_parallel part
ghstack-source-id: d410f30ec715bfb4206459becb95abeed5a4ae02
Pull Request resolved: #281
ghstack-source-id: 77f650e8281dae12f2a7ccdb415be88f9abd88cc
Pull Request resolved: #283
# Summary

Add more the possible options in the configs and add a note on how to
get the dependency at the top of the file.
ghstack-source-id: dbd201ad2976537487123fa583c86ddab06a7387
Pull Request resolved: #250
as titled, fixes #286
ghstack-source-id: 932e7cce828a15c788b34f07c264e119068777fe
Pull Request resolved: #287
Runs the integration test hourly and updates signal badge. Tested on
existing integration test. I will update the badge with periodic test
signal once workflow has landed in this PR.
<img width="516" alt="Screenshot 2024-04-30 at 6 12 00 PM"
src="https://github.com/pytorch/torchtitan/assets/1779702/8adaab3d-df18-483d-a39f-5af316b7edbc">
ghstack-source-id: 9daa99020c76fdfe429b6a9ee6d44fd1dd319fc3
Pull Request resolved: #280
Adds new command ./create_seed_checkpoint.sh which largely
reuses code inside train.py to create the model and then save its
initial state as a step-0 checkpoint for use with meta-initialization
loading flow.

ghstack-source-id: 3e1aa9eab847c1f1341f22772ca8ae3688883454
Pull Request resolved: #172
ghstack-source-id: fa9aaf337b5489d88945f15b65a8ba8cc544ded6
Pull Request resolved: #295
This appears to be a holdover from a previous way the initialization
worked.

freqs_cis should already be on gpu device after initialization.

ghstack-source-id: 7159320d4ecfb436bd2193277a88c04d136e9ad0
Pull Request resolved: #298
…int (#293)

Summary:
The profiler currently maintains a counter locally and that counter is
not synchronized with the checkpointed train step. This PR fixes the
issue.
as titled. This could make 1-D and 2-D works with the lastest main
build. thanks @bdhirsh for all the fixes!

We should figure out why dynamic shape gets turned on as a follow up
ghstack-source-id: bbedad3819ab9ef90b233209c34dd1dbc846b06a
Pull Request resolved: #299
Summary:
This PR implements 2 different async checkpoint. The first one is to use
DCP.async_save another one is to use pinned memory + a seperate process
to avoid GILs issue.

ghstack-source-id: 87fb6c28d7bc3e514c0bee7646be5188f1f66bbd
Pull Request resolved: #313
XilunWu added a commit that referenced this pull request Dec 4, 2024
ghstack-source-id: 7a0226603c357bee505d32999cd8126afe4becce
Pull Request resolved: #684
wconstab and others added 4 commits December 4, 2024 13:05
Pytorch stopped releasing cu121 nightlies.

ghstack-source-id: 39850c42c5ec0a8898a208718f35392e98a427f9
Pull Request resolved: #718
ghstack-source-id: 6eb7a87df8b3585e53993684d0b9682aeb99cfe5
Pull Request resolved: #719
XilunWu added a commit that referenced this pull request Dec 5, 2024
ghstack-source-id: 030f53c4e1520715f29f35f8cacbe3f1c939b9e4
Pull Request resolved: #684
XilunWu added a commit that referenced this pull request Dec 5, 2024
ghstack-source-id: 1021f844de071c43428d4dfe9d5ae08c653b7076
Pull Request resolved: #684
XilunWu added a commit that referenced this pull request Dec 5, 2024
ghstack-source-id: 1021f844de071c43428d4dfe9d5ae08c653b7076
Pull Request resolved: #684
msaroufim and others added 8 commits December 5, 2024 11:56
Fixes #662

followed @fegin advice to test this and indeed things are working
https://gist.github.com/msaroufim/2925b3f17b631bf370a49f185b6e169d

```
[checkpoint]
enable_checkpoint = true
folder = "checkpoint"
interval_type = "steps"
interval = 10
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
load_step = 10
```
EDIT: removed the specific new functions in hf_datasets.py and kept most
of the doc changes and will not go for a registration based API

Fixes #311

This PR describes the status quo of how new datasets should be
registered today, in that there's the implicit assumption that people
are installing torchtitan from source and updating hf_datasets.py to
support new datasets. As an example I passed in the wikipedia dataset

The main "nice" thing about this PR is that `class HuggingFaceDataset`
is now agnostic to the c4 dataset which makes it easier for new people
to add datasets without reading the rest of the file

There's another direction this PR could have went in which was to allow
custom dataset registration, the benefit is people can support new
datasets without installing titan from source but registration apis can
feel kinda "bureaucratic" and presumably people would need to register
the dataset somewhere, probably `train.py`?

Not totally sure which is more in line with the repo's goals so opening
this PR to discuss

```python
def register_dataset(
    name: str,
    loader: Callable[[str, Dict[str, Any]], Any],
    processor: Callable[[Dict[str, Any]], str],
    path: Optional[str] = None,
) -> None:

    DATASET_LOADERS[name] = loader
    DATASET_TEXT_PROCESSORS[name] = processor

def wikipedia_loader(dataset_path: str, **kwargs):
    return load_dataset(
        dataset_path,
        name="20220301.en",
        split="train", 
        streaming=True,
        trust_remote_code=True,
    )

def wikipedia_processor(sample: Dict[str, Any]) -> str:
    return f"{sample['title']}\n\n{sample['text']}"

register_dataset(
    name="wikipedia",
    loader=wikipedia_loader,
    processor=wikipedia_processor,
    path="wikipedia"
)
```
The tp_enabled parameter was not being used within the apply_fsdp
function.
Removing it simplifies the function signature while maintaining all
existing functionality.
Add `--training.enable_optimizer_in_backward` to enable optimizer in the
backward
With such feature, `optim.step()` and `optim.zero_grad()` can be
conducted during `loss.backward()`, optimizing the peak memory at
training.

Caution: optimizer in the backward would free the gradients during
backward, could not work with gradient clipping.

Results:
TL;DR: optimizer_in_backward could achieve a memory saving around 3GiB
in llama3 test .toml with slightly improvements on tps and mfu
We conduct tests based on `llama3_8b.toml` by running
`CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh`.
During the tests, we change `compile`, `enable_optimizer_in_backward`,
`activation_checkpoint.mode` to show the performance raised by
optimizer_in_backward in loggings on memory cost, tps and mfu.
<img width="761" alt="Screenshot 2024-12-02 at 3 59 41 PM"
src="https://github.com/user-attachments/assets/0079bebc-04b2-4116-a61c-6347c3f7adb6">



Analysis:
When using optim in backward, the amount of memory reserved to the
reduce-scatter stream is smaller, causing a lower **reserved memory**
shown in the table, around 2-3 GiB
Take `activation_checkpoint.mode = selective` as an example: 
<img width="628" alt="Screenshot 2024-12-02 at 3 31 46 PM"
src="https://github.com/user-attachments/assets/cf1eddf8-5c7c-4244-8692-0141178b5596">


Also, when `activation_checkpoint.mode = full`, there's around
`(num_params-2)*(param_grad_mem) - peak_memory_at_forward` decrease in
**active memory**, around 1.2GiB in the figure below.
<img width="627" alt="Screenshot 2024-12-02 at 2 57 01 PM"
src="https://github.com/user-attachments/assets/1e8b71d0-0f2e-470f-a66e-70ca26367ea9">
DTensor has existing RNG management. It requires a shared seed for every
rank in its 'world' (SPMD world).  Then it manages offsets per rank
using its own RNG tracker to ensure same or different random values
across ranks depending on the device-mesh and the type of sharding on
the current operation being performed.  (TODO: link to docs)

When used together with pipeline parallelism, it is important to use a
different seed for each separate SPMD world.  E.g. if the user specified
seed 1234, then we can literally use 1234 for all the ranks on PP=0, but
then we should use a different seed (e.g. 1234 + 1) for ranks on PP=1.
This partitions the world into PP separate SPMD worlds and uses a unique
seed for each SPMD world.

Control 'deterministic' mode separately from rng seed

The use case for 'deterministic' mode may be more for debugging, while
users may want to control RNG seeds used for real runs.

ghstack-source-id: b6cc5dbca6ffef616111a354c4c6bdf80ab8335f
Pull Request resolved: #689
ghstack-source-id: 46d2eebe1b19130597093b74d9575e9c8b4a9d8e
Pull Request resolved: #707
…on method"


**Summary**
Requires the land of pytorch/pytorch#142093

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Dec 11, 2024
ghstack-source-id: bbdd91d09aaf75a80389da422972aa81a222f958
Pull Request resolved: #684
…ent (#720)

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* #684
* #685
* __->__ #720

**Summary**
This PR improves the design of DeviceMesh hierarchy in torchtitan. Now,
we define all device meshes except `world_mesh` into 2 categories:
1. Basic mesh: those meshes defined in job `.toml` file by users. This
include `pp` (`pipeline_parallel_degree`), `dp_replicate`
(`data_parallel_replicate_degree`), `dp_shard`
(`data_parallel_shard_degree`), `tp` (`tensor_parallel_degree`), and
`cp`(`context_parallel_degree`).
2. Synthesized mesh (or called "derived mesh"): meshes that are
synthesized from basic mesh by `_flatten()`. If the mesh in synthesized
from a single mesh, then it is equivalent to aliasing. So far we utilize
2 synthesized meshes: `dp` and `dp_shard_cp`. The `dp` mesh is used for
data loading and the `dp_shard_cp` mesh is used for model params
sharding.

**Test**
CI
…on method"


**Summary**
Requires the land of pytorch/pytorch#142093

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Dec 11, 2024
ghstack-source-id: 3856be3ec600bafd425b6b0b22c936800e13c92a
Pull Request resolved: #684
XilunWu added a commit that referenced this pull request Dec 11, 2024
…ss and lower mem usage (#685)

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* #684
* __->__ #685

**Summary**
Previously CP forgot to shard the model via `apply_fsdp` when DP is not
combined with CP. This leads to high peak memory usage and diverging
loss.

**Test**
1. modify `train_configs/llama3_8b.toml`
```
steps = 20
context_parallel_degree = 8
```
2.  run training on 8xH100 GPUs
`CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8
LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh`
Before: CUDA OutOfMemory 
After: successful 20-steps training
@XilunWu XilunWu changed the base branch from main to gh/wconstab/16/base December 11, 2024 22:46
@XilunWu XilunWu closed this Dec 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.