Skip to content

Commit

Permalink
fix grad flow
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Jul 10, 2024
1 parent 89a36f9 commit 1068f89
Show file tree
Hide file tree
Showing 24 changed files with 431 additions and 199 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,5 @@ wandb/
*.pt
*.ipynb
tests/.test_cache/*

debug/*
76 changes: 7 additions & 69 deletions examples/fp8/ablations/configs/sanity_bf16.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,12 @@ checkpoints:
# resume_checkpoint_path: checkpoints
save_initial_state: false
data_stages:
# - data:
# dataset:
# dataset_overwrite_cache: false
# dataset_processing_num_proc_per_process: 1
# hf_dataset_config_name: null
# hf_dataset_or_datasets: roneneldan/TinyStories
# hf_dataset_splits: train
# text_column_name: text
# num_loading_workers: 1
# seed: 42
# name: Stable Training Stage
# start_training_step: 1
# - data:
# dataset:
# dataset_overwrite_cache: false
# dataset_processing_num_proc_per_process: 1
# hf_dataset_config_name: null
# hf_dataset_or_datasets: stas/openwebtext-10k
# hf_dataset_splits: train
# text_column_name: text
# num_loading_workers: 1
# seed: 42
# name: Annealing Phase
# start_training_step: 10
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: JeanKaddour/minipile
hf_dataset_or_datasets: nanotron/minipile_100_samples
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
Expand All @@ -46,15 +22,15 @@ general:
consumed_train_samples: null
ignore_sanity_checks: true
project: fp8_for_nanotron
run: bfloat16_2_layers_and_seq_len_16_and_micro_batch_256_and_lr_2.0e-4
run: bfloat16_2_layers_and_seq_len_256_and_micro_batch_128_and_lr_2.0e-4_and_minipile_overfitting_and_fp8_branch_and_layernorm_and_custom_adam_and_tp_1_and_lr_1.0
seed: 42
step: null
lighteval: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
monitor_model_states: true
monitor_model_states: false
model:
ddp_bucket_cap_mb: 25
dtype: bfloat16
Expand Down Expand Up @@ -83,43 +59,11 @@ model:
tie_word_embeddings: false
use_cache: true
vocab_size: 49152
# optimizer:
# accumulate_grad_in_fp32: false
# clip_grad: 1.0
# # learning_rate_scheduler:
# # learning_rate: 0.01
# # lr_decay_starting_step: null
# # lr_decay_steps: 13
# # lr_decay_style: linear
# # lr_warmup_steps: 2
# # lr_warmup_style: constant
# # min_decay_lr: 1.0e-05

# # learning_rate_scheduler:
# # learning_rate: 0.00015
# # lr_decay_starting_step: null
# # # lr_decay_steps: null
# # lr_decay_style: linear
# # lr_warmup_steps: 60
# # lr_warmup_style: constant
# # min_decay_lr: 1.0e-05

# optimizer_factory:
# adam_beta1: 0.9
# adam_beta2: 0.999
# adam_eps: 1.0e-08
# name: adam
# torch_adam_is_fused: true
# weight_decay: 0.1
# zero_stage: 0


optimizer:
accumulate_grad_in_fp32: false
# clip_grad: 1.0
learning_rate_scheduler:
# learning_rate: 0.0015 # note: 1/2 of pythia use this for a 400m model
learning_rate: 0.0006
learning_rate: 1.0
lr_decay_starting_step: null
lr_decay_steps: null
lr_decay_style: cosine
Expand All @@ -130,7 +74,7 @@ optimizer:
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
name: adam
name: custom_adam
torch_adam_is_fused: true
weight_decay: 0.1
zero_stage: 0
Expand All @@ -140,10 +84,7 @@ parallelism:
expert_parallel_size: 1
pp: 1
pp_engine: 1f1b
tp: 2
# tp_linear_async_communication: true
# tp_mode: REDUCE_SCATTER

tp: 1
tp_linear_async_communication: false
tp_mode: ALL_REDUCE

Expand All @@ -153,13 +94,10 @@ tokenizer:
tokenizer_name_or_path: lvwerra/the-tokenizer-v1
tokenizer_revision: null
tokens:
# NOTE: micro_batch_size * sequence_length * batch_accumulation_per_replica
# = 128 * 256 * 1 = 16384
batch_accumulation_per_replica: 1
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 128 # 256
# micro_batch_size: 1
sequence_length: 256
train_steps: 24376
train_steps: 1000
val_check_interval: -1
134 changes: 134 additions & 0 deletions examples/fp8/ablations/configs/sanity_bf16_for_main_branch.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
checkpoints:
checkpoint_interval: 50000
checkpoints_path: checkpoints
checkpoints_path_is_shared_file_system: false
# resume_checkpoint_path: checkpoints
save_initial_state: false
data_stages:
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: nanotron/minipile_100_samples
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
seed: 42
name: Stable Training Stage
start_training_step: 1
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: fp8_for_nanotron
run: bfloat16_2_layers_and_seq_len_256_and_micro_batch_128_and_lr_2.0e-4_and_minipile_overfitting_and_main_branch_and_layernorm
seed: 42
step: null
lighteval: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
# monitor_model_states: true
model:
ddp_bucket_cap_mb: 25
dtype: bfloat16
init_method:
# std: 0.25 # sqrt(1/16)
# std: 0.125 # sqrt(1/64)
std: 0.04419417382415922 # sqrt(1/512)
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 1
eos_token_id: 2
# hidden_act: silu
hidden_act: gelu
hidden_size: 512
initializer_range: 0.02
intermediate_size: 2048
is_llama_config: true
max_position_embeddings: 256
num_attention_heads: 16
num_hidden_layers: 2
num_key_value_heads: 16
pad_token_id: null
pretraining_tp: 1
rms_norm_eps: 1.0e-05
rope_scaling: null
tie_word_embeddings: false
use_cache: true
vocab_size: 49152
# optimizer:
# accumulate_grad_in_fp32: false
# clip_grad: 1.0
# # learning_rate_scheduler:
# # learning_rate: 0.01
# # lr_decay_starting_step: null
# # lr_decay_steps: 13
# # lr_decay_style: linear
# # lr_warmup_steps: 2
# # lr_warmup_style: constant
# # min_decay_lr: 1.0e-05

# # learning_rate_scheduler:
# # learning_rate: 0.00015
# # lr_decay_starting_step: null
# # # lr_decay_steps: null
# # lr_decay_style: linear
# # lr_warmup_steps: 60
# # lr_warmup_style: constant
# # min_decay_lr: 1.0e-05

# optimizer_factory:
# adam_beta1: 0.9
# adam_beta2: 0.999
# adam_eps: 1.0e-08
# name: adam
# torch_adam_is_fused: true
# weight_decay: 0.1
# zero_stage: 0


optimizer:
accumulate_grad_in_fp32: false
# clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.0006
lr_decay_starting_step: null
lr_decay_steps: null
lr_decay_style: cosine
lr_warmup_steps: 200 # 10% warm up of total training steps
lr_warmup_style: linear
min_decay_lr: 0.00006
optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
name: adamW
torch_adam_is_fused: true
weight_decay: 0.1
zero_stage: 0

parallelism:
dp: 1
expert_parallel_size: 1
pp: 1
pp_engine: 1f1b
tp: 2
tp_linear_async_communication: false
tp_mode: ALL_REDUCE

profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: lvwerra/the-tokenizer-v1
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 1
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 128 # 256
sequence_length: 256
train_steps: 1000
val_check_interval: -1
8 changes: 6 additions & 2 deletions examples/fp8/ablations/configs/sanity_fp8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ fp8:
- module_name: model.decoder.0.mlp.down_proj
accum_dtype: KFLOAT16
input:
dtype: float16
dtype: FP8E4M3
margin: 0
interval: 1
weight:
dtype: FP8E4M3
margin: 0
Expand Down Expand Up @@ -199,7 +201,9 @@ fp8:
- module_name: model.decoder.1.attn.qkv_proj
accum_dtype: KFLOAT16
input:
dtype: float16
dtype: FP8E4M3
margin: 0
interval: 1
weight:
dtype: FP8E4M3
margin: 0
Expand Down
14 changes: 6 additions & 8 deletions run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,12 @@ def get_dataloader_from_data_stage(
)

# Check if we have enough samples for train_steps
total_tokens_dataset = len(dataloader.dataset) * trainer.sequence_length
num_tokens_needed_for_training = (
num_remaining_train_steps * trainer.global_batch_size * trainer.sequence_length
)
assert num_tokens_needed_for_training <= total_tokens_dataset, (
f"Dataset is too small for steps ({total_tokens_dataset} < {num_tokens_needed_for_training}), "
f"Try train_steps<={len(dataloader.dataset) // trainer.global_batch_size + trainer.iteration_step}"
)
len(dataloader.dataset) * trainer.sequence_length
(num_remaining_train_steps * trainer.global_batch_size * trainer.sequence_length)
# assert num_tokens_needed_for_training <= total_tokens_dataset, (
# f"Dataset is too small for steps ({total_tokens_dataset} < {num_tokens_needed_for_training}), "
# f"Try train_steps<={len(dataloader.dataset) // trainer.global_batch_size + trainer.iteration_step}"
# )

# Case 3: Nanosets
elif isinstance(data.dataset, NanosetDatasetsArgs):
Expand Down
8 changes: 7 additions & 1 deletion src/nanotron/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,13 @@ def get_train_dataloader(
consumed_train_samples=consumed_train_samples,
)

return DataLoader(
class CyclingDataLoader(DataLoader):
def __iter__(self):
import itertools

return itertools.cycle(super().__iter__())

return CyclingDataLoader(
train_dataset,
batch_size=micro_batch_size,
sampler=train_sampler,
Expand Down
Loading

0 comments on commit 1068f89

Please sign in to comment.