Skip to content

Commit

Permalink
Updated FSDP example with additional Core API / AMP features.
Browse files Browse the repository at this point in the history
  • Loading branch information
coreystatendet committed Jul 1, 2024
1 parent 644223c commit 00ccc50
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 39 deletions.
4 changes: 4 additions & 0 deletions scratchwork/fsdp_min/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,8 @@ hyperparameters:
vocab_size: 32000
report_rate: 10
checkpoint_rate: 50
use_amp: true
validation_batches: 10
core_api_profiler: false
torch_profiler: false
max_restarts: 0
144 changes: 105 additions & 39 deletions scratchwork/fsdp_min/fsdp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging
import random
from typing import Any, Generator, Optional
from typing import Any, Dict, Generator, Optional, TypedDict

import numpy as np
import torch
Expand All @@ -11,6 +11,7 @@
from torch.distributed.fsdp import FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy, StateDictType
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.fsdp.wrap import ModuleWrapPolicy

import determined as det
Expand All @@ -21,51 +22,72 @@


def get_fake_data_iter(
batch_size: int, vocab_size: int, max_seq_len: int, rank: int, device: torch.device
batch_size: int,
vocab_size: int,
max_seq_len: int,
rank: int,
device: torch.device,
is_validation: bool,
) -> Generator[tuple[torch.Tensor, torch.Tensor], None, None]:
"""
Fake dataloader. Repeatedly yields the same (inputs, targets) tuple of tensors, with different
tensors on different ranks.
Fake dataloader. Yields a different set of data for each rank, and for train vs validation.
This data would usually come from a tokenized dataset.
"""
torch.manual_seed(42 + rank)
fake_sequence = torch.randint(vocab_size, (batch_size, max_seq_len), device=device)
inputs, targets = fake_sequence[..., :-1], fake_sequence[..., 1:]
generator = torch.Generator(device=device)
generator.manual_seed(42 + rank + 100000 * is_validation)
while True:
fake_sequence = torch.randint(
vocab_size, (batch_size, max_seq_len), device=device, generator=generator
)
inputs, targets = fake_sequence[..., :-1], fake_sequence[..., 1:]
yield inputs, targets


def get_loss(fsdp_model: FSDP, batch: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
def get_loss(
fsdp_model: FSDP, batch: tuple[torch.Tensor, torch.Tensor], use_amp: bool
) -> torch.Tensor:
inputs, labels = batch
outputs = fsdp_model(inputs)
outputs_flat = outputs.reshape(-1, outputs.shape[-1])
labels_flat = labels.reshape(-1)
loss = F.cross_entropy(outputs_flat, labels_flat)
with torch.cuda.amp.autocast(enabled=use_amp):
outputs = fsdp_model(inputs)
outputs_flat = outputs.reshape(-1, outputs.shape[-1])
labels_flat = labels.reshape(-1)
loss = F.cross_entropy(outputs_flat, labels_flat)
return loss


def get_reduced_loss_and_report(
loss_history: list[torch.Tensor],
steps_completed: int,
core_context: det.core.Context,
validation: bool,
) -> Optional[float]:
"""
Average the most recent training losses across all processes and report the result. Returns the
reduced loss on rank 0 and None on all other ranks.
Average the most recent losses across all processes and report the result. Returns the reduced
loss on rank 0 and None on all other ranks.
"""

loss_history_t = torch.stack(loss_history).mean()
dist.reduce(loss_history_t, 0, op=dist.ReduceOp.AVG)
if core_context.distributed.rank == 0:
reduced_loss = loss_history_t.item()
core_context.train.report_training_metrics(
steps_completed=steps_completed, metrics={"loss": reduced_loss}
)
# TypedDict pattern to satisfy mypy.
ReportArgs = TypedDict("ReportArgs", {"steps_completed": int, "metrics": Dict[str, float]})
report_args: ReportArgs = {
"steps_completed": steps_completed,
"metrics": {"loss": reduced_loss},
}
if validation:
core_context.train.report_validation_metrics(**report_args)
else:
core_context.train.report_training_metrics(**report_args)
return reduced_loss
return None


def save_checkpoint(
fsdp_model: FSDP,
optimizer: torch.optim.Optimizer,
scaler: ShardedGradScaler,
core_context: det.core.Context,
steps_completed: int,
) -> None:
Expand All @@ -86,11 +108,15 @@ def save_checkpoint(
):
torch.save(model_state_dict, path.joinpath("model.bin"))
torch.save(optim_state_dict, path.joinpath("optim.bin"))
# Scaler state is automatically the same across ranks.
scaler_state_dict = scaler.state_dict()
torch.save(scaler_state_dict, path.joinpath("scaler.bin"))


def load_checkpoint(
fsdp_model: FSDP,
optimizer: torch.optim.Optimizer,
scaler: ShardedGradScaler,
core_context: det.core.Context,
device: torch.device,
uuid: str,
Expand All @@ -109,6 +135,7 @@ def load_checkpoint(
optim_state_dict=optim_state_dict,
)
optimizer.load_state_dict(optim_state_dict_to_load)
scaler.load_state_dict(torch.load(path.joinpath("scaler.bin")))

with open(path.joinpath("metadata.json"), "r") as f:
metadata = json.load(f)
Expand Down Expand Up @@ -170,47 +197,81 @@ def main(
steps_completed = 0
report_rate = hparams["report_rate"]
checkpoint_rate = hparams["checkpoint_rate"]
loss_history = []

data_iter = get_fake_data_iter(
batch_size=hparams["batch_size"],
vocab_size=hparams["vocab_size"],
max_seq_len=hparams["max_seq_len"],
rank=core_context.distributed.rank,
device=device,
)

validation_batches = hparams["validation_batches"]
use_amp = hparams["use_amp"]
use_torch_profiler = hparams["torch_profiler"]
train_loss_history = []

data_iter_arguments = {
"batch_size": hparams["batch_size"],
"vocab_size": hparams["vocab_size"],
"max_seq_len": hparams["max_seq_len"],
"rank": core_context.distributed.rank,
"device": device,
}
train_data_iter = get_fake_data_iter(is_validation=False, **data_iter_arguments)
scaler = ShardedGradScaler(enabled=use_amp)
# If a previous checkpoint exists, load it now and correct the steps_completed:
if checkpoint_uuid is not None:
steps_completed = load_checkpoint(
fsdp_model, optimizer, core_context, device, checkpoint_uuid
fsdp_model, optimizer, scaler, core_context, device, checkpoint_uuid
)
# If torch profiler enabled, write profiling results to TensorBoard accessible through WebUI.
if use_torch_profiler:
torch_profiler = torch.profiler.profile(
on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(core_context.train.get_tensorboard_path())
),
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(wait=1, warmup=1, active=2),
)

for op in core_context.searcher.operations():
# Train for the number of steps specified in searcher.max_length in config.yaml
while steps_completed < op.length:
batch = next(data_iter)
loss = get_loss(fsdp_model, batch)
loss_history.append(loss.detach().clone())
loss.backward()
optimizer.step()
batch = next(train_data_iter)
loss = get_loss(fsdp_model, batch, use_amp)
train_loss_history.append(loss.detach().clone())
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if use_torch_profiler:
torch_profiler.step()

steps_completed += 1
this_is_the_last_step = steps_completed == op.length

if steps_completed % report_rate == 0 or this_is_the_last_step:
reduced_loss = get_reduced_loss_and_report(
loss_history, steps_completed, core_context
# Report the average training loss.
get_reduced_loss_and_report(
train_loss_history, steps_completed, core_context, validation=False
)
train_loss_history.clear()
# Compute and report an average validation loss.
validation_data_iter = get_fake_data_iter(is_validation=True, **data_iter_arguments)
validation_loss_history = []
with torch.inference_mode():
for i in range(validation_batches):
batch = next(validation_data_iter)
loss = get_loss(fsdp_model, batch, use_amp)
validation_loss_history.append(loss)
last_validation_loss = get_reduced_loss_and_report(
validation_loss_history, steps_completed, core_context, validation=True
)
loss_history.clear()

if steps_completed % checkpoint_rate == 0 or this_is_the_last_step:
save_checkpoint(fsdp_model, optimizer, core_context, steps_completed)
save_checkpoint(fsdp_model, optimizer, scaler, core_context, steps_completed)
# Since should_preempt is blocking, we only check at checkpoint_rate to
# maintain performance.
if core_context.preempt.should_preempt():
return

# Tell the master we're done
if core_context.distributed.rank == 0:
op.report_completed(reduced_loss)
op.report_completed(last_validation_loss)


if __name__ == "__main__":
Expand All @@ -221,14 +282,19 @@ def main(

checkpoint_uuid = info.latest_checkpoint
hparams = info.trial.hparams
core_api_profiler = hparams["core_api_profiler"]
try:
dist.init_process_group("nccl")
distributed = det.core.DistributedContext.from_torch_distributed()
with det.core.init(distributed=distributed) as core_context:
if core_api_profiler:
core_context.profiler.on()
main(
core_context=core_context,
hparams=hparams,
checkpoint_uuid=checkpoint_uuid,
)
if core_api_profiler:
core_context.profiler.off()
finally:
dist.destroy_process_group()

0 comments on commit 00ccc50

Please sign in to comment.