Skip to content

Commit

Permalink
adapt it to the current lighteval main
Browse files Browse the repository at this point in the history
  • Loading branch information
eliebak committed Sep 20, 2024
1 parent 11d60c8 commit 5e8361c
Show file tree
Hide file tree
Showing 12 changed files with 520 additions and 528 deletions.
108 changes: 27 additions & 81 deletions create_config.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,42 @@
import os
from pathlib import Path
import subprocess
from datetime import datetime
import math
import torch

import argparse
import math
from datetime import datetime
from pathlib import Path

from nanotron.models.llama import LlamaConfig

import torch
from nanotron.config import (
AdamWOptimizerArgs,
CheckpointsArgs,
Config,
DataArgs,
NanosetDatasetsArgs,
PretrainDatasetsArgs,
S3UploadArgs,
CheckpointsArgs,
DatasetStageArgs,
GeneralArgs,
LightEvalConfig,
LightEvalLoggingArgs,
LightEvalTasksArgs,
LoggingArgs,
LRSchedulerArgs,
ModelArgs,
OptimizerArgs,
AdamWOptimizerArgs,
ParallelismArgs,
PretrainDatasetsArgs,
RandomInit,
TokenizerArgs,
TokensArgs,
DatasetStageArgs,
)
from nanotron.models.llama import LlamaConfig

if __name__ == "__main__":
###########################################
## ADAPT TO YOUR ENVIRONMENT (toy example of smollm-135M on 1 GPU)

HF_USER_OR_ORG = "eliebak"
HF_USER_OR_ORG = None
TRAIN_STEPS = 100
CHECKPOINT_INTERVAL = 200
SAVE_NAME="smollm-135M-1gpu-toy"

SAVE_NAME = "smollm-135M-1gpu-toy"

###########################################

parser = argparse.ArgumentParser()
parser.add_argument("--save-path", help="path to save the configuration file", type=str, default="yaml")
parser.add_argument("--seed", help="seed", type=int, default=8)
parser.add_argument("--launch", action="store_true", help="Launch the configuration immediately")
parser.add_argument("--logs-path", help="path to the logs folder", type=str)
parser.add_argument("--run", help="name of the run", type=str)
parser.add_argument("--slurm", help="use slurm", action="store_true")
parser.add_argument("--nodes", help="specify the number of nodes", type=int)
args = parser.parse_args()

timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
Expand All @@ -78,7 +63,7 @@
rope_scaling=None,
tie_word_embeddings=True,
use_cache=True,
vocab_size=49152,
vocab_size=49152,
)

# Uncomment to evaluate the model on a set of tasks with lighteval during the training.
Expand All @@ -100,24 +85,16 @@
# ),
# batch_size=16,
# logging=LightEvalLoggingArgs(
# local_output_path="lighteval-logs",
# private=True,
# push_details_to_hub=True,
# push_results_to_hub=True,
# push_results_to_tensorboard=True,
# hf_user_or_org=HF_USER_OR_ORG,
# hub_repo_results="lighteval-results",
# hub_repo_details="lighteval-details",
# hub_repo_tensorboard="smollm-evals-visualization",
# output_dir=None,
# push_to_hub=True,
# push_to_tensorboard=True,
# public_run=False,
# results_org=HF_USER_OR_ORG,
# tensorboard_metric_prefix="eval",
# ),
# temp_dir = "temp_dir",
# slurm_template="slurm/run_eval.slurm.jinja",
# # slurm_template="slurm/run_eval_s3.slurm.jinja", if s3

# )

lighteval = None
# lighteval = None

checkpoints = CheckpointsArgs(
# checkpoints_path="checkpoints",
Expand All @@ -137,7 +114,7 @@
)

tokens = TokensArgs(
batch_accumulation_per_replica=8,
batch_accumulation_per_replica=1,
micro_batch_size=8,
sequence_length=2048,
train_steps=TRAIN_STEPS,
Expand All @@ -147,7 +124,7 @@
model = ModelArgs(
model_config=model_config,
init_method=RandomInit(
std=1/math.sqrt(model_config.hidden_size),
std=1 / math.sqrt(model_config.hidden_size),
),
dtype=torch.bfloat16,
)
Expand All @@ -164,12 +141,11 @@
lr_warmup_steps=10,
lr_warmup_style="linear",
lr_decay_style="linear",
lr_decay_steps = 20,
lr_decay_starting_step=80 ,
lr_decay_steps=20,
lr_decay_starting_step=80,
min_decay_lr=0,
)


optimizer = OptimizerArgs(
zero_stage=0,
weight_decay=0.01,
Expand Down Expand Up @@ -197,12 +173,12 @@
# s5cmd_path="PATH_TO_S5CMD",
# )

data_stages=[
data_stages = [
DatasetStageArgs(
data=DataArgs(
# 1. Un-tokenized dataset from HuggingFace
dataset=PretrainDatasetsArgs(
hf_dataset_or_datasets="HuggingFaceTB/smollm-corpus", # feel free to replace it by a smaller one if you don't have enough memory
hf_dataset_or_datasets="HuggingFaceTB/smollm-corpus", # feel free to replace it by a smaller one if you don't have enough memory
hf_dataset_splits="train",
hf_dataset_config_name="cosmopedia-v2",
text_column_name="text",
Expand Down Expand Up @@ -250,42 +226,12 @@
lighteval=lighteval,
)

save_path= Path(args.save_path)
save_path = Path(args.save_path)
save_path.mkdir(parents=True, exist_ok=True)

config_path_yaml = save_path / f"{SAVE_NAME}.yaml"
config.save_as_yaml(config_path_yaml)

print(f"💾 Configuration saved in: {str(save_path)}")

if args.launch:

# Sanity check for logs_path and run
if not args.logs_path:
raise ValueError("--logs_path must be defined. Please provide a path for the logs.")
if not args.run:
raise ValueError("--run must be defined. Please provide a name for the run.")

launcher_path = Path("launcher.py")
if not launcher_path.exists():
raise FileNotFoundError(f"Launcher not found at {launcher_path}. Please ensure the file exists or change the launcher path in the create_config.py file.")
launch_command = [
"python", str(launcher_path),
"--config-path", str(config_path_yaml),
]
launch_command.extend([
"--logs-path", args.logs_path,
"--run", args.run
])
if args.slurm:
launch_command.append("--slurm")

if args.nodes:
launch_command.extend(["--nodes", str(args.nodes)])


print(f"🧪 Launching configuration with command: {' '.join(launch_command)}")
subprocess.run(launch_command, check=True)
else:
print("To launch this configuration, run:")
print(f"python 'launcher.py' configs/{str(config_path_yaml)}")
print("To launch this configuration, run:")
print(f"python launcher.py --config-path configs/{str(config_path_yaml)}")
Loading

0 comments on commit 5e8361c

Please sign in to comment.