Skip to content

Commit

Permalink
add initial configs for perf testing on ESM2 in JET (bionemo2) (#497)
Browse files Browse the repository at this point in the history
Adding ESM2 650M conv + perf partial configs for benchmarking with
required changes
  • Loading branch information
dorotat-nv authored Jan 7, 2025
1 parent dca314e commit 5eddee1
Show file tree
Hide file tree
Showing 11 changed files with 326 additions and 37 deletions.
54 changes: 54 additions & 0 deletions ci/benchmarks/partial-conv/esm2_pretrain.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
scope: partial-conv
time_limit: 14400
script_args:
# All arguments referenced in the script string must be specified here.
# Arguments not referenced in the script string must have the 'arg' field specified.
# See jet/core/configs.py for the specification of the configuration class
workspace:
value: /workspace/bionemo2
key_segment: False
data_path:
value: /data/20240809_uniref_2024_03/data
key_segment: False
model:
value: esm2
variant:
value: train
config_name:
value: 650M
precision:
value: [bf16-mixed]
nodes:
value: [4]
gpus:
value: 8
batch_size:
value: 16
max_steps:
value: 26500
script: |-
WANDB_API_KEY=$BIONEMO_WANDB_API_KEY ${variant}_${model} \
--train-cluster-path=${data_path}/train_clusters.parquet \
--train-database-path=${data_path}/train.db \
--valid-cluster-path=${data_path}/valid_clusters.parquet \
--valid-database-path=${data_path}/validation.db \
--micro-batch-size=${batch_size} \
--num-nodes=${nodes} \
--num-gpus=${gpus} \
--val-check-interval=1000 \
--limit-val-batches=1 \
--num-steps=${max_steps} \
--min-seq-length=1024 \
--max-seq-length=1024 \
--num-layers=33 \
--hidden-size=1280 \
--num-attention-heads=20 \
--ffn-hidden-size=5120 \
--create-tensorboard-logger \
--experiment-name=${batch_size}bs_${nodes}node_${gpus}gpu_${max_steps}s_${precision}prec \
--result-dir=${tensorboard_dir} \
--wandb-project=${wandb_project_name} \
--wandb-group=${model}_${variant}_${config_name} \
--wandb-job-type=${pipeline_label} \
--log-every-n-steps=50 \
--disable-checkpointing;
65 changes: 65 additions & 0 deletions ci/benchmarks/perf/esm2_pretrain.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
scope: perf
time_limit: 1800
script_args:
# All arguments referenced in the script string must be specified here.
# Arguments not referenced in the script string must have the 'arg' field specified.
# See jet/core/configs.py for the specification of the configuration class
workspace:
value: /workspace/bionemo2
key_segment: False
data_path:
value: /data/20240809_uniref_2024_03/data
key_segment: False
model: esm2
variant: train
config_name: 650M
precision: bf16-mixed
max_steps: 200
gpus: 8
acc_grad: 1
products:
- nodes: 1
batch_size: 16
pp: 1
tp: 1
- nodes: 2
batch_size: 16
pp: 2
tp: 1
- nodes: 2
batch_size: 16
pp: 1
tp: 2
- nodes: 2
batch_size: 16
pp: 1
tp: 1
script: |-
WANDB_API_KEY=$BIONEMO_WANDB_API_KEY ${variant}_${model} \
--train-cluster-path=${data_path}/train_clusters.parquet \
--train-database-path=${data_path}/train.db \
--valid-cluster-path=${data_path}/valid_clusters.parquet \
--valid-database-path=${data_path}/validation.db \
--micro-batch-size=${batch_size} \
--num-nodes=${nodes} \
--num-gpus=${gpus} \
--val-check-interval=50 \
--limit-val-batches=1 \
--num-steps=${max_steps} \
--min-seq-length=1024 \
--max-seq-length=1024 \
--num-layers=33 \
--hidden-size=1280 \
--num-attention-heads=20 \
--ffn-hidden-size=5120 \
--create-tensorboard-logger \
--experiment-name=${batch_size}bs_${nodes}node_${gpus}gpu_${max_steps}s_${precision}prec \
--result-dir=${tensorboard_dir} \
--wandb-project=${wandb_project_name} \
--wandb-group=${model}_${variant}_${config_name} \
--wandb-job-type=${pipeline_label} \
--log-every-n-steps=10 \
--accumulate-grad-batches=${acc_grad} \
--pipeline-model-parallel-size=${pp} \
--tensor-model-parallel-size={tp} \
--disable-checkpointing;
21 changes: 19 additions & 2 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/run/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ def parse_args():
default=[0],
help="Enable nsys profiling for these ranks.",
)
parser.add_argument(
"--disable-checkpointing",
action="store_false",
default=True,
dest="create_checkpoint_callback",
help="Disable creating a ModelCheckpoint callback.",
)
return parser.parse_args()

def string_to_class(path: str):
Expand All @@ -87,7 +94,12 @@ def string_to_class(path: str):
module = importlib.import_module(module_path)
return getattr(module, class_name)

def load_config(config_path: str, model_config_cls: Optional[str], data_config_cls: Optional[str]) -> MainConfig:
def load_config(
config_path: str,
model_config_cls: Optional[str],
data_config_cls: Optional[str],
create_checkpoint_callback: bool,
) -> MainConfig:
with open(config_path, "r") as f:
config_dict = yaml.safe_load(f)

Expand All @@ -109,10 +121,15 @@ def load_config(config_path: str, model_config_cls: Optional[str], data_config_c
elif isinstance(data_config_cls, str):
data_config_cls = string_to_class(data_config_cls)

# disable checkpointing if called from the command line
if not create_checkpoint_callback:
config_dict["training_config"]["enable_checkpointing"] = create_checkpoint_callback
config_dict["experiment_config"]["create_checkpoint_callback"] = create_checkpoint_callback

return MainConfig[model_config_cls, data_config_cls](**config_dict)

args = parse_args()
config = load_config(args.config, args.model_config_cls, args.data_config_cls)
config = load_config(args.config, args.model_config_cls, args.data_config_cls, args.create_checkpoint_callback)

if args.nsys_profiling:
nsys_config = NsysConfig(
Expand Down
40 changes: 32 additions & 8 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,15 @@ def main(
wandb_offline: bool = False,
wandb_tags: Optional[List[str]] = None,
wandb_group: Optional[str] = None,
wandb_job_type: Optional[str] = None,
wandb_id: Optional[str] = None,
wandb_anonymous: Optional[bool] = False,
wandb_log_model: bool = False,
pipeline_model_parallel_size: int = 1,
tensor_model_parallel_size: int = 1,
create_tensorboard_logger: bool = False,
nemo1_init_path: Optional[Path] = None,
create_checkpoint_callback: bool = True,
restore_from_checkpoint_path: Optional[str] = None,
save_best_checkpoint: bool = True,
save_last_checkpoint: bool = True,
Expand Down Expand Up @@ -129,13 +131,15 @@ def main(
wandb_offline (bool): Run offline (data can be streamed later to wandb servers).
wandb_tags (Optional[List[str]]): Tags associated with this run
wandb_group (Optional[str]): A unique string shared by all runs in a given group
wandb_job_type (Optional[str]): Type of run, which is useful when you're grouping runs together into larger experiments using group.
wandb_id (Optional[str]): Sets the version, mainly used to resume a previous run
wandb_anonymous (Optional[bool]): Enables or explicitly disables anonymous logging
wandb_log_model (bool): Save checkpoints in wandb dir to upload on W&B servers
pipeline_model_parallel_size (int): pipeline model parallel size
tensor_model_parallel_size (int): tensor model parallel size
create_tensorboard_logger (bool): create the tensorboard logger
nemo1_init_path (Optional[Path]): Nemo 1 initialization path
create_checkpoint_callback (bool): create a ModelCheckpoint callback and attach it to the pytorch lightning trainer
restore_from_checkpoint_path (Optional[str]): If set, restores the model from the directory passed in. Expects the
checkpoint to be created by using the ModelCheckpoint class and always_save_context=True.
save_best_checkpoint (bool): whether to save the best checkpoint
Expand Down Expand Up @@ -199,6 +203,7 @@ def main(
entity=wandb_entity,
tags=wandb_tags,
group=wandb_group,
job_type=wandb_job_type,
id=wandb_id,
anonymous=wandb_anonymous,
log_model=wandb_log_model,
Expand Down Expand Up @@ -237,6 +242,7 @@ def main(
grad_reduce_in_fp32=grad_reduce_in_fp32,
autocast_enabled=False,
),
enable_checkpointing=create_checkpoint_callback,
)

tokenizer = get_tokenizer()
Expand Down Expand Up @@ -298,14 +304,17 @@ def main(
)

# Configure our custom Checkpointer
checkpoint_callback = nl_callbacks.ModelCheckpoint(
save_last=save_last_checkpoint,
monitor=metric_to_monitor_for_checkpoints, # "val_loss",
save_top_k=save_top_k,
every_n_train_steps=val_check_interval,
always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
filename="{epoch}-{val_loss:.2f}-{step}-{consumed_samples}", # Including step and consumed_samples in the checkpoint filename prevents duplicate filenames and bugs related to this.
)
if create_checkpoint_callback:
checkpoint_callback = nl_callbacks.ModelCheckpoint(
save_last=save_last_checkpoint,
monitor=metric_to_monitor_for_checkpoints, # "val_loss",
save_top_k=save_top_k,
every_n_train_steps=val_check_interval,
always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
filename="{epoch}-{val_loss:.2f}-{step}-{consumed_samples}", # Including step and consumed_samples in the checkpoint filename prevents duplicate filenames and bugs related to this.
)
else:
checkpoint_callback = None

# Setup the logger and train the model
nemo_logger = setup_nemo_lightning_logger(
Expand Down Expand Up @@ -348,6 +357,7 @@ def train_esm2_entrypoint():
wandb_project=args.wandb_project,
wandb_tags=args.wandb_tags,
wandb_group=args.wandb_group,
wandb_job_type=args.wandb_job_type,
wandb_id=args.wandb_id,
wandb_anonymous=args.wandb_anonymous,
wandb_log_model=args.wandb_log_model,
Expand All @@ -369,6 +379,7 @@ def train_esm2_entrypoint():
experiment_name=args.experiment_name,
resume_if_exists=args.resume_if_exists,
nemo1_init_path=args.nemo1_init_path,
create_checkpoint_callback=args.create_checkpoint_callback,
restore_from_checkpoint_path=args.restore_from_checkpoint_path,
save_best_checkpoint=args.save_best_checkpoint,
save_last_checkpoint=args.save_last_checkpoint,
Expand Down Expand Up @@ -459,6 +470,12 @@ def get_parser():
parser.add_argument(
"--wandb-group", type=str, default=None, help="A unique string shared by all runs in a given group"
)
parser.add_argument(
"--wandb-job-type",
type=str,
default=None,
help="A unique string representing a type of run, which is useful when you're grouping runs together into larger experiments using group.",
)
parser.add_argument(
"--wandb-id", type=str, default=None, help="Sets the version, mainly used to resume a previous run"
)
Expand Down Expand Up @@ -580,6 +597,13 @@ def get_parser():
required=False,
help="Path to nemo1 file, if desired to load at init time.",
)
parser.add_argument(
"--disable-checkpointing",
action="store_false",
default=True,
dest="create_checkpoint_callback",
help="Disable creating a ModelCheckpoint callback.",
)
parser.add_argument(
"--save-best-checkpoint",
action="store_true",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ def dummy_parquet_train_val_inputs(tmp_path):
return train_cluster_path, valid_cluster_path


def test_main_runs(monkeypatch, tmpdir, dummy_protein_dataset, dummy_parquet_train_val_inputs):
@pytest.mark.parametrize("create_checkpoint_callback", [True, False])
def test_main_runs(
monkeypatch, tmpdir, dummy_protein_dataset, dummy_parquet_train_val_inputs, create_checkpoint_callback
):
train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs

result_dir = Path(tmpdir.mkdir("results"))
Expand Down Expand Up @@ -119,19 +122,28 @@ def test_main_runs(monkeypatch, tmpdir, dummy_protein_dataset, dummy_parquet_tra
num_attention_heads=2,
hidden_size=4,
ffn_hidden_size=4 * 4,
create_checkpoint_callback=create_checkpoint_callback,
)

assert (result_dir / "test_experiment").exists(), "Could not find test experiment directory."
assert (result_dir / "test_experiment").is_dir(), "Test experiment directory is supposed to be a directory."
children = list((result_dir / "test_experiment").iterdir())
assert len(children) == 1, f"Expected 1 child in test experiment directory, found {children}."
uq_rundir = children[0] # it will be some date.
assert (
result_dir / "test_experiment" / uq_rundir / "checkpoints"
).exists(), "Could not find test experiment checkpoints directory."
assert (
result_dir / "test_experiment" / uq_rundir / "checkpoints"
).is_dir(), "Test experiment checkpoints directory is supposed to be a directory."

# checking directory with checkpoints
expected_exists = create_checkpoint_callback
actual_exists = (result_dir / "test_experiment" / uq_rundir / "checkpoints").exists()
assert expected_exists == actual_exists, (
f"Checkpoints directory existence mismatch. "
f"Expected: {'exists' if expected_exists else 'does not exist'}, "
f"Found: {'exists' if actual_exists else 'does not exist'}."
)

if create_checkpoint_callback:
assert (
result_dir / "test_experiment" / uq_rundir / "checkpoints"
).is_dir(), "Test experiment checkpoints directory is supposed to be a directory."
assert (
result_dir / "test_experiment" / uq_rundir / "nemo_log_globalrank-0_localrank-0.txt"
).is_file(), "Could not find experiment log."
Expand Down
21 changes: 19 additions & 2 deletions sub-packages/bionemo-geneformer/src/bionemo/geneformer/run/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ def parse_args():
default=[0],
help="Enable nsys profiling for these ranks.",
)
parser.add_argument(
"--disable-checkpointing",
action="store_false",
default=True,
dest="create_checkpoint_callback",
help="Disable creating a ModelCheckpoint callback.",
)

return parser.parse_args()

Expand All @@ -92,7 +99,12 @@ def string_to_class(path: str):
module = importlib.import_module(module_path)
return getattr(module, class_name)

def load_config(config_path: str, model_config_cls: Optional[str], data_config_cls: Optional[str]) -> MainConfig:
def load_config(
config_path: str,
model_config_cls: Optional[str],
data_config_cls: Optional[str],
create_checkpoint_callback: bool,
) -> MainConfig:
with open(config_path, "r") as f:
config_dict = yaml.safe_load(f)

Expand All @@ -106,14 +118,19 @@ def load_config(config_path: str, model_config_cls: Optional[str], data_config_c
# We assume we get a string to some importable config... e.g. in the sub-package jensen, 'bionemo.jensen.configs.MyConfig'
model_config_cls = string_to_class(model_config_cls)

# disable checkpointing if called from the command line
if not create_checkpoint_callback:
config_dict["training_config"]["enable_checkpointing"] = create_checkpoint_callback
config_dict["experiment_config"]["create_checkpoint_callback"] = create_checkpoint_callback

if data_config_cls is None:
data_config_cls = GeneformerPretrainingDataConfig
elif isinstance(data_config_cls, str):
data_config_cls = string_to_class(data_config_cls)
return MainConfig[model_config_cls, data_config_cls](**config_dict)

args = parse_args()
config = load_config(args.config, args.model_config_cls, args.data_config_cls)
config = load_config(args.config, args.model_config_cls, args.data_config_cls, args.create_checkpoint_callback)

if args.nsys_profiling:
nsys_config = NsysConfig(
Expand Down
Loading

0 comments on commit 5eddee1

Please sign in to comment.