diff --git a/create_config.py b/create_config.py new file mode 100644 index 00000000..cd9a122b --- /dev/null +++ b/create_config.py @@ -0,0 +1,237 @@ +import argparse +import math +from datetime import datetime +from pathlib import Path + +import torch +from nanotron.config import ( + AdamWOptimizerArgs, + CheckpointsArgs, + Config, + DataArgs, + DatasetStageArgs, + GeneralArgs, + LoggingArgs, + LRSchedulerArgs, + ModelArgs, + OptimizerArgs, + ParallelismArgs, + PretrainDatasetsArgs, + RandomInit, + TokenizerArgs, + TokensArgs, +) +from nanotron.models.llama import LlamaConfig + +if __name__ == "__main__": + ########################################### + ## ADAPT TO YOUR ENVIRONMENT (toy example of smollm-135M on 1 GPU) + + HF_USER_OR_ORG = None + TRAIN_STEPS = 100 + CHECKPOINT_INTERVAL = 200 + SAVE_NAME = "smollm-135M-1gpu-toy" + + ########################################### + + parser = argparse.ArgumentParser() + parser.add_argument("--save-path", help="path to save the configuration file", type=str, default="yaml") + parser.add_argument("--seed", help="seed", type=int, default=8) + args = parser.parse_args() + + timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + general = GeneralArgs( + project="smollm", + run="toy-smollm", + seed=args.seed, + temp_dir="temp", + ) + + model_config = LlamaConfig( + bos_token_id=0, + eos_token_id=0, + hidden_act="silu", + hidden_size=576, + initializer_range=0.02, + intermediate_size=1536, + max_position_embeddings=2048, + num_attention_heads=9, + num_hidden_layers=30, + num_key_value_heads=3, + pretraining_tp=1, + rms_norm_eps=1e-05, + rope_scaling=None, + tie_word_embeddings=True, + use_cache=True, + vocab_size=49152, + ) + + # Uncomment to evaluate the model on a set of tasks with lighteval during the training. + # lighteval = LightEvalConfig( + # tasks=LightEvalTasksArgs( + # tasks="early-signal", # "generatives", "all" + # custom_tasks="nanotron.lighteval.evaluation_tasks", + # max_samples=1000, + # dataset_loading_processes=8, + # ), + # parallelism=ParallelismArgs( + # dp=8, + # pp=1, + # tp=1, + # pp_engine="1f1b", + # tp_mode="ALL_REDUCE", + # # recompute_granularity="selective", + # tp_linear_async_communication=False, + # ), + # batch_size=16, + # logging=LightEvalLoggingArgs( + # output_dir=None, + # push_to_hub=True, + # push_to_tensorboard=True, + # public_run=False, + # results_org=HF_USER_OR_ORG, + # tensorboard_metric_prefix="eval", + # ), + # ) + + lighteval = None + + checkpoints = CheckpointsArgs( + # checkpoints_path="checkpoints", + checkpoints_path_is_shared_file_system=False, + # resume_checkpoint_path="local_path/to/checkpoint" or s3_path, + checkpoint_interval=CHECKPOINT_INTERVAL, + save_initial_state=False, + ) + + parallelism = ParallelismArgs( + dp=1, + pp=1, + tp=1, + pp_engine="1f1b", + tp_mode="REDUCE_SCATTER", + tp_linear_async_communication=True, + ) + + tokens = TokensArgs( + batch_accumulation_per_replica=1, + micro_batch_size=8, + sequence_length=2048, + train_steps=TRAIN_STEPS, + val_check_interval=-1, + ) + + model = ModelArgs( + model_config=model_config, + init_method=RandomInit( + std=1 / math.sqrt(model_config.hidden_size), + ), + dtype=torch.bfloat16, + ) + + logging = LoggingArgs( + # 'debug', 'info', 'warning', 'error', 'critical' and 'passive' + log_level="info", + log_level_replica="info", + iteration_step_info_interval=1, + ) + + learning_rate_scheduler = LRSchedulerArgs( + learning_rate=3e-3, + lr_warmup_steps=10, + lr_warmup_style="linear", + lr_decay_style="linear", + lr_decay_steps=20, + lr_decay_starting_step=80, + min_decay_lr=0, + ) + + optimizer = OptimizerArgs( + zero_stage=0, + weight_decay=0.01, + clip_grad=1.0, + accumulate_grad_in_fp32=True, + learning_rate_scheduler=learning_rate_scheduler, + optimizer_factory=AdamWOptimizerArgs( + adam_eps=1e-08, + adam_beta1=0.9, + adam_beta2=0.95, + torch_adam_is_fused=True, + ), + ) + + tokenizer = TokenizerArgs( + tokenizer_name_or_path="HuggingFaceTB/cosmo2-tokenizer", + ) + + # Uncomment if you want to upload the checkpoints to s3 or load a ckpt from s3 + # s3_upload = S3UploadArgs( + # upload_s3_path=f"S3_PATH", + # remove_after_upload=True, + # s5cmd_numworkers=16, + # s5cmd_concurrency=5, + # s5cmd_path="PATH_TO_S5CMD", + # ) + + data_stages = [ + DatasetStageArgs( + data=DataArgs( + # 1. Un-tokenized dataset from HuggingFace + dataset=PretrainDatasetsArgs( + hf_dataset_or_datasets="HuggingFaceTB/smollm-corpus", # feel free to replace it by a smaller one if you don't have enough memory + hf_dataset_splits="train", + hf_dataset_config_name="cosmopedia-v2", + text_column_name="text", + ), + # 2. Pre-tokenized local dataset with Nanoset + # dataset=NanosetDatasetsArgs( + # dataset_folder="datasets/cosmopedia-v2", + # ), + # num_loading_workers=0, + # seed=general.seed, + ), + name="training stage", + start_training_step=1, + ), + # You can add a decay stage here if you want to change the data mixture + # Example (weight are arbitrary here): + # DatasetStageArgs( + # data=DataArgs( + # dataset=NanosetDatasetsArgs( + # dataset_folder={ + # "datasets/fineweb-edu-dedup": 50, + # "datasets/cosmopedia-v2": 30, + # "datasets/python-edu": 10, + # "datasets/open-web-math": 10, + # } + # ), + # num_loading_workers=0, + # seed=general.seed, + # ), + # name="decay stage", + # start_training_step=optimizer.learning_rate_scheduler.lr_decay_starting_step, + # ), + ] + + config = Config( + general=general, + checkpoints=checkpoints, + parallelism=parallelism, + model=model, + tokenizer=tokenizer, + logging=logging, + tokens=tokens, + optimizer=optimizer, + data_stages=data_stages, + lighteval=lighteval, + ) + + save_path = Path(args.save_path) + save_path.mkdir(parents=True, exist_ok=True) + + config_path_yaml = save_path / f"{SAVE_NAME}.yaml" + config.save_as_yaml(config_path_yaml) + + print(f"πŸ’Ύ Configuration saved in: {str(save_path)}") + print("To launch this configuration, run:") + print(f"python launcher.py --config-path configs/{str(config_path_yaml)}") diff --git a/launcher.py b/launcher.py new file mode 100644 index 00000000..f00e20e9 --- /dev/null +++ b/launcher.py @@ -0,0 +1,395 @@ +import argparse +import json +import os +import subprocess +import tempfile +from datetime import datetime +from pathlib import Path + +import torch +from jinja2 import Template +from nanotron.config import ( + Config, + get_config_from_file, + save_as_yaml, +) +from nanotron.logging import human_format + + +def count_subdirectories(path): + return sum(os.path.isdir(os.path.join(path, item)) for item in os.listdir(path)) + + +def launch_slurm_job(launch_file_contents, *args): + """ + Small helper function to save a sbatch script and call it. + Args: + launch_file_contents: Contents of the sbatch script + *args: any other arguments to pass to the sbatch command + + Returns: the id of the launched slurm job + + """ + with tempfile.NamedTemporaryFile("w") as f: + f.write(launch_file_contents) + f.flush() + return subprocess.check_output(["sbatch", *args, f.name]).decode("utf-8").split()[-1] + + +def set_nested_attribute(obj, path, value): + parts = path.split(".") + for part in parts[:-1]: + if not hasattr(obj, part): + setattr(obj, part, type("", (), {})()) + obj = getattr(obj, part) + setattr(obj, parts[-1], value) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("--config-path", help="path to the configuration file", type=str, default=None, required=True) + parser.add_argument("--project", help="name of the project", type=str) + parser.add_argument("--run", help="name of the run", type=str) + parser.add_argument("--logs-path", help="path to the logs folder", type=str, default="logs") + parser.add_argument( + "--override", nargs="+", metavar="KEY=VALUE", help="Override config values. Use dot notation for nested keys." + ) + parser.add_argument("--slurm", action="store_true", help="Launch the job on Slurm") + parser.add_argument("--nodes", type=int, help="Number of nodes to use for the job") + args = parser.parse_args() + + if args.config_path is None: + raise ValueError("Please provide a config path") + + if args.slurm: + if args.nodes is None: + raise ValueError("When using Slurm (--slurm), you must specify the number of nodes (--nodes)") + + # Load the configuration using get_config_from_file + config = get_config_from_file(args.config_path, config_class=Config) + + if config.general.logs_path is None and args.logs_path is None: + raise ValueError("Please provide a logs path") + if config.general.project is None and args.project is None: + raise ValueError("Please provide a project name") + elif args.project is not None: + config.general.project = args.project + + if config.general.run is None and args.run is None: + raise ValueError("Please provide a run name") + elif args.run is not None: + config.general.run = args.run + + num_params = human_format(config.model.model_config.get_llama_param_count()).replace(".", ",") + + if args.override: + for item in args.override: + if "=" not in item: + raise ValueError(f"Invalid override format: {item}. Use KEY=VALUE.") + key, value = item.split("=", 1) + try: + value = eval(value) + except Exception as e: + print(f"Warning: Could not evaluate '{value}': {e}") + + set_nested_attribute(config, key, value) + + print("⇄ Applied overrides:") + for item in args.override: + print(f" {item}") + + # Calculate and print learning rate and global batch size information + lr_initial = config.optimizer.learning_rate_scheduler.learning_rate + lr_min = config.optimizer.learning_rate_scheduler.min_decay_lr + lr_warmup_steps = config.optimizer.learning_rate_scheduler.lr_warmup_steps + lr_decay_steps = config.optimizer.learning_rate_scheduler.lr_decay_steps + lr_decay_start = config.optimizer.learning_rate_scheduler.lr_decay_starting_step + lr_decay_style = config.optimizer.learning_rate_scheduler.lr_decay_style + + # Sample/Token per GPU (at once) + bs_gpu_sample = config.tokens.micro_batch_size + bs_gpu_token = bs_gpu_sample * config.tokens.sequence_length + + # Sample/Token in one step + gbs_sample = bs_gpu_sample * config.parallelism.dp * config.tokens.batch_accumulation_per_replica + gbs_token = gbs_sample * config.tokens.sequence_length + + total_tokens = config.tokens.train_steps * gbs_token + total_tokens_billions = human_format(total_tokens).replace(".", ",") + + print( + f""" +πŸ‹οΈ Model Parameters: +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ Total Parameters β”‚ {num_params:>22} β”‚ +β”‚ Layers β”‚ {config.model.model_config.num_hidden_layers:>22d} β”‚ +β”‚ Attention Heads β”‚ {config.model.model_config.num_attention_heads:>22d} β”‚ +β”‚ Hidden Size β”‚ {config.model.model_config.hidden_size:>22d} β”‚ +β”‚ Intermediate Size β”‚ {config.model.model_config.intermediate_size:>22d} β”‚ +β”‚ Context Length β”‚ {config.model.model_config.max_position_embeddings:>22d} β”‚ +β”‚ Tokenizer β”‚ {config.tokenizer.tokenizer_name_or_path[:22]:>22} β”‚ +β”‚ Vocab Size β”‚ {config.model.model_config.vocab_size:>22d} β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +""" + ) + + num_nodes = args.nodes if args.slurm else 1 + print( + f""" +πŸŽ›οΈ Parallelism Configuration: +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ Nodes β”‚ {num_nodes:>22d} β”‚ +β”‚ Total GPUs β”‚ {config.parallelism.dp*config.parallelism.pp*config.parallelism.tp:>22d} β”‚ +β”‚ Data Parallel (DP) β”‚ {config.parallelism.dp:>22d} β”‚ +β”‚ Pipeline Parallel (PP)β”‚ {config.parallelism.pp:>22d} β”‚ +β”‚ Tensor Parallel (TP) β”‚ {config.parallelism.tp:>22d} β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +""" + ) + + print( + f""" +πŸ“™ Training Configuration: +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ Total Tokens β”‚ {total_tokens_billions:>22} β”‚ +β”‚ Batch Size (per GPU) β”‚ {bs_gpu_token:>15,d} Tokens β”‚ +β”‚ Global Batch Size β”‚ {gbs_token:>15,d} Tokens β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +""" + ) + + print( + f""" +πŸ“Š Learning Rate Schedule: +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ Initial LR β”‚ {lr_initial:>22.2e} β”‚ +β”‚ Warmup Style β”‚ {config.optimizer.learning_rate_scheduler.lr_warmup_style[:22]:>22} β”‚ +β”‚ Warmup Steps β”‚ {lr_warmup_steps:>22d} β”‚ +β”‚ Decay Style β”‚ {lr_decay_style[:22]:>22} β”‚ +β”‚ Decay Start Step β”‚ {lr_decay_start:>22d} β”‚ +β”‚ Decay Steps β”‚ {lr_decay_steps:>22d} β”‚ +β”‚ Final LR β”‚ {lr_min:>22.2e} β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +""" + ) + print( + f""" +πŸ”§ Optimization Configuration: +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ Optimizer β”‚ {config.optimizer.optimizer_factory.__class__.__name__:>22} β”‚ +β”‚ Weight Decay β”‚ {config.optimizer.weight_decay:>22.2e} β”‚ +β”‚ Gradient Clipping β”‚ {config.optimizer.clip_grad:>22.2f} β”‚ +β”‚ Adam Epsilon β”‚ {config.optimizer.optimizer_factory.adam_eps:>22.2e} β”‚ +β”‚ Adam Beta1 β”‚ {config.optimizer.optimizer_factory.adam_beta1:>22.2f} β”‚ +β”‚ Adam Beta2 β”‚ {config.optimizer.optimizer_factory.adam_beta2:>22.2f} β”‚ +β”‚ ZeRO Stage β”‚ {config.optimizer.zero_stage:>22d} β”‚ +β”‚ FP32 Grad Accumulationβ”‚ {str(config.optimizer.accumulate_grad_in_fp32):>22} β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +""" + ) + + config.general.logs_path = args.logs_path + + path = Path(args.logs_path) / f"{config.general.run}" + path.mkdir(parents=True, exist_ok=True) + + timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + run_number = count_subdirectories(f"{args.logs_path}/{config.general.run}") + 1 + timestamp_with_run = f"run{run_number:03d}_{timestamp}" + config.general.timestamp_with_run = timestamp_with_run + + config.general.config_logs_path = str( + Path(config.general.logs_path) / config.general.run / timestamp_with_run / "config" + ) + Path(config.general.config_logs_path).mkdir(parents=True, exist_ok=True) + + if config.checkpoints.checkpoints_path is None: + config.checkpoints.checkpoints_path = str( + Path(config.general.logs_path) / config.general.run / timestamp_with_run / "checkpoints" + ) + Path(config.checkpoints.checkpoints_path).mkdir(parents=True, exist_ok=True) + + if args.slurm: + + nodes = args.nodes + + launch_slurm_config_path = Path("slurm/launch_slurm_config.json") + if config.lighteval is not None: + eval_slurm_config_path = Path("slurm/eval_slurm_config.json") + if eval_slurm_config_path.exists(): + config.general.eval_slurm_config = str(eval_slurm_config_path.resolve()) + else: + raise ValueError("Lighteval SLURM configuration is required but not provided.") + if config.general.is_s3_available: + config.general.eval_slurm_template = "slurm/run_eval_s3.slurm.jinja" + else: + config.general.eval_slurm_template = "slurm/run_eval.slurm.jinja" + + with open(launch_slurm_config_path, "r") as f: + launch_slurm_config = json.load(f) + + total_gpus = config.parallelism.dp * config.parallelism.pp * config.parallelism.tp + gpus_per_node = launch_slurm_config.get("gpus_per_node") + if total_gpus < gpus_per_node: + required_nodes = 1 + gpus_per_node = total_gpus + print( + "Warning: The total number of GPUs is less than the GPUs per node. You need to adjust to use all available GPUs." + ) + else: + required_nodes = (total_gpus + gpus_per_node - 1) // gpus_per_node # Ceiling division + + if args.nodes != required_nodes: + raise ValueError( + f"Number of nodes in config ({args.nodes}) does not match the required number of nodes ({required_nodes}) based on the parallelism configuration." + ) + + # Create necessary folders + project_log_folder = Path(config.general.logs_path) + log_folder = project_log_folder / f"{config.general.run}" / f"{timestamp_with_run}" + subfolders = ["launch-script", "slurm-logs"] + if hasattr(config, "lighteval") and config.lighteval is not None: + subfolders.append("evals") + + for subfolder in subfolders: + folder_path = str(log_folder / subfolder) + Path(folder_path).mkdir(parents=True, exist_ok=True) + if subfolder == "launch-script": + config.general.launch_script_path = folder_path + elif subfolder == "slurm-logs": + config.general.slurm_logs_path = folder_path + elif subfolder == "evals": + config.general.evals_logs_path = folder_path + for evals_subfolder in ["launch-config", "logs", "lighteval-logs"]: + if evals_subfolder == "lighteval-logs": + if config.lighteval.logging.output_dir is None: + evals_subfolder_path = str(Path(config.general.evals_logs_path) / evals_subfolder) + Path(evals_subfolder_path).mkdir(parents=True, exist_ok=True) + config.lighteval.logging.output_dir = evals_subfolder_path + else: + evals_subfolder_path = str(Path(config.general.evals_logs_path) / evals_subfolder) + Path(evals_subfolder_path).mkdir(parents=True, exist_ok=True) + + torchrun_args = "" + if "torchrun_args" in launch_slurm_config and launch_slurm_config["torchrun_args"]: + torchrun_args = " ".join([f"--{k} {v}" for k, v in launch_slurm_config["torchrun_args"].items()]) + + launch_slurm_config.update( + { + "job_name": f"{config.general.project}-{config.general.run}", + "nodes": args.nodes, + "slurm_logs_path": config.general.slurm_logs_path, + "path_to_trainer_python_file": os.path.join(os.path.dirname(__file__), "run_train.py"), + "config_path_yaml": f"{config.general.config_logs_path}/launch_config.yaml", + "torchrun_args": torchrun_args, + } + ) + + # Load Jinja2 template + template_path = Path("slurm/launch_training.slurm.jinja") + with open(template_path, "r") as f: + template = Template(f.read()) + + # Render the template + sbatch_script = template.render(**launch_slurm_config) + if launch_slurm_config_path.exists(): + config.general.launch_slurm_config = str(launch_slurm_config_path.resolve()) + else: + config.general.launch_slurm_config = None + + if config.lighteval is not None: + # Save the lighteval configuration + lighteval_config = config.lighteval + Path(config.general.config_logs_path).mkdir(parents=True, exist_ok=True) + config.general.lighteval_config_path = str(Path(config.general.config_logs_path) / "lighteval_config.yaml") + save_as_yaml(lighteval_config, config.general.lighteval_config_path) + + config_path_yaml = str(Path(config.general.config_logs_path) / "launch_config.yaml") + Path(config.general.config_logs_path).mkdir(parents=True, exist_ok=True) + config.save_as_yaml(config_path_yaml) + + # Launch the Slurm job + job_id = launch_slurm_job(sbatch_script) + print(f"πŸš€ Slurm job launched with id={job_id}") + + # Save the Slurm script if a path is provided + if config.general.launch_script_path: + Path(config.general.launch_script_path).mkdir(parents=True, exist_ok=True) + script_filename = "slurm_launch_script.slurm" + script_path = str(Path(config.general.launch_script_path) / script_filename) + script_path = os.path.join(config.general.launch_script_path, script_filename) + + with open(script_path, "w") as f: + f.write(sbatch_script) + print(" πŸ€– Slurm Configuration Details:") + + slurm_config_keys = ["qos", "gpus_per_node", "cpus_per_task", "constraint", "account", "reservation"] + for key in slurm_config_keys: + if key in launch_slurm_config: + if launch_slurm_config[key] is not None: + print(f" {key}: {launch_slurm_config[key]}") + + print(" ") + print(" πŸ“ Log structure:") + print(f" {config.general.logs_path}/{config.general.run}/") + print(f" └── {timestamp_with_run}/") + if config.checkpoints.checkpoints_path == str( + Path(config.general.logs_path) / config.general.run / timestamp_with_run / "checkpoints" + ): + print(" β”œβ”€β”€ checkpoints/") + print(" β”œβ”€β”€ config/") + print(" β”œβ”€β”€ launch-script/") + print(" β”œβ”€β”€ slurm-logs/") + if hasattr(config, "lighteval") and config.lighteval is not None: + print(" └── evals/") + print(" β”œβ”€β”€ launch-config/") + print(" └── logs/") + if config.lighteval.logging.output_dir == str(Path(config.general.evals_logs_path) / "lighteval-logs"): + print(" └── lighteval-logs/") + + else: + # Check if running on an interactive node + try: + gpu_count = torch.cuda.device_count() + is_interactive = gpu_count > 0 + except Exception as e: + print(f"Warning: Could not get GPU count: {e}") + is_interactive = False + + if is_interactive: + print("πŸ’» Running on an interactive node with GPUs.") + gpu_config = config.parallelism.dp * config.parallelism.tp * config.parallelism.pp + if gpu_count < gpu_config: + raise ValueError( + f"Error: Your configuration (dp={config.parallelism.dp}, tp={config.parallelism.tp}, pp={config.parallelism.pp}) " + f"requires {gpu_config} GPUs, but only {gpu_count} are available." + ) + elif gpu_count == gpu_config: + print( + f"πŸš€ Running on {gpu_count} GPUs, which matches your configuration (dp={config.parallelism.dp}, tp={config.parallelism.tp}, pp={config.parallelism.pp})" + ) + total_gpus = gpu_count + elif gpu_count > gpu_config: + total_gpus = gpu_config + print( + f"⚠️ Warning: Your configuration (dp={config.parallelism.dp}, tp={config.parallelism.tp}, pp={config.parallelism.pp}) " + f"uses {total_gpus} GPUs, but {gpu_count} are available. " + f"You are not fully utilizing all available GPUs on this device." + ) + + config_path_yaml = str(Path(config.general.config_logs_path) / "launch_config.yaml") + os.makedirs(config.general.config_logs_path, exist_ok=True) + config.save_as_yaml(config_path_yaml) + + trainer_python_file = "run_train.py" + cmd = f"{trainer_python_file} --config-file {config_path_yaml}" + + launch_cmd = f"CUDA_DEVICE_MAX_CONNECTIONS='1' torchrun --nproc_per_node {total_gpus} {cmd}" + print(f"πŸš€ Launching interactive job with command: {launch_cmd}") + + subprocess.run(launch_cmd, shell=True, check=True) + else: + print( + "❌ Not running on a Slurm cluster or an interactive node with GPUs. Please submit a Slurm job or use an interactive node with GPUs." + ) diff --git a/pyproject.toml b/pyproject.toml index 9794ab78..dbde7f0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,9 @@ s3 = [ "s5cmd", ] +lighteval = [ + "lighteval[nanotron]@git+https://github.com/huggingface/lighteval.git", +] [build-system] requires = [ "setuptools", diff --git a/slurm/eval_slurm_config.json b/slurm/eval_slurm_config.json new file mode 100644 index 00000000..51bd4912 --- /dev/null +++ b/slurm/eval_slurm_config.json @@ -0,0 +1,26 @@ +{ + "job_name": "", + "n_tasks_per_node": 1, + "cpus_per_task": 32, + "gpus_per_node": 8, + "partition": "hopper-prod", + "qos": "high", + "mail_type": null, + "mail_user": null, + "exclude_nodes": null, + "time": "1:00:00", + "constraint": null, + "account": null, + "reservation": null, + "torchrun_args": { + "node_rank": "$SLURM_PROCID", + "role": "$SLURMD_NODENAME", + "max_restarts": 0, + "tee": 3 + }, + "hf_cache": "~/.cache", + "array": null, + "mem": null, + "begin": null + } + \ No newline at end of file diff --git a/slurm/launch_slurm_config.json b/slurm/launch_slurm_config.json new file mode 100644 index 00000000..b82b4bc6 --- /dev/null +++ b/slurm/launch_slurm_config.json @@ -0,0 +1,25 @@ +{ + "job_name": "", + "n_tasks_per_node": 1, + "cpus_per_task": 88, + "gpus_per_node": 8, + "partition": "hopper-prod", + "qos": "high", + "mail_type": null, + "mail_user": null, + "exclude_nodes": ["ip-26-0-161-138"], + "time": null, + "constraint": null, + "account": null, + "reservation": null, + "torchrun_args": { + "node_rank": "$SLURM_PROCID", + "role": "$SLURMD_NODENAME", + "max_restarts": 0, + "tee": 3 + }, + "hf_cache": "~/.cache", + "array": null, + "mem": null, + "begin": null +} diff --git a/slurm/launch_training.slurm.jinja b/slurm/launch_training.slurm.jinja new file mode 100644 index 00000000..4e71e88a --- /dev/null +++ b/slurm/launch_training.slurm.jinja @@ -0,0 +1,95 @@ +#!/bin/bash +#SBATCH --job-name={{ job_name }} +#SBATCH --nodes={{ nodes }} +#SBATCH --ntasks-per-node={{ n_tasks_per_node }} +#SBATCH --gres=gpu:{{ gpus_per_node }} +{% if cpus_per_task %} +#SBATCH --cpus-per-task={{ cpus_per_task }} +{% endif %} +#SBATCH --partition={{ partition }} +#SBATCH --output={{ slurm_logs_path }}/train-%j.out +#SBATCH --error={{ slurm_logs_path }}/train-%j.err +{% if qos %} +#SBATCH --qos={{ qos }} +{% endif %} +{% if mail_type %} +#SBATCH --mail-type={{ mail_type }} +{% endif %} +{% if mail_user %} +#SBATCH --mail-user={{ mail_user }} +{% endif %} +{% if exclude_nodes %} +#SBATCH --exclude={{ exclude_nodes|join(',') }} +{% endif %} +{% if time %} +#SBATCH --time={{ time }} +{% endif %} +{% if constraint %} +#SBATCH --constraint={{ constraint }} +{% endif %} +{% if account %} +#SBATCH --account={{ account }} +{% endif %} +{% if reservation %} +#SBATCH --reservation={{ reservation }} +{% endif %} + +set -e + +TRAINER_PYTHON_FILE={{ path_to_trainer_python_file }} +nvidia-smi + +# Show some environment variables +echo python3 version = `python3 --version` +echo "Python path: $(which python3)" +echo "NCCL version: $(python -c "import torch;print(torch.cuda.nccl.version())")" +echo "CUDA version: $(python -c "import torch;print(torch.version.cuda)")" + +echo "START TIME: $(date)" +secs_to_human() { + echo "$(( ${1} / 3600 )):$(( (${1} / 60) % 60 )):$(( ${1} % 60 ))" +} +start=$(date +%s) +echo "$(date -d @${start} "+%Y-%m-%d %H:%M:%S"): ${SLURM_JOB_NAME} start id=${SLURM_JOB_ID}\n" + +# SLURM stuff +export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` +export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +export MASTER_PORT=6000 +export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` + +export TMPDIR=/scratch +export CUDA_DEVICE_MAX_CONNECTIONS="1" + +export HUGGINGFACE_HUB_CACHE={{ hf_cache }} +export HF_DATASETS_CACHE={{ hf_cache }} +export HF_MODULES_CACHE={{ hf_cache }} +export HF_HOME={{ hf_cache }} + +echo go $COUNT_NODE +echo $HOSTNAMES + +CMD=" $TRAINER_PYTHON_FILE \ + --config-file {{ config_path_yaml }} \ + " +export LAUNCHER="torchrun \ + --nproc_per_node {{ gpus_per_node }} \ + --nnodes $COUNT_NODE \ + {{ torchrun_args }} \ + --node_rank $SLURM_PROCID \ + --role $SLURMD_NODENAME: \ + --max_restarts 0 \ + --tee 3 \ + " + +# Wait a random number between 0 and 1000 (milliseconds) to avoid too many concurrent requests to the hub +random_milliseconds=$(( RANDOM % 1001 )) +sleep_time=$(bc <<< "scale=3; $random_milliseconds / 1000") +echo "Sleeping for $sleep_time seconds..." +sleep $sleep_time + +launch_args="srun $SRUN_ARGS -u bash -c $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD" + +srun $SRUN_ARGS -u bash -c "$LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD" + +echo "END TIME: $(date)" \ No newline at end of file diff --git a/slurm/run_eval.slurm.jinja b/slurm/run_eval.slurm.jinja new file mode 100644 index 00000000..6444858a --- /dev/null +++ b/slurm/run_eval.slurm.jinja @@ -0,0 +1,95 @@ +#!/bin/bash +#SBATCH --job-name={{ job_name }} +#SBATCH --nodes={{ nodes }} +#SBATCH --ntasks-per-node={{ n_tasks_per_node }} +#SBATCH --gres=gpu:{{ gpus_per_node }} +{% if cpus_per_task %} +#SBATCH --cpus-per-task={{ cpus_per_task }} +{% endif %} +#SBATCH --partition={{ partition }} +#SBATCH --output={{ eval_path }}/%x-%n-%j.out +#SBATCH --error={{ eval_path }}/%x-%n-%j.err +{% if qos %} +#SBATCH --qos={{ qos }} +{% endif %} +{% if mail_type %} +#SBATCH --mail-type={{ mail_type }} +{% endif %} +{% if mail_user %} +#SBATCH --mail-user={{ mail_user }} +{% endif %} +{% if exclude_nodes %} +#SBATCH --exclude={{ exclude_nodes|join(',') }} +{% endif %} +{% if time %} +#SBATCH --time={{ time }} +{% endif %} +{% if constraint %} +#SBATCH --constraint={{ constraint }} +{% endif %} +{% if account %} +#SBATCH --account={{ account }} +{% endif %} +{% if reservation %} +#SBATCH --reservation={{ reservation }} +{% endif %} + +set -e +echo "START TIME: $(date)" +#Show some environment variables +echo python3 version = `python3 --version` +echo "NCCL version: $(python -c "import torch;print(torch.cuda.nccl.version())")" +echo "CUDA version: $(python -c "import torch;print(torch.version.cuda)")" + +# SLURM stuff +export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` +export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +export MASTER_PORT=6000 +export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` + +# Hugging Face token +if [ -z "$HUGGING_FACE_HUB_TOKEN" ]; then + # Attempt to read the token from the cache + if TOKEN=$(cat ~/.cache/huggingface/token 2>/dev/null); then + export HUGGING_FACE_HUB_TOKEN=$TOKEN + else + echo "Error: The environment variable HUGGING_FACE_HUB_TOKEN is not set and the token cache could not be read." + exit 1 + fi +fi + +export CUBLAS_WORKSPACE_CONFIG=":4096:8" +export CUDA_DEVICE_MAX_CONNECTIONS="1" + +export HUGGINGFACE_HUB_CACHE={{ hf_cache }} +export HF_DATASETS_CACHE={{ hf_cache }} +export HF_MODULES_CACHE={{ hf_cache }} +export HF_HOME={{ hf_cache }} + +echo go $COUNT_NODE +echo $HOSTNAMES + + +CMD="/fsx/elie_bakouch/nanotron/src/nanotron/lighteval/run_evals.py \ + --checkpoint-config-path {{ model_checkpoint_path }}/config.yaml \ + --lighteval-config-path {{ lighteval_config_path }} \ + " + +export LAUNCHER="torchrun \ + --nproc_per_node {{ gpus_per_node }} \ + --nnodes $COUNT_NODE \ + --node_rank $SLURM_PROCID \ + --role $SLURMD_NODENAME: \ + --max_restarts 0 \ + --tee 3 \ + " + +# Wait a random number between 0 and 1000 (milliseconds) to avoid too many concurrent requests to the hub +random_milliseconds=$(( RANDOM % 1001 )) +sleep_time=$(bc <<< "scale=3; $random_milliseconds / 1000") +echo "Sleeping for $sleep_time seconds..." +sleep $sleep_time + +launch_args="srun $SRUN_ARGS -u bash -c $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD" + +srun $SRUN_ARGS -u bash -c "$LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD" diff --git a/slurm/run_eval_s3.slurm.jinja b/slurm/run_eval_s3.slurm.jinja new file mode 100644 index 00000000..04441638 --- /dev/null +++ b/slurm/run_eval_s3.slurm.jinja @@ -0,0 +1,100 @@ +#!/bin/bash +#SBATCH --job-name={{ job_name }} +#SBATCH --nodes={{ nodes }} +#SBATCH --ntasks-per-node={{ n_tasks_per_node }} +#SBATCH --gres=gpu:{{ gpus_per_node }} +{% if cpus_per_task %} +#SBATCH --cpus-per-task={{ cpus_per_task }} +{% endif %} +#SBATCH --partition={{ partition }} +#SBATCH --output={{ eval_path }}/%x-%n-%j.out +#SBATCH --error={{ eval_path }}/%x-%n-%j.err +{% if qos %} +#SBATCH --qos={{ qos }} +{% endif %} +{% if mail_type %} +#SBATCH --mail-type={{ mail_type }} +{% endif %} +{% if mail_user %} +#SBATCH --mail-user={{ mail_user }} +{% endif %} +{% if exclude_nodes %} +#SBATCH --exclude={{ exclude_nodes|join(',') }} +{% endif %} +{% if time %} +#SBATCH --time={{ time }} +{% endif %} +{% if constraint %} +#SBATCH --constraint={{ constraint }} +{% endif %} +{% if account %} +#SBATCH --account={{ account }} +{% endif %} +{% if reservation %} +#SBATCH --reservation={{ reservation }} +{% endif %} + +LOCAL_DOWNLOAD_CHECKPOINT_FOLDER={{ local_path }} + +set -e +echo "START TIME: $(date)" +#Show some environment variables +echo python3 version = `python3 --version` +echo "NCCL version: $(python -c "import torch;print(torch.cuda.nccl.version())")" +echo "CUDA version: $(python -c "import torch;print(torch.version.cuda)")" + +# SLURM stuff +export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` +export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +export MASTER_PORT=6000 +export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` + +# Hugging Face token +if [ -z "$HUGGING_FACE_HUB_TOKEN" ]; then + # Attempt to read the token from the cache + if TOKEN=$(cat ~/.cache/huggingface/token 2>/dev/null); then + export HUGGING_FACE_HUB_TOKEN=$TOKEN + else + echo "Error: The environment variable HUGGING_FACE_HUB_TOKEN is not set and the token cache could not be read." + exit 1 + fi +fi + +export CUBLAS_WORKSPACE_CONFIG=":4096:8" +export CUDA_DEVICE_MAX_CONNECTIONS="1" + +export HUGGINGFACE_HUB_CACHE={{ hf_cache }} +export HF_DATASETS_CACHE={{ hf_cache }} +export HF_MODULES_CACHE={{ hf_cache }} +export HF_HOME={{ hf_cache }} + +echo go $COUNT_NODE +echo $HOSTNAMES + +# Copying checkpoint from s3 to the node on node +mkdir -p $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER +s5cmd cp --exclude "optimizer/*" {{ model_checkpoint_path }}* $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER + +CMD="/fsx/elie_bakouch/nanotron/src/nanotron/lighteval/run_evals.py \ + --checkpoint-config-path $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/config.yaml \ + --lighteval-config-path {{ lighteval_config_path }} \ + " + +export LAUNCHER="torchrun \ + --nproc_per_node {{ gpus_per_node }} \ + --nnodes $COUNT_NODE \ + --node_rank $SLURM_PROCID \ + --role $SLURMD_NODENAME: \ + --max_restarts 0 \ + --tee 3 \ + " + +# Wait a random number between 0 and 1000 (milliseconds) to avoid too many concurrent requests to the hub +random_milliseconds=$(( RANDOM % 1001 )) +sleep_time=$(bc <<< "scale=3; $random_milliseconds / 1000") +echo "Sleeping for $sleep_time seconds..." +sleep $sleep_time + +launch_args="srun $SRUN_ARGS -u bash -c $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD" + +srun $SRUN_ARGS -u bash -c "$LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD" diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index c50334f6..31f3ee4d 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -8,9 +8,9 @@ import torch import yaml from dacite import from_dict -from datasets.download.streaming_download_manager import xPath from yaml.loader import SafeLoader +from datasets.download.streaming_download_manager import xPath from nanotron.config.lighteval_config import LightEvalConfig from nanotron.config.models_config import ExistingCheckpointInit, NanotronConfigs, RandomInit, SpectralMupInit from nanotron.config.parallelism_config import ParallelismArgs @@ -24,6 +24,7 @@ from nanotron.logging import get_logger from nanotron.parallel.pipeline_parallel.engine import PipelineEngine from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode +from nanotron.s3_checkpoints import check_path_is_local logger = get_logger(__name__) @@ -96,17 +97,19 @@ def __post_init__(self): class S3UploadArgs: """Arguments related to uploading checkpoints on s3""" - upload_s3_path: xPath - remove_after_upload: bool - s5cmd_numworkers: Optional[int] - s5cmd_concurrency: Optional[int] - s5cmd_path: Optional[xPath] + remove_after_upload: Optional[bool] = True + upload_s3_path: Optional[ + str + ] = None # set to None if we want to use S3UploadArgs to download checkpoints from s3 but not upload checkpoints on s3 + s5cmd_numworkers: Optional[int] = None + s5cmd_concurrency: Optional[int] = None + s5cmd_path: Optional[str] = None def __post_init__(self): - if isinstance(self.upload_s3_path, str): + if isinstance(self.upload_s3_path, str) and self.upload_s3_path is not None: self.upload_s3_path = xPath(self.upload_s3_path) if isinstance(self.s5cmd_path, str): - self.s5cmd_path = xPath(self.s5cmd_path) + self.s5cmd_path = Path(self.s5cmd_path) @dataclass @@ -154,18 +157,21 @@ class CheckpointsArgs: resume_checkpoint_path: if you want to load from a specific checkpoint path """ - checkpoints_path: Path checkpoint_interval: int + checkpoints_path: Optional[str] = None save_initial_state: Optional[bool] = False save_final_state: Optional[bool] = False - resume_checkpoint_path: Optional[xPath] = None + resume_checkpoint_path: Optional[str] = None checkpoints_path_is_shared_file_system: Optional[bool] = False def __post_init__(self): if isinstance(self.checkpoints_path, str): self.checkpoints_path = xPath(self.checkpoints_path) if isinstance(self.resume_checkpoint_path, str): - self.resume_checkpoint_path = xPath(self.resume_checkpoint_path) + if check_path_is_local(self.resume_checkpoint_path): + self.resume_checkpoint_path = Path(self.resume_checkpoint_path) + else: + self.resume_checkpoint_path = xPath(self.resume_checkpoint_path) @dataclass @@ -180,8 +186,20 @@ class GeneralArgs: ignore_sanity_checks: Whether to ignore sanity checks """ - project: str + project: Optional[str] = None run: Optional[str] = None + logs_path: Optional[str] = None + launch_slurm_config: Optional[str] = None + eval_slurm_config: Optional[str] = None + eval_slurm_template: Optional[str] = None + lighteval_config_path: Optional[str] = None + is_s3_available: Optional[bool] = None + timestamp_with_run: Optional[str] = None + launch_script_path: Optional[str] = None + slurm_logs_path: Optional[str] = None + config_logs_path: Optional[str] = None + evals_logs_path: Optional[str] = None + temp_dir: Optional[str] = None seed: Optional[int] = None step: Optional[int] = None consumed_train_samples: Optional[int] = None @@ -358,10 +376,14 @@ def create_empty(cls): return cls(**{f.name: None for f in cls_fields}) def __post_init__(self): + self.general.__post_init__() if self.s3_upload is not None: self.s3_upload.__post_init__() + self.general.is_s3_available = True + else: + self.general.is_s3_available = False # Some final sanity checks across separate arguments sections: if self.profiler is not None and self.profiler.profiler_export_path is not None: assert self.tokens.train_steps < 10 @@ -394,15 +416,16 @@ def __post_init__(self): for i in range(len(self.data_stages) - 1) ), "The stages are not sorted by start_training_step in increasing order" - # # if lighteval, we need tokenizer to be defined - # if self.checkpoints.lighteval is not None: - # assert self.tokenizer.tokenizer_name_or_path is not None + # if lighteval, we need tokenizer to be defined + if self.lighteval is not None: + assert self.tokenizer.tokenizer_name_or_path is not None @property def global_batch_size(self): return self.tokens.micro_batch_size * self.tokens.batch_accumulation_per_replica * self.parallelism.dp def save_as_yaml(self, file_path: str): + config_dict = serialize(self) file_path = str(file_path) with open(file_path, "w") as f: @@ -473,7 +496,7 @@ def get_config_from_file( skip_unused_config_keys: whether to skip unused first-nesting-level keys in the config file (for config with additional sections) skip_null_keys: whether to skip keys with value None at first and second nesting level """ - # Open the file and load the file + with open(config_path) as f: config_dict = yaml.load(f, Loader=SafeLoader) @@ -490,3 +513,14 @@ def get_config_from_file( ) config.model.model_config = model_config_class(**config.model.model_config) return config + + +def save_as_yaml(config: Union[Config, LightEvalConfig], file_path: str): + config_class = type(config) + config_dict = serialize(config) + file_path = str(file_path) + with open(file_path, "w") as f: + yaml.dump(config_dict, f) + + # Sanity test config can be reloaded + _ = get_config_from_file(file_path, config_class=config_class) diff --git a/src/nanotron/config/lighteval_config.py b/src/nanotron/config/lighteval_config.py index b5f12059..3808d60c 100644 --- a/src/nanotron/config/lighteval_config.py +++ b/src/nanotron/config/lighteval_config.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from pathlib import Path from typing import Dict, Optional, Union from nanotron.config.parallelism_config import ParallelismArgs @@ -33,44 +32,27 @@ def __post_init__(self): class LightEvalLoggingArgs: """Arguments related to logging for LightEval""" - local_output_path: Optional[Path] = None - push_results_to_hub: Optional[bool] = None - push_details_to_hub: Optional[bool] = None - push_results_to_tensorboard: Optional[bool] = None - hub_repo_results: Optional[str] = None - hub_repo_details: Optional[str] = None - hub_repo_tensorboard: Optional[str] = None - tensorboard_metric_prefix: Optional[str] = None - - def __post_init__(self): - if isinstance(self.local_output_path, str): - self.local_output_path = Path(self.local_output_path) + output_dir: Optional[str] = None + save_details: bool = True + push_to_hub: bool = False + push_to_tensorboard: bool = False + public_run: bool = False + results_org: str | None = None + tensorboard_metric_prefix: str = "eval" @dataclass class LightEvalTasksArgs: """Arguments related to tasks for LightEval""" - tasks: Optional[str] = None + tasks: str custom_tasks: Optional[str] = None max_samples: Optional[int] = None num_fewshot_seeds: Optional[int] = None - dataset_loading_processes: Optional[int] = 8 + dataset_loading_processes: int = 8 multichoice_continuations_start_space: Optional[bool] = None - no_multichoice_continuations_start_space: Optional[bool] = None - - -@dataclass -class LightEvalWandbLoggerConfig: - """Arguments related to the local Wandb logger""" - - wandb_project: str = "" - wandb_entity: Optional[str] = None - wandb_run_name: Optional[str] = None - - def __post_init__(self): - assert self.wandb_project != "", "Please specify a wandb_project" + pair_wise_tokenization: bool = False @dataclass @@ -81,13 +63,8 @@ class LightEvalConfig: the saved config when running LightEval after training. """ - slurm_template: Optional[str] = None - slurm_script_dir: Optional[str] = None - - checkpoints_path: Optional[str] = None - parallelism: Optional[ParallelismArgs] = None - batch_size: Optional[int] = None + logging: LightEvalLoggingArgs + tasks: LightEvalTasksArgs + parallelism: ParallelismArgs + batch_size: int = 0 generation: Optional[Union[GenerationArgs, Dict[str, GenerationArgs]]] = None - tasks: Optional[LightEvalTasksArgs] = None - logging: Optional[LightEvalLoggingArgs] = None - wandb: Optional[LightEvalWandbLoggerConfig] = None diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index d92de405..34c5c076 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -55,6 +55,48 @@ class LlamaConfig: use_cache: bool = True vocab_size: int = 32000 + def get_llama_param_count(self): + # Embedding layer + embedding_params = self.vocab_size * self.hidden_size + + # Input RMS Norm + input_rms = self.num_hidden_layers * self.hidden_size + # Post attention RMS Norm + after_attention_rms = self.num_hidden_layers * self.hidden_size + + # Attention layers + attn_params = self.num_hidden_layers * ( + # Query projection + self.num_attention_heads * (self.hidden_size // self.num_attention_heads) * self.hidden_size + + # Key and Value projections (different than query in case of GQA) + 2 * self.num_key_value_heads * (self.hidden_size // self.num_attention_heads) * self.hidden_size + + # Output projection + self.num_attention_heads * (self.hidden_size // self.num_attention_heads) * self.hidden_size + ) + + # MLP layers + mlp_params = self.num_hidden_layers * ( + # First linear layer (2 for gated) + 2* self.hidden_size * self.intermediate_size + + # Second linear layer + self.intermediate_size * self.hidden_size + ) + + + # Final RMS Norm + final_rms = self.hidden_size + + total_params = ( + embedding_params + + input_rms + + after_attention_rms + + attn_params + + mlp_params + + final_rms + ) + + return total_params + def __post_init__(self): # NOTE: user don't set self._init_method, ModelArgs will set it # then we only pass LlamaConfig around diff --git a/src/nanotron/config/utils_config.py b/src/nanotron/config/utils_config.py index f4c07146..87d69585 100644 --- a/src/nanotron/config/utils_config.py +++ b/src/nanotron/config/utils_config.py @@ -4,6 +4,7 @@ import torch +from datasets.download.streaming_download_manager import xPath from nanotron.generation.sampler import SamplerType from nanotron.parallel.pipeline_parallel.engine import ( AllForwardAllBackwardPipelineEngine, @@ -31,6 +32,8 @@ def serialize(data) -> dict: value = getattr(data, field.name) if hasattr(value, "__dataclass_fields__"): result[field.name] = serialize(value) + elif isinstance(value, xPath): + result[field.name] = str(value) elif isinstance(value, Path): result[field.name] = str(value) elif isinstance(value, PipelineEngine): diff --git a/src/nanotron/lighteval/__init__.py b/src/nanotron/lighteval/__init__.py new file mode 100644 index 00000000..d7ea002c --- /dev/null +++ b/src/nanotron/lighteval/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa: F401 + +from .one_job_runner import LightEvalRunner diff --git a/src/nanotron/lighteval/evaluation_tasks.py b/src/nanotron/lighteval/evaluation_tasks.py new file mode 100644 index 00000000..2dd9820c --- /dev/null +++ b/src/nanotron/lighteval/evaluation_tasks.py @@ -0,0 +1,659 @@ +# ruff: noqa: F405, F403, F401 +""" +Custom evaluation tasks for lighteval + +This file generally create just a TASKS_TABLE and TASKS_GROUPS which are then imported by LightEval. +""" +import re +from dataclasses import asdict +from typing import Dict, List, Tuple + +import lighteval.tasks.default_prompts as prompt +from lighteval.metrics.metrics import Metrics +from lighteval.tasks.default_prompts import LETTER_INDICES +from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + +_TASKS_STRINGS: List[Tuple[LightevalTaskConfig, str]] = [] +_TASKS: List[LightevalTaskConfig] = [] + +trust_remote_code = True + +## COMMON_SENSE_REASONING_TASKS ## + + +def commonsense_qa_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=line["question"], + choices=[f" {c}" for c in line["choices"]["text"]], + gold_index=LETTER_INDICES.index(line["answerKey"].strip()), + instruction="", + ) + + +def siqa_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=line["context"] + " " + line["question"], + choices=[f" {c}" for c in [line["answerA"], line["answerB"], line["answerC"]]], + gold_index=int(line["label"]) - 1, + instruction="", + ) + + +COMMON_SENSE_REASONING_TASKS = [ + LightevalTaskConfig( + name="hellaswag", + prompt_function=prompt.hellaswag_harness, # Updated prompt function + hf_repo="hellaswag", + hf_subset="default", + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm], # Updated metric + trust_dataset=trust_remote_code, + ), + LightevalTaskConfig( + name="winogrande", + prompt_function=prompt.winogrande, # Updated prompt function + hf_repo="winogrande", + hf_subset="winogrande_xl", + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm], # Updated metric + trust_dataset=trust_remote_code, + ), + LightevalTaskConfig( + name="piqa", + prompt_function=prompt.piqa_harness, # Updated prompt function + hf_repo="piqa", + hf_subset="plain_text", + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm], # Updated metric + trust_dataset=trust_remote_code, + ), + LightevalTaskConfig( + name="siqa", + prompt_function=siqa_prompt, # Updated prompt function + hf_repo="lighteval/siqa", + hf_subset="default", + hf_avail_splits=["train", "validation"], + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm], # Updated metric + trust_dataset=trust_remote_code, + ), + LightevalTaskConfig( + name="openbookqa", + prompt_function=prompt.openbookqa, # Updated prompt function + hf_repo="openbookqa", + hf_subset="main", + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm], # Updated metric + trust_dataset=trust_remote_code, + ), + LightevalTaskConfig( + name="arc:easy", + prompt_function=prompt.arc, # Updated prompt function + hf_repo="ai2_arc", + hf_subset="ARC-Easy", + evaluation_splits=["test"], + generation_size=1, + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm], # Updated metric + trust_dataset=trust_remote_code, + ), + LightevalTaskConfig( + name="arc:challenge", + prompt_function=prompt.arc, # Updated prompt function + hf_repo="ai2_arc", + hf_subset="ARC-Challenge", + evaluation_splits=["test"], + generation_size=1, + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm], # Updated metric + trust_dataset=trust_remote_code, + ), + LightevalTaskConfig( + name="commonsense_qa", + prompt_function=commonsense_qa_prompt, # Updated prompt function + hf_repo="commonsense_qa", + hf_subset="default", + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm], # Updated metric + trust_dataset=trust_remote_code, + ), +] + + +# 0 short for common sense +COMMON_SENSE_REASONING_STRING = [(t, f"custom|{t.name}|0|1") for t in COMMON_SENSE_REASONING_TASKS] +_TASKS_STRINGS.extend(COMMON_SENSE_REASONING_STRING) +_TASKS += COMMON_SENSE_REASONING_TASKS + +## WORLD_KNOWLEDGE_TASKS ## + + +def natural_questions_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=line["question"] + "?\nAnswer: ", + choices=[line["short_answers"]], + gold_index=0, + instruction="", + ) + + +WORLD_KNOWLEDGE_TASKS = [ + LightevalTaskConfig( + name="trivia_qa", + prompt_function=prompt.triviaqa, + hf_repo="trivia_qa", + hf_subset="rc.nocontext", + metric=[Metrics.quasi_exact_match], + generation_size=20, + stop_sequence=["\n", ".", ","], + trust_dataset=trust_remote_code, + ), + LightevalTaskConfig( + name="natural_questions", + prompt_function=natural_questions_prompt, + hf_repo="lighteval/natural_questions_clean", + hf_subset="default", + metric=[Metrics.quasi_exact_match], + generation_size=20, + stop_sequence=["\n", ".", ","], + trust_dataset=trust_remote_code, + ), +] + + +WORLD_KNOWLEDGE_STRING = [(t, f"custom|{t.name}|5|1") for t in WORLD_KNOWLEDGE_TASKS] +# WORLD_KNOWLEDGE_STRING = {t: f'custom|{t.name}|0|1' for t in WORLD_KNOWLEDGE_TASKS} +_TASKS_STRINGS.extend(WORLD_KNOWLEDGE_STRING) +_TASKS += WORLD_KNOWLEDGE_TASKS + +## Reading comprehension ## +def boolq_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"{line['passage']}\nQuestion: {line['question'].capitalize()}?\nAnswer:", + choices=[" No", " Yes"], # Only gold + gold_index=int(line["label"]), + ) + + +READING_COMP_TASKS = [ + LightevalTaskConfig( + name="super_glue:boolq", + prompt_function=boolq_prompt, + hf_repo="super_glue", + hf_subset="boolq", + metric=[Metrics.target_perplexity], + trust_dataset=trust_remote_code, + ), + LightevalTaskConfig( + name="quac", + prompt_function=prompt.quac, + hf_repo="lighteval/quac_helm", + hf_subset="deault", + metric=[Metrics.quasi_exact_match], + generation_size=20, + stop_sequence=["\n", ".", ","], + trust_dataset=trust_remote_code, + ), +] + + +READING_COMP_STRING = [(t, f"custom|{t.name}|0|1") for t in READING_COMP_TASKS] +_TASKS_STRINGS.extend(READING_COMP_STRING) +_TASKS += READING_COMP_TASKS + + +## MATH ## +class CustomMathEvaluationTask(LightevalTaskConfig): + """Custom class for math tasks with all the defaults set""" + + def __init__( + self, + name, + prompt_function=prompt.math, + hf_repo="lighteval/MATH", + hf_subset=None, + metric=[Metrics.quasi_exact_match_math], + hf_avail_splits=None, + evaluation_splits=["test"], + few_shots_split=None, + few_shots_select=None, + suite=["custom"], + generation_size=40, + stop_sequence=None, + output_regex=None, + frozen=False, + trust_dataset=trust_remote_code, + ): + super().__init__( + name=name, + prompt_function=prompt_function, + hf_repo=hf_repo, + hf_subset=hf_subset, + metric=metric, + hf_avail_splits=hf_avail_splits, + evaluation_splits=evaluation_splits, + few_shots_split=few_shots_split, + few_shots_select=few_shots_select, + suite=suite, + generation_size=generation_size, + stop_sequence=stop_sequence, + output_regex=output_regex, + frozen=frozen, + trust_dataset=trust_dataset, + ) + + +MATH_TASKS = [ + CustomMathEvaluationTask(name="math:algebra", hf_subset="algebra"), + CustomMathEvaluationTask(name="math:counting_and_probability", hf_subset="counting_and_probability"), + CustomMathEvaluationTask(name="math:geometry", hf_subset="geometry"), + CustomMathEvaluationTask(name="math:intermediate_algebra", hf_subset="intermediate_algebra"), + CustomMathEvaluationTask(name="math:number_theory", hf_subset="number_theory"), + CustomMathEvaluationTask(name="math:prealgebra", hf_subset="prealgebra"), + CustomMathEvaluationTask(name="math:precalculus", hf_subset="precalculus"), +] +GSM8K = LightevalTaskConfig( + name="gsm8k", + prompt_function=prompt.gsm8k, + hf_repo="gsm8k", + hf_subset="main", + hf_avail_splits=["train", "test"], + evaluation_splits=["test"], + metric=[Metrics.perfect_exact_match], + generation_size=10, + stop_sequence=["\n"], + trust_dataset=trust_remote_code, +) + + +MATH_STRING = [(t, f"custom|{t.name}|4|1") for t in MATH_TASKS] +GSM8K_STRING = [(GSM8K, f"custom|{GSM8K.name}|8|1")] +_TASKS_STRINGS.extend(MATH_STRING) +_TASKS_STRINGS.extend(GSM8K_STRING) +_TASKS += MATH_TASKS + [GSM8K] + + +## MMLU ## +def mmlu_harness(line, task_name: str = None): + topic = line["subject"] + prompt = f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n" + prompt += line["question"] + "\n" + prompt += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["choices"])]) + prompt += "Answer:" + + gold_ix = LETTER_INDICES.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"] + "__few_shots" in line and line["__few_shots"] is True # We are adding few shots + + return Doc( + task_name=task_name, + query=prompt, + choices=[" A", " B", " C", " D"], + target_for_fewshot_sorting=[" A", " B", " C", " D"][gold_ix], + gold_index=gold_ix, + instruction=f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n", + ) + + +def mmlu_prompt(line, task_name: str = None): + """MMLU prompt without letters""" + topic = line["subject"] + prompt = f"The following are questions about {topic.replace('_', ' ')}.\nQuestion: " + prompt += line["question"] + "\nAnswer:" + + return Doc( + task_name=task_name, + query=prompt, + choices=[f" {c}" for c in line["choices"]], + gold_index=line["answer"], + instruction=f"The following are questions about {topic.replace('_', ' ')}.\n", + ) + + +class CustomMMLUEvaluationTask(LightevalTaskConfig): + def __init__( + self, + name, + prompt_function=mmlu_prompt, + hf_repo="lighteval/mmlu", + hf_subset=None, + # metric=[Metrics.loglikelihood_acc_single_token], + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm], + hf_avail_splits=None, + evaluation_splits=["test"], + few_shots_split="dev", + few_shots_select=None, + suite=["custom"], + generation_size=-1, + stop_sequence=None, + output_regex=None, + frozen=False, + trust_dataset=trust_remote_code, + ): + super().__init__( + name=name, + prompt_function=prompt_function, + hf_repo=hf_repo, + hf_subset=hf_subset, + metric=metric, + hf_avail_splits=hf_avail_splits, + evaluation_splits=evaluation_splits, + few_shots_split=few_shots_split, + few_shots_select=few_shots_select, + suite=suite, + generation_size=generation_size, + stop_sequence=stop_sequence, + output_regex=output_regex, + frozen=frozen, + trust_dataset=trust_dataset, + ) + + +MMLU_TASKS = [ + CustomMMLUEvaluationTask(name="mmlu:abstract_algebra", hf_subset="abstract_algebra"), + CustomMMLUEvaluationTask(name="mmlu:anatomy", hf_subset="anatomy"), + CustomMMLUEvaluationTask(name="mmlu:astronomy", hf_subset="astronomy"), + CustomMMLUEvaluationTask(name="mmlu:business_ethics", hf_subset="business_ethics"), + CustomMMLUEvaluationTask(name="mmlu:clinical_knowledge", hf_subset="clinical_knowledge"), + CustomMMLUEvaluationTask(name="mmlu:college_biology", hf_subset="college_biology"), + CustomMMLUEvaluationTask(name="mmlu:college_chemistry", hf_subset="college_chemistry"), + CustomMMLUEvaluationTask(name="mmlu:college_computer_science", hf_subset="college_computer_science"), + CustomMMLUEvaluationTask(name="mmlu:college_mathematics", hf_subset="college_mathematics"), + CustomMMLUEvaluationTask(name="mmlu:college_medicine", hf_subset="college_medicine"), + CustomMMLUEvaluationTask(name="mmlu:college_physics", hf_subset="college_physics"), + CustomMMLUEvaluationTask(name="mmlu:computer_security", hf_subset="computer_security"), + CustomMMLUEvaluationTask(name="mmlu:conceptual_physics", hf_subset="conceptual_physics"), + CustomMMLUEvaluationTask(name="mmlu:econometrics", hf_subset="econometrics"), + CustomMMLUEvaluationTask(name="mmlu:electrical_engineering", hf_subset="electrical_engineering"), + CustomMMLUEvaluationTask(name="mmlu:elementary_mathematics", hf_subset="elementary_mathematics"), + CustomMMLUEvaluationTask(name="mmlu:formal_logic", hf_subset="formal_logic"), + CustomMMLUEvaluationTask(name="mmlu:global_facts", hf_subset="global_facts"), + CustomMMLUEvaluationTask(name="mmlu:high_school_biology", hf_subset="high_school_biology"), + CustomMMLUEvaluationTask(name="mmlu:high_school_chemistry", hf_subset="high_school_chemistry"), + CustomMMLUEvaluationTask(name="mmlu:high_school_computer_science", hf_subset="high_school_computer_science"), + CustomMMLUEvaluationTask(name="mmlu:high_school_european_history", hf_subset="high_school_european_history"), + CustomMMLUEvaluationTask(name="mmlu:high_school_geography", hf_subset="high_school_geography"), + CustomMMLUEvaluationTask( + name="mmlu:high_school_government_and_politics", hf_subset="high_school_government_and_politics" + ), + CustomMMLUEvaluationTask(name="mmlu:high_school_macroeconomics", hf_subset="high_school_macroeconomics"), + CustomMMLUEvaluationTask(name="mmlu:high_school_mathematics", hf_subset="high_school_mathematics"), + CustomMMLUEvaluationTask(name="mmlu:high_school_microeconomics", hf_subset="high_school_microeconomics"), + CustomMMLUEvaluationTask(name="mmlu:high_school_physics", hf_subset="high_school_physics"), + CustomMMLUEvaluationTask(name="mmlu:high_school_psychology", hf_subset="high_school_psychology"), + CustomMMLUEvaluationTask(name="mmlu:high_school_statistics", hf_subset="high_school_statistics"), + CustomMMLUEvaluationTask(name="mmlu:high_school_us_history", hf_subset="high_school_us_history"), + CustomMMLUEvaluationTask(name="mmlu:high_school_world_history", hf_subset="high_school_world_history"), + CustomMMLUEvaluationTask(name="mmlu:human_aging", hf_subset="human_aging"), + CustomMMLUEvaluationTask(name="mmlu:human_sexuality", hf_subset="human_sexuality"), + CustomMMLUEvaluationTask(name="mmlu:international_law", hf_subset="international_law"), + CustomMMLUEvaluationTask(name="mmlu:jurisprudence", hf_subset="jurisprudence"), + CustomMMLUEvaluationTask(name="mmlu:logical_fallacies", hf_subset="logical_fallacies"), + CustomMMLUEvaluationTask(name="mmlu:machine_learning", hf_subset="machine_learning"), + CustomMMLUEvaluationTask(name="mmlu:management", hf_subset="management"), + CustomMMLUEvaluationTask(name="mmlu:marketing", hf_subset="marketing"), + CustomMMLUEvaluationTask(name="mmlu:medical_genetics", hf_subset="medical_genetics"), + CustomMMLUEvaluationTask(name="mmlu:miscellaneous", hf_subset="miscellaneous"), + CustomMMLUEvaluationTask(name="mmlu:moral_disputes", hf_subset="moral_disputes"), + CustomMMLUEvaluationTask(name="mmlu:moral_scenarios", hf_subset="moral_scenarios"), + CustomMMLUEvaluationTask(name="mmlu:nutrition", hf_subset="nutrition"), + CustomMMLUEvaluationTask(name="mmlu:philosophy", hf_subset="philosophy"), + CustomMMLUEvaluationTask(name="mmlu:prehistory", hf_subset="prehistory"), + CustomMMLUEvaluationTask(name="mmlu:professional_accounting", hf_subset="professional_accounting"), + CustomMMLUEvaluationTask(name="mmlu:professional_law", hf_subset="professional_law"), + CustomMMLUEvaluationTask(name="mmlu:professional_medicine", hf_subset="professional_medicine"), + CustomMMLUEvaluationTask(name="mmlu:professional_psychology", hf_subset="professional_psychology"), + CustomMMLUEvaluationTask(name="mmlu:public_relations", hf_subset="public_relations"), + CustomMMLUEvaluationTask(name="mmlu:security_studies", hf_subset="security_studies"), + CustomMMLUEvaluationTask(name="mmlu:sociology", hf_subset="sociology"), + CustomMMLUEvaluationTask(name="mmlu:us_foreign_policy", hf_subset="us_foreign_policy"), + CustomMMLUEvaluationTask(name="mmlu:virology", hf_subset="virology"), + CustomMMLUEvaluationTask(name="mmlu:world_religions", hf_subset="world_religions"), +] + + +# MMLU_STRING = {t: f'custom|{t.name}|5|1' for t in MMLU_TASKS} +MMLU_STRING = [(t, f"custom|{t.name}|0|1") for t in MMLU_TASKS] +_TASKS_STRINGS.extend(MMLU_STRING) +_TASKS += MMLU_TASKS + +## BBH ## + + +def bbh_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=line["input"] + "\nAnswer: ", + choices=[line["target"]], + gold_index=0, + ) + + +class CustomBBHEvaluationTask(LightevalTaskConfig): + def __init__( + self, + name, + prompt_function=bbh_prompt, + hf_repo="lighteval/big_bench_hard", + hf_subset=None, + metric=[Metrics.exact_match], + hf_avail_splits=["train"], + evaluation_splits=["train"], + few_shots_split="train", + few_shots_select=None, + suite=["custom"], + generation_size=4, + stop_sequence=None, + output_regex=None, + frozen=False, + trust_dataset=trust_remote_code, + ): + super().__init__( + name=name, + prompt_function=prompt_function, + hf_repo=hf_repo, + hf_subset=hf_subset, + metric=metric, + hf_avail_splits=hf_avail_splits, + evaluation_splits=evaluation_splits, + few_shots_split=few_shots_split, + few_shots_select=few_shots_select, + suite=suite, + generation_size=generation_size, + stop_sequence=stop_sequence, + output_regex=output_regex, + frozen=frozen, + trust_dataset=trust_dataset, + ) + + +BBH_TASKS = [ + CustomBBHEvaluationTask(name="bbh:boolean_expressions", hf_subset="boolean_expressions"), + CustomBBHEvaluationTask(name="bbh:causal_judgement", hf_subset="causal_judgement"), + CustomBBHEvaluationTask(name="bbh:date_understanding", hf_subset="date_understanding"), + CustomBBHEvaluationTask(name="bbh:disambiguation_qa", hf_subset="disambiguation_qa"), + CustomBBHEvaluationTask(name="bbh:dyck_languages", hf_subset="dyck_languages"), + CustomBBHEvaluationTask(name="bbh:formal_fallacies", hf_subset="formal_fallacies"), + CustomBBHEvaluationTask(name="bbh:geometric_shapes", hf_subset="geometric_shapes"), + CustomBBHEvaluationTask(name="bbh:hyperbaton", hf_subset="hyperbaton"), + CustomBBHEvaluationTask(name="bbh:logical_deduction_five_objects", hf_subset="logical_deduction_five_objects"), + CustomBBHEvaluationTask(name="bbh:logical_deduction_seven_objects", hf_subset="logical_deduction_seven_objects"), + CustomBBHEvaluationTask(name="bbh:logical_deduction_three_objects", hf_subset="logical_deduction_three_objects"), + CustomBBHEvaluationTask(name="bbh:movie_recommendation", hf_subset="movie_recommendation"), + CustomBBHEvaluationTask(name="bbh:multistep_arithmetic_two", hf_subset="multistep_arithmetic_two"), + CustomBBHEvaluationTask(name="bbh:navigate", hf_subset="navigate"), + CustomBBHEvaluationTask(name="bbh:object_counting", hf_subset="object_counting"), + CustomBBHEvaluationTask(name="bbh:penguins_in_a_table", hf_subset="penguins_in_a_table"), + CustomBBHEvaluationTask(name="bbh:reasoning_about_colored_objects", hf_subset="reasoning_about_colored_objects"), + CustomBBHEvaluationTask(name="bbh:ruin_names", hf_subset="ruin_names"), + CustomBBHEvaluationTask( + name="bbh:salient_translation_error_detection", hf_subset="salient_translation_error_detection" + ), + CustomBBHEvaluationTask(name="bbh:snarks", hf_subset="snarks"), + CustomBBHEvaluationTask(name="bbh:sports_understanding", hf_subset="sports_understanding"), + CustomBBHEvaluationTask(name="bbh:temporal_sequences", hf_subset="temporal_sequences"), + CustomBBHEvaluationTask( + name="bbh:tracking_shuffled_objects_five_objects", hf_subset="tracking_shuffled_objects_five_objects" + ), + CustomBBHEvaluationTask( + name="bbh:tracking_shuffled_objects_seven_objects", hf_subset="tracking_shuffled_objects_seven_objects" + ), + CustomBBHEvaluationTask( + name="bbh:tracking_shuffled_objects_three_objects", hf_subset="tracking_shuffled_objects_three_objects" + ), + CustomBBHEvaluationTask(name="bbh:web_of_lies", hf_subset="web_of_lies"), + CustomBBHEvaluationTask(name="bbh:word_sorting", hf_subset="word_sorting"), +] + + +# BBH_STRING = {t: f'custom|{t.name}|3|1' for t in BBH_TASKS} +BBH_STRING = [(t, f"custom|{t.name}|0|1") for t in BBH_TASKS] +_TASKS_STRINGS.extend(BBH_STRING) +_TASKS += BBH_TASKS + + +## AGI eval ## + + +def agi_eval_math_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=line["question"], + choices=[line["answer"]], + gold_index=0, + instruction="", + ) + + +def agi_eval_prompt(line, task_name: str = None): + cleaned_options = [o.replace("(", "").replace(")", " ") for o in line["options"]] + prompt = "The following are multiple choice questions (with answers).\n\n" + prompt += line["question"] + "\n" + "\n".join(cleaned_options) + "\n" + prompt += "Answer: " + + choices = LETTER_INDICES[: len(line["options"])] + + output = Doc( + query=prompt, + instruction="The following are multiple choice questions (with answers).\n\n", + choices=None, # updated below + gold_index=None, # updated below + ) + + if line["label"]: + output.choices = choices + output.gold_index = LETTER_INDICES.index(line["label"].strip()) + else: + output.choices = [line["answer"]] + output.gold_index = 0 + + return output + + +def agi_eval_prompt_no_letters(line, task_name: str = None): + cleaned_options = [ + " " + o.replace("(A)", "").replace("(B)", "").replace("(C)", "").replace("(D)", "").replace("(E)", "") + for o in line["options"] + ] + + output = Doc( + query=line["question"], + choices=cleaned_options, + gold_index=LETTER_INDICES.index(line["label"].strip()), + instruction="", + ) + + return output + + +class CustomAGIEvalEvaluationTask(LightevalTaskConfig): + def __init__( + self, + name, + prompt_function=agi_eval_prompt_no_letters, + hf_repo="lighteval/agi_eval_en", + hf_subset=None, + # metric=[Metrics.loglikelihood_acc_single_token], + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm], + hf_avail_splits=["train", "validation"], + evaluation_splits=["train"], + few_shots_split="validation", + few_shots_select=None, + suite=["custom"], + generation_size=-1, + stop_sequence=None, + output_regex=None, + frozen=False, + trust_dataset=trust_remote_code, + ): + super().__init__( + name=name, + prompt_function=prompt_function, + hf_repo=hf_repo, + hf_subset=hf_subset, + metric=metric, + hf_avail_splits=hf_avail_splits, + evaluation_splits=evaluation_splits, + few_shots_split=few_shots_split, + few_shots_select=few_shots_select, + suite=suite, + generation_size=generation_size, + stop_sequence=stop_sequence, + output_regex=output_regex, + frozen=frozen, + trust_dataset=trust_dataset, + ) + + +AGIEVAL_TASKS = [ + CustomAGIEvalEvaluationTask(name="agi_eval:aqua_rat", hf_subset="aqua_rat"), + CustomAGIEvalEvaluationTask(name="agi_eval:logiqa-en", hf_subset="logiqa-en"), + CustomAGIEvalEvaluationTask(name="agi_eval:lsat-ar", hf_subset="lsat-ar"), + CustomAGIEvalEvaluationTask(name="agi_eval:lsat-lr", hf_subset="lsat-lr"), + CustomAGIEvalEvaluationTask(name="agi_eval:lsat-rc", hf_subset="lsat-rc"), + CustomAGIEvalEvaluationTask( + name="agi_eval:math", + hf_subset="math", + prompt_function="agi_eval_math_prompt", + metric=[Metrics.exact_match, Metrics.quasi_exact_match], + generation_size=40, + ), + CustomAGIEvalEvaluationTask(name="agi_eval:sat-en", hf_subset="sat-en"), + CustomAGIEvalEvaluationTask(name="agi_eval:sat-math", hf_subset="sat-math"), +] + + +# AGIEVAL_STRING = {t: f'custom|{t.name}|5|1' for t in AGIEVAL_TASKS} +AGIEVAL_STRING = [(t, f"custom|{t.name}|0|1") for t in AGIEVAL_TASKS] +_TASKS_STRINGS.extend(AGIEVAL_STRING) +_TASKS += AGIEVAL_TASKS + + +OPEN_LLM_LEADERBOARD_STRING = [ + "custom|arc:challenge|25|1", + "custom|hellaswag|10|1", + "lighteval|truthfulqa:mc|0|1", + "custom|winogrande|5|1", + "lighteval|gsm8k|5|1", +] + [f"custom|{t.name}|5|1" for t in MMLU_TASKS] + + +## HUMAN EVAL ## +# human_eval = LightevalTaskConfig( +# name="human_eval", +# prompt_function="human_eval", +# hf_repo="lighteval/human_eval", +# metric=["human_eval_pass_at_1"], +# ), + + +EARLY_SIGNAL_TASKS = ",".join([t[1] for t in COMMON_SENSE_REASONING_STRING] + [t[1] for t in MMLU_STRING]) + +# Convert to dict for lighteval +TASKS_TABLE = _TASKS +# You can have a few pre-organised groups of tasks +TASKS_GROUPS = { + "all": ",".join(t[1] for t in _TASKS_STRINGS), + "early-signal": EARLY_SIGNAL_TASKS, + "open-llm-leaderboard": ",".join(OPEN_LLM_LEADERBOARD_STRING), +} + +if __name__ == "__main__": + print(t["name"] for t in TASKS_TABLE) + print(len(TASKS_TABLE)) diff --git a/src/nanotron/lighteval/one_job_runner.py b/src/nanotron/lighteval/one_job_runner.py new file mode 100644 index 00000000..3321e7ce --- /dev/null +++ b/src/nanotron/lighteval/one_job_runner.py @@ -0,0 +1,190 @@ +""" Mostly complete a SLURM template with a link to a single checkpoint on s3 and launch it +""" +import datetime +import json +import os +import re +import subprocess +from typing import List, Optional, Tuple + +import jinja2 + +from nanotron import logging +from nanotron.config import Config, LightEvalConfig +from nanotron.logging import log_rank +from nanotron.parallel import ParallelContext + +logger = logging.get_logger(__name__) + + +class LightEvalRunner: + def __init__(self, config: Config, parallel_context: Optional[ParallelContext] = None): + self.config = config + self.lighteval_config = config.lighteval + self.parallel_context = parallel_context + + def eval_single_checkpoint_no_s3(self, checkpoint_path: str) -> Tuple[str, str]: + if not os.path.exists(checkpoint_path): + log_rank( + f"Checkpoint path does not exist: {checkpoint_path}. Unable to evaluate.", + logger=logger, + level=logging.ERROR, + group=self.parallel_context.dp_pg if self.parallel_context is not None else None, + rank=0, + ) + return None, None + + slurm_job_id, slurm_log = run_slurm_one_job( + config=self.config, + lighteval_config=self.lighteval_config, + slurm_template=self.config.general.eval_slurm_template, + model_checkpoint_path=checkpoint_path, + current_step=self.config.general.step, + ) + + return slurm_job_id, slurm_log + + def eval_single_checkpoint(self, uploaded_files: List[dict]) -> Tuple[str, str]: + """Run light evaluation on uploaded files.""" + logger.warning(f"Lighteval Runner got {len(uploaded_files)} files. Checking configs.") + config_files = [ + f for f in uploaded_files if "config.py" in f["destination"] or "config.yaml" in f["destination"] + ] + # Sanity check on the config files len (we want only one) + if len(config_files) == 0: + log_rank( + "No config files founds in uploaded checkpoints. Not running evaluation.", + logger=logger, + level=logging.ERROR, + group=self.parallel_context.dp_pg if self.parallel_context is not None else None, + rank=0, + ) + return + if len(config_files) > 1: + log_rank( + "Found multiple config files in uploaded checkpoints.", + logger=logger, + level=logging.ERROR, + group=self.parallel_context.dp_pg if self.parallel_context is not None else None, + rank=0, + ) + return + checkpoint_path = config_files[0]["destination"].replace("config.yaml", "") + + slurm_job_id, slurm_log = run_slurm_one_job( + config=self.config, + lighteval_config=self.lighteval_config, + slurm_template=self.config.general.eval_slurm_template, + model_checkpoint_path=checkpoint_path, + current_step=self.config.general.step, + ) + + return slurm_job_id, slurm_log + + +def run_slurm_one_job( + config: Config, + lighteval_config: LightEvalConfig, + model_checkpoint_path: str, + slurm_template: str, + current_step: int, + slurm_name: Optional[str] = "eval", +): + """Launch a single job on Slurm with the given mapping + Args: + slurm_config: Slurm configuration + mapping: Mapping to use for the job script (see SLURM_ONE_JOB_MAPPING) + """ + s3 = config.general.is_s3_available + eval_launch_script_path = os.path.join(config.general.evals_logs_path, "launch-config", str(current_step)) + eval_logs_path = os.path.join(config.general.evals_logs_path, "logs", str(current_step)) + + with open(config.general.eval_slurm_config, "r") as f: + eval_slurm_config = json.load(f) + + os.makedirs(eval_launch_script_path, exist_ok=True) + os.makedirs(eval_logs_path, exist_ok=True) + + environment = jinja2.Environment( + comment_start_string="{=", + comment_end_string="=}", + ) + + with open(slurm_template, "r") as f: + SLURM_JOBS_ARRAY_TEMPLATE = environment.from_string(f.read()) + + # Update the config with additional required parameters + # Calculate the number of nodes based on parallelism config and gpus_per_node + total_gpus_needed = ( + lighteval_config.parallelism.dp * lighteval_config.parallelism.pp * lighteval_config.parallelism.tp + ) + gpus_per_node = eval_slurm_config.get("gpus_per_node") + nodes = (total_gpus_needed + gpus_per_node - 1) // gpus_per_node # Ceiling division + + if s3: + eval_slurm_config.update( + { + "nodes": nodes, # Assuming we want to run on a single node + "job_name": f"eval-{current_step}", + "eval_path": eval_logs_path, + "local_path": f"{config.general.temp_dir}/eval_{config.general.timestamp_with_run}/{current_step}", + "model_checkpoint_path": model_checkpoint_path, + "lighteval_config_path": config.general.lighteval_config_path, + } + ) + else: + eval_slurm_config.update( + { + "nodes": nodes, # Assuming we want to run on a single node + "job_name": f"eval-{current_step}", + "eval_path": eval_logs_path, + "model_checkpoint_path": model_checkpoint_path, + "lighteval_config_path": config.general.lighteval_config_path, + } + ) + + launch_string = SLURM_JOBS_ARRAY_TEMPLATE.render(**eval_slurm_config) + + match = re.match(r"#SBATCH --output=(.*)", launch_string) + slurm_output_path = match.group(1) if match else "" + + current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + launch_script_path = os.path.join(eval_launch_script_path, f"launch_script-{current_time}.slurm") + + # make sure the folder exists before write + # Extract the folder path from launch_script_path + folder_path = os.path.dirname(launch_script_path) + + # Check if the folder exists. If not, create it. + if not os.path.exists(folder_path): + os.makedirs(folder_path) + + with open(launch_script_path, "w") as f: + f.write(launch_string) + + logger.warning(f'Launching Slurm job {slurm_name} with launch script "{launch_script_path}"') + + # Preserve important environment variables + env = { + "PATH": os.environ["PATH"], + "LD_LIBRARY_PATH": os.environ.get("LD_LIBRARY_PATH", ""), + "HOME": os.path.expanduser("~"), + } + + try: + # Use subprocess.run instead of check_output for better error handling + result = subprocess.run(["sbatch", launch_script_path], env=env, check=True, capture_output=True, text=True) + output = result.stdout + job_ids = output.split()[-1] + output_log = ( + slurm_output_path.replace("%x", slurm_name).replace("%j", job_ids).replace("%n", "0").replace("%t", "0") + ) + logger.warning(f'Slurm job launched successfully with id={job_ids}, logging outputs at "{output_log}"') + except subprocess.CalledProcessError as e: + logger.error(f"Error while launching Slurm job: {e}") + logger.error(f"Command output: {e.output}") + logger.error(f"Command stderr: {e.stderr}") + job_ids = None + output_log = None + + return job_ids, output_log diff --git a/src/nanotron/lighteval/run_evals.py b/src/nanotron/lighteval/run_evals.py new file mode 100644 index 00000000..1fd4b178 --- /dev/null +++ b/src/nanotron/lighteval/run_evals.py @@ -0,0 +1,39 @@ +# flake8: noqa: C901 +import argparse + +from lighteval.main_nanotron import main + +from nanotron.config import Config + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--checkpoint-config-path", + type=str, + required=True, + help="Path to the Nanotron checkpoint YAML or python config file, potentially on S3", + ) + parser.add_argument( + "--lighteval-config-path", + type=str, + required=True, + help="Path to an optional YAML or python Lighteval config", + ) + parser.add_argument( + "--cache-dir", + type=str, + default=None, + help="Cache directory", + ) + return parser + + +if __name__ == "__main__": + parser = get_parser() + args, unknowns = parser.parse_known_args() + main( + checkpoint_config_path=args.checkpoint_config_path, + lighteval_config_path=args.lighteval_config_path, + cache_dir=args.cache_dir, + ) diff --git a/src/nanotron/logging.py b/src/nanotron/logging.py index 708393b5..cdbd0b78 100644 --- a/src/nanotron/logging.py +++ b/src/nanotron/logging.py @@ -236,7 +236,7 @@ def warn_once( def human_format(num: float, billions: bool = False, divide_by_1024: bool = False) -> str: if abs(num) < 1: return "{:.3g}".format(num) - SIZES = ["", "K", "M", "G", "T", "P", "E"] + SIZES = ["", "K", "M", "B", "T", "P", "E"] num = float("{:.3g}".format(num)) magnitude = 0 i = 0 diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 88fb6bcb..e7510e58 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -917,8 +917,9 @@ def get_block_compute_costs(self): d_qkv = model_config.hidden_size // model_config.num_attention_heads block_compute_costs = { # CausalSelfAttention (qkv proj + attn out) + MLP - LlamaDecoderLayer: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size - + 3 * d_ff * model_config.hidden_size, + LlamaDecoderLayer: 2 * model_config.num_attention_heads * d_qkv * model_config.hidden_size # Q output projection + + 2 * model_config.num_key_value_heads * d_qkv * model_config.hidden_size # KV + + 3 * d_ff * model_config.hidden_size, # for the MLP (3 because of the gated mechanism) # This is the last lm_head TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size, } @@ -1132,6 +1133,7 @@ def get_flops( 2 * num_layers * batch_size * seq_len * (hidden_size) * num_heads * hidden_size_per_head + 2 * num_layers * batch_size * seq_len * (hidden_size) * 2 * num_key_value_heads * hidden_size_per_head ) + ## qk logits decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (hidden_size_per_head) * seq_len ## v logits @@ -1146,6 +1148,10 @@ def get_flops( ## 2nd layer decoder_ffn_2_flops_fwd = 2 * num_layers * batch_size * seq_len * (ffn_hidden_size) * hidden_size + # Layer Norm (RMSNorm for LLaMA) + # There are typically 2 layer norms per transformer layer, plus one at the end + layer_norm_flops_fwd = (2 * num_layers + 1) * batch_size * seq_len * hidden_size * 2 # multiply by 2 for division and square root (square root take significatively more time ?) + decoder_flops_fwd = ( decoder_qkv_proj_flops_fwd + decoder_qk_logits_flops_fwd @@ -1153,6 +1159,7 @@ def get_flops( + decoder_attn_out_flops_fwd + decoder_ffn_1_flops_fwd + decoder_ffn_2_flops_fwd + + layer_norm_flops_fwd ) # lm head diff --git a/src/nanotron/serialize/main.py b/src/nanotron/serialize/main.py index 7991cbd4..c6c5fb07 100644 --- a/src/nanotron/serialize/main.py +++ b/src/nanotron/serialize/main.py @@ -3,7 +3,6 @@ from typing import Optional, cast import torch -from datasets.download.streaming_download_manager import xPath from torch import nn from torch.nn.parallel import DistributedDataParallel from torch.optim.lr_scheduler import LambdaLR @@ -255,7 +254,7 @@ def parse_ckpt_path(config: Config, parallel_context: ParallelContext) -> Option load_from_candidate = config.checkpoints.resume_checkpoint_path if load_from_candidate is not None: if check_path_is_local(load_from_candidate): - latest_meta_path: xPath = config.checkpoints.resume_checkpoint_path / "latest.txt" + latest_meta_path: Path = config.checkpoints.resume_checkpoint_path / "latest.txt" if latest_meta_path.exists(): with fs_open(config.checkpoints.resume_checkpoint_path / "latest.txt", mode="r") as fi: # TODO @thomasw21: make a better structure system so that we get typing correct diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 45d704ee..62986fba 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -46,6 +46,7 @@ log_throughput, lr_scheduler_builder, ) +from nanotron.lighteval import LightEvalRunner from nanotron.logging import ( LoggerWriter, LogItem, @@ -259,9 +260,9 @@ def pre_init(self): self.init_checkpoint_path = parse_ckpt_path(config=self.config, parallel_context=self.parallel_context) def post_init(self): - # S3 Mover and save initial state - if self.config.s3_upload is not None: - # NOTE: Only local rank 0 should upload + # S3 Mover and save initial state (only if we need to upload checkpoints on s3) + if self.config.s3_upload is not None and self.config.s3_upload.upload_s3_path is not None: + # Only local rank 0 should upload dummy = bool(int(os.environ.get("LOCAL_RANK", None)) != 0) self.s3_mover = S3Mover( local_path=self.config.checkpoints.checkpoints_path, @@ -274,6 +275,22 @@ def post_init(self): ) else: self.s3_mover = None + if dist.get_rank(self.parallel_context.world_pg) == 0: + # check if slurm is configured + # TODO @eliebak rewrite the logic by self.slurm + there can be s3upload AND checkpoint path local which is not support for now + if self.config.lighteval is not None and self.config.general.eval_slurm_config is not None: + self.lighteval_runner = LightEvalRunner(config=self.config, parallel_context=self.parallel_context) + if self.s3_mover is not None: + self.s3_mover.post_upload_callback = self.lighteval_runner.eval_single_checkpoint + self.post_checkpoint_callback = None + else: + # Use the no_s3 version of the evaluation function + # TODO: make it one function + make it automatic to switch to the right jinja template + self.post_checkpoint_callback = self.lighteval_runner.eval_single_checkpoint_no_s3 + else: + self.post_checkpoint_callback = None + else: + self.post_checkpoint_callback = None def pre_training(self, *args, **kwargs): self._print_training_plan() @@ -287,16 +304,27 @@ def pre_training(self, *args, **kwargs): rank=0, ) - current_time = datetime.datetime.now().strftime("%d/%m/%Y_%H:%M:%S") - if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0] and wandb is not None: + if dist.get_rank(self.parallel_context.world_pg) == 0 and wandb is not None: wandb.init( project=self.config.general.project, - name=f"{current_time}_{self.config.general.run}", + name=f"{self.config.general.run}_{self.config.general.timestamp_with_run}", config={"nanotron_config": self.config.as_dict()}, ) + # Define tokens metric as x-axis for all metrics + wandb.define_metric("Tokens") + wandb.define_metric("*", step_metric="Tokens") - def post_train_step(self): + # Handle resuming from a previous run + initial_step = getattr(self.config.general, "step", 0) + if initial_step is None: + initial_step = 0 + initial_tokens = initial_step * self.global_batch_size + + # Log initial tokens to set the starting point + wandb.log({"Tokens": initial_tokens}) + + def post_train_step(self): # Update our background upload/removal of checkpoints if self.s3_mover is not None: self.s3_mover.update() @@ -593,7 +621,7 @@ def train_step_logs( lr = self.lr_scheduler.get_last_lr()[0] log_entries = [ - # LogItem("consumed_samples", self.consumed_train_samples, "human_format"), # , "12d"), + # LogItem("consumed_samples", self.metadata.consumed_train_samples, "human_format"), # , "12d"), LogItem( "consumed_tokens", self.metadata.consumed_train_samples * self.config.tokens.sequence_length, @@ -637,6 +665,7 @@ def train_step_logs( { **{log_item.tag: log_item.scalar_value for log_item in log_entries}, "iteration_step": self.iteration_step, + "Tokens": self.metadata.consumed_train_samples * self.config.tokens.sequence_length, } ) @@ -723,6 +752,7 @@ def _load_model_checkpoint(self, model: NanotronModel) -> NanotronModel: root_folder=self.init_checkpoint_path, ) reloaded_from_checkpoint = True + if not reloaded_from_checkpoint: log_rank("No checkpoint path provided.", logger=logger, level=logging.INFO, rank=0) if isinstance(self.config.model.init_method, ExistingCheckpointInit): @@ -859,18 +889,27 @@ def pre_save_checkpoint(self) -> Path: self.s3_mover.distributed_wait_for_completion(self.parallel_context.world_pg) if self.s3_mover.post_upload_callback_outputs is not None: slurm_job_id, slurm_log = self.s3_mover.post_upload_callback_outputs - self.log_object({"job_id": slurm_job_id, "log": slurm_log}, "slurm_eval") + log_rank( + f"launching eval job: job_id={slurm_job_id} log at {slurm_log} slurm_eval", + logger=logger, + level=logging.INFO, + rank=0, + ) def post_save_checkpoint(self): # Upload to S3 if self.s3_mover is not None: self.s3_mover.start_uploading() + elif self.post_checkpoint_callback is not None: + # If we're not using S3, but we have a post-checkpoint callback for evals + checkpoint_path = Path(self.config.checkpoints.checkpoints_path) / f"{self.config.general.step}" + self.post_checkpoint_callback(checkpoint_path) + def save_checkpoint(self) -> Path: self.pre_save_checkpoint() - checkpoints_path = self.config.checkpoints.checkpoints_path - checkpoint_path = checkpoints_path / f"{self.iteration_step}" + checkpoint_path = Path(checkpoints_path) / f"{self.iteration_step}" if self.config.checkpoints.checkpoints_path_is_shared_file_system: should_mkdir = dist.get_rank(self.parallel_context.world_pg) == 0 else: diff --git a/src/nanotron/utils.py b/src/nanotron/utils.py index b3831801..07cd4898 100644 --- a/src/nanotron/utils.py +++ b/src/nanotron/utils.py @@ -2,6 +2,7 @@ import inspect import os import random +import re import socket from contextlib import ExitStack, contextmanager from typing import ContextManager, List, Optional @@ -160,3 +161,9 @@ def find_free_port(min_port: int = 2000, max_port: int = 65000) -> int: return port except OSError: continue + + +def check_path_is_s3(path: str) -> bool: + # TODO maybe replace by a better method ? + s3_pattern = r"^s3://|^https?://[\w\-\.]+\.s3\.amazonaws\.com/|^https?://s3\.amazonaws\.com/[\w\-\.]+" + return bool(re.match(s3_pattern, path))