From ba2469780da5a689e856e21ab9664ab1bed4fdd5 Mon Sep 17 00:00:00 2001 From: yifanmao Date: Thu, 19 Dec 2024 16:05:24 -0800 Subject: [PATCH] [BE] Combine OptimizerWrapper and OptimizerContainer (#738) Combine `state_dict` and `load_state_dict` from OptimizerWrapper to OptimizerContainer so that we only have one optimzier related class Also, add `get_lr_scheduler_state` to SchedulersContainer when update `lr_scheduler` at self.state --- torchtitan/checkpoint.py | 80 ++------------------------ torchtitan/optimizer.py | 119 +++++++++++++++++++++++++-------------- 2 files changed, 81 insertions(+), 118 deletions(-) diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index a6057c79..db54ccd9 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -21,21 +21,14 @@ import torch.nn as nn from torch.distributed.checkpoint.state_dict import ( get_model_state_dict, - get_optimizer_state_dict, set_model_state_dict, - set_optimizer_state_dict, StateDictOptions, ) from torch.distributed.checkpoint.stateful import Stateful from torch.utils.data import DataLoader from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging import init_logger, logger -from torchtitan.optimizer import ( - OptimizersContainer, - OptimizersInBackwardContainer, - SchedulersContainer, - SchedulersInBackwardContainer, -) +from torchtitan.optimizer import OptimizersContainer, SchedulersContainer class IntervalType(enum.Enum): @@ -104,43 +97,6 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: list(map(func, self.model)) -class OptimizerWrapper(Stateful): - def __init__( - self, - model: Union[nn.Module, List[nn.Module]], - optim: OptimizersContainer, - ) -> None: - self.model = [model] if isinstance(model, nn.Module) else model - if isinstance(optim, OptimizersInBackwardContainer): - self.optim = [ - sub_optim - for optim_group in optim.optimizers - for sub_optim in optim_group - ] - else: - optimizers = optim.optimizers - self.optim = ( - [optimizers] - if isinstance(optimizers, torch.optim.Optimizer) - else optimizers - ) - - def state_dict(self) -> Dict[str, Any]: - func = functools.partial( - get_optimizer_state_dict, - options=StateDictOptions(flatten_optimizer_state_dict=True), - ) - return {k: v for sd in map(func, self.model, self.optim) for k, v in sd.items()} - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - func = functools.partial( - set_optimizer_state_dict, - optim_state_dict=state_dict, - options=StateDictOptions(flatten_optimizer_state_dict=True), - ) - list(map(func, self.model, self.optim)) - - class Terminate: pass @@ -204,7 +160,7 @@ def __init__( restore its optimizer states, others will error. The solution to this problem is optimizer flattening: it landed in #127071 and is enabled in TorchTitan - by passing the 'flatten_optimizer_state_dict' kwarg to DCP functions called in the OptimizerWrapper. + by passing the 'flatten_optimizer_state_dict' kwarg to DCP functions called in the OptimizerContainer. 2. With complex PP schedules, we have multiple model chunks per pp rank. This compounds challenge (1) by also requiring us to reason about multiple 'optim' objects locally. @@ -220,44 +176,16 @@ def __init__( TODO: This is currently unsolved and needs a fix. """ - assert len(model_parts) == len( - optimizers.optimizers - ), "Must pass one optimizer per model part" - assert len(model_parts) == len( - lr_schedulers.schedulers - ), "Must pass one lr_scheduler per model part" - self.states = states self.states.update( { "model": ModelWrapper(model_parts), - "optimizer": OptimizerWrapper( - model_parts, - optimizers, - ), + "optimizer": optimizers, "dataloader": dataloader, } ) - # SchedulersInBackwardContainer has a different structure than SchedulersContainer, List[List[Scheduler]] rahter - # than List[Scheduler], but the schedulers are the same for each list inside, so here just store the first one. - # TODO: Restructure SchedulersInBackwardContainer to be consisitent with SchedulersContainer. - if isinstance(lr_schedulers, SchedulersInBackwardContainer): - if len(lr_schedulers.schedulers) == 1: - self.states["lr_scheduler"] = lr_schedulers.schedulers[0][0] - else: - # For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler. - # It should only support saving and loading a distributed checkpoint with the same number of pp ranks - for idx, lr_scheduler in enumerate(lr_schedulers.schedulers): - self.states[f"lr_scheduler_{idx}"] = lr_scheduler[0] - else: - if len(lr_schedulers.schedulers) == 1: - self.states["lr_scheduler"] = lr_schedulers.schedulers[0] - else: - # For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler. - # It should only support saving and loading a distributed checkpoint with the same number of pp ranks - for idx, lr_scheduler in enumerate(lr_schedulers.schedulers): - self.states[f"lr_scheduler_{idx}"] = lr_scheduler + self.states.update(lr_schedulers.get_lr_scheduler_state()) self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder) self.interval_type = ( diff --git a/torchtitan/optimizer.py b/torchtitan/optimizer.py index d88d431f..4e205f04 100644 --- a/torchtitan/optimizer.py +++ b/torchtitan/optimizer.py @@ -5,18 +5,31 @@ # LICENSE file in the root directory of this source tree. import functools +from typing import Any, Dict, List import torch +import torch.nn as nn +from torch.distributed.checkpoint.state_dict import ( + get_optimizer_state_dict, + set_optimizer_state_dict, + StateDictOptions, +) +from torch.distributed.checkpoint.stateful import Stateful from torch.optim.lr_scheduler import LambdaLR from torchtitan.config_manager import JobConfig -class OptimizersContainer: - """Util for calling step/zero_grad on multiple optimizers needed for virtual pipeline stages""" +class OptimizersContainer(Stateful): + """Util for calling step/zero_grad on multiple optimizers needed for virtual pipeline stages + and saving/loading optimizer state_dict at checkpoint. + """ - def __init__(self, model_parts, optimizer_kwargs, name): + def __init__( + self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str + ) -> None: self.optimizers = [] - for model in model_parts: + self.model_parts = model_parts + for model in self.model_parts: if name == "Adam": # TODO: make the optimizer options configurable by toml/cmd args optimizer = torch.optim.Adam(model.parameters(), **optimizer_kwargs) @@ -25,22 +38,50 @@ def __init__(self, model_parts, optimizer_kwargs, name): else: raise NotImplementedError(f"Optimizer {name} not added.") self.optimizers.append(optimizer) + self._validate_length(len(self.model_parts)) + + def _validate_length(self, expected_length) -> None: + assert expected_length == len( + self.optimizers + ), "Must pass one optimizer per model part or per param if using OptimizersInBackwardContainer" - def step(self): + def step(self) -> None: for optimizer in self.optimizers: optimizer.step() - def zero_grad(self): + def zero_grad(self) -> None: for optimizer in self.optimizers: optimizer.zero_grad() + def state_dict(self) -> Dict[str, Any]: + func = functools.partial( + get_optimizer_state_dict, + options=StateDictOptions(flatten_optimizer_state_dict=True), + ) + return { + k: v + for sd in map(func, self.model_parts, self.optimizers) + for k, v in sd.items() + } + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + func = functools.partial( + set_optimizer_state_dict, + optim_state_dict=state_dict, + options=StateDictOptions(flatten_optimizer_state_dict=True), + ) + list(map(func, self.model_parts, self.optimizers)) + class OptimizersInBackwardContainer(OptimizersContainer): """Optimiers in backward to skip .step() and .zero_grad()""" - def __init__(self, model_parts, optimizer_kwargs, name): + def __init__( + self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str + ) -> None: self.optimizers = [] - for model in model_parts: + self.model_parts = model_parts + for model in self.model_parts: if name == "Adam": # TODO: make the optimizer options configurable by toml/cmd args optim_dict = { @@ -63,17 +104,25 @@ def optim_hook(param) -> None: if param.requires_grad: param.register_post_accumulate_grad_hook(optim_hook) - self.optimizers.append([optim_dict[param] for param in model.parameters()]) + self.optimizers.extend([optim_dict[param] for param in model.parameters()]) + self._validate_length( + sum( + len([param for param in model.parameters()]) + for model in self.model_parts + ) + ) - def step(self): + def step(self) -> None: pass - def zero_grad(self): + def zero_grad(self) -> None: pass # consider split between PP and non-PP -def build_optimizers(model_parts, job_config: JobConfig): +def build_optimizers( + model_parts: List[nn.Module], job_config: JobConfig +) -> OptimizersContainer: """Wrap one optimizer per model part in an OptimizersContainer which provides a single step() and zero_grad() method for all the child optimizers. """ @@ -121,44 +170,30 @@ def linear_warmup_linear_decay( class SchedulersContainer: """Util for calling step on multiple learning rate schedulers needed for virtual pipeline stages""" - def __init__(self, optimizers, lr_lambda): + def __init__(self, optimizers, lr_lambda) -> None: self.schedulers = [] for optimizer in optimizers: self.schedulers.append(LambdaLR(optimizer, lr_lambda=lr_lambda)) - def step(self): - for schedulers in self.schedulers: - schedulers.step() + def step(self) -> None: + for scheduler in self.schedulers: + scheduler.step() + def get_lr_scheduler_state(self) -> Dict[str, Any]: + state_dict = {} + if len(self.schedulers) == 1: + state_dict["lr_scheduler"] = self.schedulers[0] + else: + # For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler. + # It should only support saving and loading a distributed checkpoint with the same number of pp ranks + for idx, lr_scheduler in enumerate(self.schedulers): + state_dict[f"lr_scheduler_{idx}"] = lr_scheduler + return state_dict -class SchedulersInBackwardContainer(SchedulersContainer): - """Util for calling step on multiple learning rate schedulers when optimizers are in backward""" - - def __init__(self, optimizers, lr_lambda): - # all the schedulers for each optimizer group are the same, here we only store the first one - # to self.schedulers follow the same structure as SchedulersContainer, but store all of them - # to self.all_schedulers for container.step() to call - self.schedulers = [] - for optim_group in optimizers: - scheduler_group = [] - for sub_optim in optim_group: - scheduler_group.append(LambdaLR(sub_optim, lr_lambda=lr_lambda)) - self.schedulers.append(scheduler_group) - - def step(self): - for scheduler_group in self.schedulers: - for scheduler in scheduler_group: - scheduler.step() - -def build_lr_schedulers(optimizers, job_config: JobConfig): - optim_in_bwd = job_config.optimizer.early_step_in_backward +def build_lr_schedulers(optimizers, job_config: JobConfig) -> SchedulersContainer: warmup_steps = int(job_config.training.warmup_steps) decay_steps = float(max(1, job_config.training.steps - warmup_steps)) lr_lambda = functools.partial(linear_warmup_linear_decay, warmup_steps, decay_steps) - return ( - SchedulersContainer(optimizers, lr_lambda) - if not optim_in_bwd - else SchedulersInBackwardContainer(optimizers, lr_lambda) - ) + return SchedulersContainer(optimizers, lr_lambda)