Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lighteval support after checkpoint, UX refactor #222

Open
wants to merge 45 commits into
base: main
Choose a base branch
from

Conversation

eliebak
Copy link
Contributor

@eliebak eliebak commented Aug 24, 2024

Info

  • Follow up of the s3upload PR220

Changes

  1. Automated Lighteval Job Execution

    • Triggers after each checkpoint save (locally or on S3)
    • Implemented via:
      • self.s3_mover.post_upload_callback (S3 enabled)
      • self.post_checkpoint_callback (local)
  2. Enhanced Wandb

    • Store ID and project in general args
    • Lof tokens in the x axis instead of steps
    • Enables logging for both training and evaluation
  3. New scripts for users: launcher.py and create_config.py

  4. Fix model parameter counting

    • New function: model_config.get_llama_param_count()
    • Updated get_flops() method:
      • Corrected block_compute_costs:
        • Included layer norm and GQA calculation for FLOPS
  5. Customizable Slurm Folder

    • Separate modifications for training and evaluation jobs

Usage Instructions

create_config.py

  1. Modify the Python file as needed
  2. Run: python create_config.py --save_path <path_to_save>

launcher.py

Basic usage:
python launcher.py --logs_path <path_to_log> --run <name_of_the_run>

  • Creates timestamped folder in logs_path/run
  • Numbering: Starts at 001 for first run

Additional options:

  • Slurm: Add --slurm --nodes <number_of_node>
  • Override config arguments:
    python nanotron/launcher.py [other_args] --override KEY1=VALUE1 KEY2=VALUE2 ...

Minor Changes

  1. Changed 'G' to 'B' in human format function
  2. Lighteval branch fixes:
  • Added evaluation to Wandb
  • Set trust_remote_code=True globally
  • Created repository for uploading details and results

Recommended Workflow

  1. Create a new folder for each project
  2. Copy and customize in the folder:
  • launcher.py
  • create_config.py
  • slurm folder (if using Slurm)

Fancy prints ✨

  • Command
    python launcher.py --base-config "smollm-360M-4nodes" --run "smol-360M" --override "tokens.micro_batch_size=8" "optimizer.learning_rate_scheduler.learning_rate=1e-3" --slurm --nodes 4

  • Output

⇄ Applied overrides:
  tokens.micro_batch_size=8
  optimizer.learning_rate_scheduler.learning_rate=1e-3

🏋️  Model Parameters:
┌───────────────────────┬────────────────────────┐
│ Total Parameters      │                   362M │
│ Layers                │                     32 │
│ Attention Heads       │                     15 │
│ Hidden Size           │                    960 │
│ Intermediate Size     │                   2560 │
│ Context Length        │                   2048 │
│ Tokenizer             │ HuggingFaceTB/cosmo2-t │
│ Vocab Size            │                  49152 │
└───────────────────────┴────────────────────────┘


🎛️ Parallelism Configuration:
┌───────────────────────┬────────────────────────┐
│ Nodes                 │                      4 │
│ Total GPUs            │                     32 │
│ Data Parallel (DP)    │                     32 │
│ Pipeline Parallel (PP)│                      1 │
│ Tensor Parallel (TP)  │                      1 │ 
└───────────────────────┴────────────────────────┘


📙 Training Configuration:
┌───────────────────────┬────────────────────────┐
│ Total Tokens          │                   629B │
│ Global Batch Size     │              1,048,576 │
│ Batch Size (per GPU)  │                 32,768 │
└───────────────────────┴────────────────────────┘


📊 Learning Rate Schedule:
┌───────────────────────┬────────────────────────┐
│ Initial LR            │               1.00e-03 │
│ Warmup Style          │                 linear │
│ Warmup Steps          │                   5000 │
│ Decay Style           │                 1-sqrt │
│ Decay Start Step      │                 500000 │
│ Decay Steps           │                 100000 │
│ Final LR              │               0.00e+00 │
└───────────────────────┴────────────────────────┘


🔧 Optimization Configuration:
┌───────────────────────┬────────────────────────┐
│ Optimizer             │     AdamWOptimizerArgs │
│ Weight Decay          │               1.00e-02 │
│ Gradient Clipping     │                   1.00 │
│ Adam Epsilon          │               1.00e-08 │
│ Adam Beta1            │                   0.90 │
│ Adam Beta2            │                   0.95 │
│ ZeRO Stage            │                      0 │
│ FP32 Grad Accumulation│                   True │
└───────────────────────┴────────────────────────┘

🚀 Slurm job launched with id=8552705
    🤖 Slurm Configuration Details:
        qos: high
        gpus_per_node: 8
        cpus_per_task: 88
        
    📁 Log structure:
    logs/smol-360M/
    └── run001_2024-09-03_05-45-23/
        ├── config/
        ├── launch-script/
        ├── slurm-logs/
        └── evals/
            ├── launch-config/
            └── logs/

@eliebak eliebak changed the title [WIP] Add lighteval support after checkpoint (when s3 upload) + launcher.py [WIP] Add lighteval support after checkpoint + launcher.py Aug 30, 2024
@eliebak eliebak marked this pull request as draft August 30, 2024 03:43
@3outeille 3outeille self-assigned this Aug 30, 2024
@eliebak eliebak changed the title [WIP] Add lighteval support after checkpoint + launcher.py lighteval support after checkpoint, UX refactor Sep 3, 2024
@zzhhjjj
Copy link
Collaborator

zzhhjjj commented Sep 13, 2024

What's the reason we have to "Install lighteval from https://github.com/eliebak/lighteval/tree/current-nanotron otherwise it will not work"

@eliebak
Copy link
Contributor Author

eliebak commented Sep 13, 2024

What's the reason we have to "Install lighteval from https://github.com/eliebak/lighteval/tree/current-nanotron otherwise it will not work"

I changed the installation to https://github.com/huggingface/lighteval/tree/nanotron-compatible in the .toml file. It's an older version of lighteval that work with the current code, at the time i made the change lighteval was changing a bit so used an older commit that was the one for the FW ablation. I (or other ppl) can update to the current version after :)

@@ -287,16 +304,30 @@ def pre_training(self, *args, **kwargs):
rank=0,
)

current_time = datetime.datetime.now().strftime("%d/%m/%Y_%H:%M:%S")
if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0] and wandb is not None:
datetime.datetime.now().strftime("%d/%m/%Y_%H:%M:%S")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

# Log initial tokens to set the starting point
wandb.log({"Tokens": initial_tokens})

print(f"Initial Tokens: {initial_tokens}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

log_rank(message, logger=logger, level=logging.INFO, rank=0)

if wandb is not None and dist.get_rank(self.parallel_context.dp_pg) == 0:
if self.config.general.wandb_id is None:
self.config.general.wandb_id = wandb.run.id
self.config.general.wandb_project = wandb.run.project
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't we keep it as the default in the initial config?

Copy link
Contributor Author

@eliebak eliebak Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was because we previously need to store the wandb.run.id (and project) to pass it to lighteval to save the eval on the same wandb run. But we don't save to wandb anymore so no need for this will remove it thanks for noticing it!

"Update the wandb run due too resume from checkpoint", logger=logger, level=logging.WARNING, rank=0
)
self.config.general.wandb_id = wandb.run.id
self.config.general.wandb_project = wandb.run.project
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

from contextlib import ExitStack, contextmanager
from typing import ContextManager, List, Optional
import json
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like we don't use it? could u rerun precommit please

@@ -60,6 +60,9 @@ s3 = [
"s5cmd",
]

lighteval = [
"lighteval[nanotron]@git+https://github.com/huggingface/lighteval.git",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nicee

tokens=tokens,
optimizer=optimizer,
data_stages=data_stages,
lighteval=lighteval,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not defined variable?

@@ -1,7 +1,10 @@
import os
from pathlib import Path
from typing import Optional, cast
from datasets.download.streaming_download_manager import xPath
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rerun precommit

class S3UploadArgs:
"""Arguments related to uploading checkpoints on s3"""

remove_after_upload: bool
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be True by default?

Comment on lines 399 to 401
if hasattr(self, "_post_init_done"):
return
self._post_init_done = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems uncessary

eval_slurm_config: Optional[str] = None
eval_slurm_template: Optional[str] = None
lighteval_config_path: Optional[str] = None
is_s3_available: Optional[bool] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove? because we set in it line 406 and and 409?

Copy link
Contributor Author

@eliebak eliebak Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it's more clear to have it in the config and define it later no? don't know what's the standard practice here (update: i remove it i think it's better you're right)

@@ -490,3 +536,14 @@ def get_config_from_file(
)
config.model.model_config = model_config_class(**config.model.model_config)
return config


def save_as_yaml(config, config_class, file_path: str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add type hints

Copy link
Member

@xrsrke xrsrke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall. I left a few requested changes and one question: When is the lighteval supposed to run? It doesn’t seem to launch any lighteval runs after each checkpoint is saved on slurm? Thanks

@eliebak
Copy link
Contributor Author

eliebak commented Sep 24, 2024

Thanks for the review, made the changes! :) It launch a slurm job using run_slurm_one_job for the evaluation here if no s3 and here if you upload ckpt to s3 (same way as in brrr for the s3 case)

Comment on lines 624 to 629
# LogItem("consumed_samples", self.consumed_train_samples, "human_format"), # , "12d"),
LogItem(
"consumed_tokens",
self.metadata.consumed_train_samples * self.config.tokens.sequence_length,
"human_format",
), # , "12d"),
# LogItem(
# "consumed_tokens",
# self.metadata.consumed_train_samples * self.config.tokens.sequence_length,
# "human_format",
# ), # , "12d"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this comment out ?

Comment on lines +35 to +36
elif isinstance(value, xPath):
result[field.name] = str(value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why xPath ? I thought we removed it for Path

Copy link
Contributor Author

@eliebak eliebak Sep 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we still need it in s3upload to upload to s3 (with path it consider s3 path as local path)

Copy link
Member

@xrsrke xrsrke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I think you made some breaking changes, the checkpoint saving is broke. reproduce:

CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc-per-node 8 --master_port=25621 run_train.py --config /fsx/phuc/temp/env_for_review_pr/nanotron/pr_config_new.yaml
trainer.train(dataloader)  File "/fsx/phuc/temp/env_for_review_pr/nanotron/src/nanotron/trainer.py", line 489, in train
  File "/fsx/phuc/temp/env_for_review_pr/nanotron/src/nanotron/trainer.py", line 489, in train
  File "/fsx/phuc/temp/env_for_review_pr/nanotron/src/nanotron/trainer.py", line 489, in train

  File "/fsx/phuc/temp/env_for_review_pr/nanotron/src/nanotron/trainer.py", line 489, in train
  File "/fsx/phuc/temp/env_for_review_pr/nanotron/run_train.py", line 237, in <module>
    trainer.train(dataloader)
  File "/fsx/phuc/temp/env_for_review_pr/nanotron/src/nanotron/trainer.py", line 489, in train
    self.save_checkpoint()
  File "/fsx/phuc/temp/env_for_review_pr/nanotron/src/nanotron/trainer.py", line 912, in save_checkpoint
    checkpoint_path = Path(checkpoints_path) / f"{self.iteration_step}"
  File "/fsx/phuc/temp/env_for_review_pr/env/lib/python3.10/pathlib.py", line 960, in __new__
    self = cls._from_parts(args)
  File "/fsx/phuc/temp/env_for_review_pr/env/lib/python3.10/pathlib.py", line 594, in _from_parts
    drv, root, parts = self._parse_args(args)
    trainer.train(dataloader)
  File "/fsx/phuc/temp/env_for_review_pr/nanotron/src/nanotron/trainer.py", line 489, in train
  File "/fsx/phuc/temp/env_for_review_pr/env/lib/python3.10/pathlib.py", line 578, in _parse_args
    a = os.fspath(a)
    self.save_checkpoint()
  File "/fsx/phuc/temp/env_for_review_pr/nanotron/src/nanotron/trainer.py", line 912, in save_checkpoint
TypeError: expected str, bytes or os.PathLike object, not NoneType
    self.save_checkpoint()
  File "/fsx/phuc/temp/env_for_review_pr/nanotron/src/nanotron/trainer.py", line 912, in save_checkpoint
    self.save_checkpoint()
          File "/fsx/phuc/temp/env_for_review_pr/nanotron/src/nanotron/trainer.py", line 912, in save_checkpoint
self.save_checkpoint()self.save_checkpoint()

  File "/fsx/phuc/temp/env_for_review_pr/nanotron/src/nanotron/trainer.py", line 912, in save_checkpoint
  File "/fsx/phuc/temp/env_for_review_pr/nanotron/src/nanotron/trainer.py", line 912, in save_checkpoint
    self.save_checkpoint()
      File "/fsx/phuc/temp/env_for_review_pr/nanotron/src/nanotron/trainer.py", line 912, in save_checkpoint
self.save_checkpoint()
  File "/fsx/phuc/temp/env_for_review_pr/nanotron/src/nanotron/trainer.py", line 912, in save_checkpoint
    checkpoint_path = Path(checkpoints_path) / f"{self.iteration_step}"
  File "/fsx/phuc/temp/env_for_review_pr/env/lib/python3.10/pathlib.py", line 960, in __new__
    checkpoint_path = Path(checkpoints_path) / f"{self.iteration_step}"
  File "/fsx/phuc/temp/env_for_review_pr/env/lib/python3.10/pathlib.py", line 960, in __new__
    checkpoint_path = Path(checkpoints_path) / f"{self.iteration_step}"
  File "/fsx/phuc/temp/env_for_review_pr/env/lib/python3.10/pathlib.py", line 960, in __new__
    checkpoint_path = Path(checkpoints_path) / f"{self.iteration_step}"
  File "/fsx/phuc/temp/env_for_review_pr/env/lib/python3.10/pathlib.py", line 960, in __new__
    checkpoint_path = Path(checkpoints_path) / f"{self.iteration_step}"
  File "/fsx/phuc/temp/env_for_review_pr/env/lib/python3.10/pathlib.py", line 960, in __new__
    checkpoint_path = Path(checkpoints_path) / f"{self.iteration_step}"
  File "/fsx/phuc/temp/env_for_review_pr/env/lib/python3.10/pathlib.py", line 960, in __new__
    checkpoint_path = Path(checkpoints_path) / f"{self.iteration_step}"
  File "/fsx/phuc/temp/env_for_review_pr/env/lib/python3.10/pathlib.py", line 960, in __new__
    self = cls._from_parts(args)
  File "/fsx/phuc/temp/env_for_review_pr/env/lib/python3.10/pathlib.py", line 594, in _from_parts
    self = cls._from_parts(args)
      File "/fsx/phuc/temp/env_for_review_pr/env/lib/python3.10/pathlib.py", line 594, in _from_parts
    self = cls._from_parts(args)    self = cls._from_parts(args)
self = cls._from_parts(args)
  File "/fsx/phuc/temp/env_for_review_pr/env/lib/python3.10/pathlib.py", line 594, in _from_parts

  File "/fsx/phuc/temp/env_for_review_pr/env/lib/python3.10/pathlib.py", line 594, in _from_parts
  File "/fsx/phuc/temp/env_for_review_pr/env/lib/python3.10/pathlib.py", line 594, in _from_parts
        drv, root, parts = self._parse_args(args)drv, root, parts = self._parse_args(args)

      File "/fsx/phuc/temp/env_for_review_pr/env/lib/python3.10/pathlib.py", line 578, in _parse_args
  File "/fsx/phuc/temp/env_for_review_pr/env/lib/python3.10/pathlib.py", line 578, in _parse_args
self = cls._from_parts(args)
self = cls._from_parts(args)  File "/fsx/phuc/temp/env_for_review_pr/env/lib/python3.10/pathlib.py", line 594, in _from_parts

  File "/fsx/phuc/temp/env_for_review_pr/env/lib/python3.10/pathlib.py", line 594, in _from_parts
    drv, root, parts = self._parse_args(args)
      File "/fsx/phuc/temp/env_for_review_pr/env/lib/python3.10/pathlib.py", line 578, in _parse_args
drv, root, parts = self._parse_args(args)
drv, root, parts = self._parse_args(args)  File "/fsx/phuc/temp/env_for_review_pr/env/lib/python3.10/pathlib.py", line 578, in _parse_args

  File "/fsx/phuc/temp/env_for_review_pr/env/lib/python3.10/pathlib.py", line 578, in _parse_args
        a = os.fspath(a)a = os.fspath(a)

TypeErrorTypeError    : :     drv, root, parts = self._parse_args(args)expected str, bytes or os.PathLike object, not NoneTypeexpected str, bytes or os.PathLike object, not NoneTypedrv, root, parts = self._parse_args(args)



      File "/fsx/phuc/temp/env_for_review_pr/env/lib/python3.10/pathlib.py", line 578, in _parse_args
  File "/fsx/phuc/temp/env_for_review_pr/env/lib/python3.10/pathlib.py", line 578, in _parse_args
a = os.fspath(a)
a = os.fspath(a)
TypeError: expected str, bytes or os.PathLike object, not NoneType
TypeError: expected str, bytes or os.PathLike object, not NoneType
    a = os.fspath(a)
TypeError: expected str, bytes or os.PathLike object, not NoneType
    a = os.fspath(a)
TypeError: expected str, bytes or os.PathLike object, not NoneType
    a = os.fspath(a)
TypeError: expected str, bytes or os.PathLike object, not NoneType
[2024-10-02 20:09:03,528] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2082316 closing signal SIGTERM
[2024-10-02 20:09:03,529] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2082319 closing signal SIGTERM
[2024-10-02 20:09:03,529] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2082323 closing signal SIGTERM
[2024-10-02 20:09:03,706] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 1 (pid: 2082317) of binary: /fsx/phuc/temp/env_for_review_pr/env/bin/python
Traceback (most recent call last):
  File "/fsx/phuc/temp/env_for_review_pr/env/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/fsx/phuc/temp/env_for_review_pr/env/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/fsx/phuc/temp/env_for_review_pr/env/lib/python3.10/site-packages/torch/distributed/run.py", line 806, in main
    run(args)
  File "/fsx/phuc/temp/env_for_review_pr/env/lib/python3.10/site-packages/torch/distributed/run.py", line 797, in run
    elastic_launch(
  File "/fsx/phuc/temp/env_for_review_pr/env/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/fsx/phuc/temp/env_for_review_pr/env/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================

@eliebak
Copy link
Contributor Author

eliebak commented Oct 2, 2024

in your config the checkpoint_path is set to null checkpoint_path: null that's why you've got this error

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants