From 501116706880388a4e4c9789807db48559e084a0 Mon Sep 17 00:00:00 2001 From: mori360 Date: Fri, 13 Dec 2024 15:49:38 -0800 Subject: [PATCH 1/8] combine optim and lr_scheduler together --- torchtitan/checkpoint.py | 27 ++-------------- torchtitan/optimizer.py | 68 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 69 insertions(+), 26 deletions(-) diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index a6057c79..8f857afd 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -34,7 +34,6 @@ OptimizersContainer, OptimizersInBackwardContainer, SchedulersContainer, - SchedulersInBackwardContainer, ) @@ -229,35 +228,15 @@ def __init__( self.states = states + optimizers.update_for_checkpoint(model_parts) self.states.update( { "model": ModelWrapper(model_parts), - "optimizer": OptimizerWrapper( - model_parts, - optimizers, - ), + "optimizer": optimizers, "dataloader": dataloader, } ) - # SchedulersInBackwardContainer has a different structure than SchedulersContainer, List[List[Scheduler]] rahter - # than List[Scheduler], but the schedulers are the same for each list inside, so here just store the first one. - # TODO: Restructure SchedulersInBackwardContainer to be consisitent with SchedulersContainer. - if isinstance(lr_schedulers, SchedulersInBackwardContainer): - if len(lr_schedulers.schedulers) == 1: - self.states["lr_scheduler"] = lr_schedulers.schedulers[0][0] - else: - # For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler. - # It should only support saving and loading a distributed checkpoint with the same number of pp ranks - for idx, lr_scheduler in enumerate(lr_schedulers.schedulers): - self.states[f"lr_scheduler_{idx}"] = lr_scheduler[0] - else: - if len(lr_schedulers.schedulers) == 1: - self.states["lr_scheduler"] = lr_schedulers.schedulers[0] - else: - # For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler. - # It should only support saving and loading a distributed checkpoint with the same number of pp ranks - for idx, lr_scheduler in enumerate(lr_schedulers.schedulers): - self.states[f"lr_scheduler_{idx}"] = lr_scheduler + self.states.update(lr_schedulers.update_state()) self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder) self.interval_type = ( diff --git a/torchtitan/optimizer.py b/torchtitan/optimizer.py index d88d431f..6833bde3 100644 --- a/torchtitan/optimizer.py +++ b/torchtitan/optimizer.py @@ -5,17 +5,26 @@ # LICENSE file in the root directory of this source tree. import functools +from typing import Any, Dict import torch +from torch.distributed.checkpoint.state_dict import ( + get_optimizer_state_dict, + set_optimizer_state_dict, + StateDictOptions, +) +from torch.distributed.checkpoint.stateful import Stateful from torch.optim.lr_scheduler import LambdaLR from torchtitan.config_manager import JobConfig -class OptimizersContainer: - """Util for calling step/zero_grad on multiple optimizers needed for virtual pipeline stages""" +class OptimizersContainer(Stateful): + """Util for calling step/zero_grad on multiple optimizers needed for virtual pipeline stages and save""" def __init__(self, model_parts, optimizer_kwargs, name): self.optimizers = [] + self.model = [] + self.plain_optim = [] for model in model_parts: if name == "Adam": # TODO: make the optimizer options configurable by toml/cmd args @@ -26,6 +35,14 @@ def __init__(self, model_parts, optimizer_kwargs, name): raise NotImplementedError(f"Optimizer {name} not added.") self.optimizers.append(optimizer) + def update_for_checkpoint(self, model): + self.model = [model] if isinstance(model, torch.nn.Module) else model + self.plain_optim = ( + [self.optimizers] + if isinstance(self.optimizers, torch.optim.Optimizer) + else self.optimizers + ) + def step(self): for optimizer in self.optimizers: optimizer.step() @@ -34,6 +51,25 @@ def zero_grad(self): for optimizer in self.optimizers: optimizer.zero_grad() + def state_dict(self) -> Dict[str, Any]: + func = functools.partial( + get_optimizer_state_dict, + options=StateDictOptions(flatten_optimizer_state_dict=True), + ) + return { + k: v + for sd in map(func, self.model, self.plain_optim) + for k, v in sd.items() + } + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + func = functools.partial( + set_optimizer_state_dict, + optim_state_dict=state_dict, + options=StateDictOptions(flatten_optimizer_state_dict=True), + ) + list(map(func, self.model, self.plain_optim)) + class OptimizersInBackwardContainer(OptimizersContainer): """Optimiers in backward to skip .step() and .zero_grad()""" @@ -65,6 +101,12 @@ def optim_hook(param) -> None: self.optimizers.append([optim_dict[param] for param in model.parameters()]) + def update_for_checkpoint(self, model): + self.model = [model] if isinstance(model, torch.nn.Module) else model + self.plain_optim = [ + sub_optim for optim_group in self.optimizers for sub_optim in optim_group + ] + def step(self): pass @@ -130,6 +172,17 @@ def step(self): for schedulers in self.schedulers: schedulers.step() + def update_state(self) -> Dict[str, Any]: + state_dict = {} + if len(self.schedulers) == 1: + state_dict["lr_scheduler"] = self.schedulers[0] + else: + # For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler. + # It should only support saving and loading a distributed checkpoint with the same number of pp ranks + for idx, lr_scheduler in enumerate(self.schedulers): + state_dict[f"lr_scheduler_{idx}"] = lr_scheduler + return state_dict + class SchedulersInBackwardContainer(SchedulersContainer): """Util for calling step on multiple learning rate schedulers when optimizers are in backward""" @@ -150,6 +203,17 @@ def step(self): for scheduler in scheduler_group: scheduler.step() + def update_state(self) -> Dict[str, Any]: + state_dict = {} + if len(self.schedulers) == 1: + state_dict["lr_scheduler"] = self.schedulers[0][0] + else: + # For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler. + # It should only support saving and loading a distributed checkpoint with the same number of pp ranks + for idx, lr_scheduler in enumerate(self.schedulers): + state_dict[f"lr_scheduler_{idx}"] = lr_scheduler[0] + return state_dict + def build_lr_schedulers(optimizers, job_config: JobConfig): optim_in_bwd = job_config.optimizer.early_step_in_backward From fbcbf6633dc3bc9ac0b64cc5bf5e23cc49047fa2 Mon Sep 17 00:00:00 2001 From: mori360 Date: Fri, 13 Dec 2024 16:34:10 -0800 Subject: [PATCH 2/8] change func name --- torchtitan/checkpoint.py | 2 +- torchtitan/optimizer.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index 8f857afd..71f896a5 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -236,7 +236,7 @@ def __init__( "dataloader": dataloader, } ) - self.states.update(lr_schedulers.update_state()) + self.states.update(lr_schedulers.get_lr_scheduler_state()) self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder) self.interval_type = ( diff --git a/torchtitan/optimizer.py b/torchtitan/optimizer.py index 6833bde3..8674d53d 100644 --- a/torchtitan/optimizer.py +++ b/torchtitan/optimizer.py @@ -172,7 +172,7 @@ def step(self): for schedulers in self.schedulers: schedulers.step() - def update_state(self) -> Dict[str, Any]: + def get_lr_scheduler_state(self) -> Dict[str, Any]: state_dict = {} if len(self.schedulers) == 1: state_dict["lr_scheduler"] = self.schedulers[0] @@ -203,7 +203,7 @@ def step(self): for scheduler in scheduler_group: scheduler.step() - def update_state(self) -> Dict[str, Any]: + def get_lr_scheduler_state(self) -> Dict[str, Any]: state_dict = {} if len(self.schedulers) == 1: state_dict["lr_scheduler"] = self.schedulers[0][0] From 3a5b16f660d132d7ed7719b71c22e5021a204557 Mon Sep 17 00:00:00 2001 From: mori360 Date: Mon, 16 Dec 2024 11:14:22 -0800 Subject: [PATCH 3/8] remove optimizerwrapper, combine update model into init --- torchtitan/checkpoint.py | 48 ++-------------------------------------- torchtitan/optimizer.py | 20 ++++++++++------- 2 files changed, 14 insertions(+), 54 deletions(-) diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index 71f896a5..ebb81dac 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -21,20 +21,14 @@ import torch.nn as nn from torch.distributed.checkpoint.state_dict import ( get_model_state_dict, - get_optimizer_state_dict, set_model_state_dict, - set_optimizer_state_dict, StateDictOptions, ) from torch.distributed.checkpoint.stateful import Stateful from torch.utils.data import DataLoader from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging import init_logger, logger -from torchtitan.optimizer import ( - OptimizersContainer, - OptimizersInBackwardContainer, - SchedulersContainer, -) +from torchtitan.optimizer import OptimizersContainer, SchedulersContainer class IntervalType(enum.Enum): @@ -103,43 +97,6 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: list(map(func, self.model)) -class OptimizerWrapper(Stateful): - def __init__( - self, - model: Union[nn.Module, List[nn.Module]], - optim: OptimizersContainer, - ) -> None: - self.model = [model] if isinstance(model, nn.Module) else model - if isinstance(optim, OptimizersInBackwardContainer): - self.optim = [ - sub_optim - for optim_group in optim.optimizers - for sub_optim in optim_group - ] - else: - optimizers = optim.optimizers - self.optim = ( - [optimizers] - if isinstance(optimizers, torch.optim.Optimizer) - else optimizers - ) - - def state_dict(self) -> Dict[str, Any]: - func = functools.partial( - get_optimizer_state_dict, - options=StateDictOptions(flatten_optimizer_state_dict=True), - ) - return {k: v for sd in map(func, self.model, self.optim) for k, v in sd.items()} - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - func = functools.partial( - set_optimizer_state_dict, - optim_state_dict=state_dict, - options=StateDictOptions(flatten_optimizer_state_dict=True), - ) - list(map(func, self.model, self.optim)) - - class Terminate: pass @@ -203,7 +160,7 @@ def __init__( restore its optimizer states, others will error. The solution to this problem is optimizer flattening: it landed in #127071 and is enabled in TorchTitan - by passing the 'flatten_optimizer_state_dict' kwarg to DCP functions called in the OptimizerWrapper. + by passing the 'flatten_optimizer_state_dict' kwarg to DCP functions called in the OptimizerContainer. 2. With complex PP schedules, we have multiple model chunks per pp rank. This compounds challenge (1) by also requiring us to reason about multiple 'optim' objects locally. @@ -228,7 +185,6 @@ def __init__( self.states = states - optimizers.update_for_checkpoint(model_parts) self.states.update( { "model": ModelWrapper(model_parts), diff --git a/torchtitan/optimizer.py b/torchtitan/optimizer.py index 8674d53d..f0110018 100644 --- a/torchtitan/optimizer.py +++ b/torchtitan/optimizer.py @@ -19,13 +19,15 @@ class OptimizersContainer(Stateful): - """Util for calling step/zero_grad on multiple optimizers needed for virtual pipeline stages and save""" + """Util for calling step/zero_grad on multiple optimizers needed for virtual pipeline stages + and saving/loading optimizer state_dict at checkpoint. + """ def __init__(self, model_parts, optimizer_kwargs, name): self.optimizers = [] - self.model = [] + self.model = model_parts self.plain_optim = [] - for model in model_parts: + for model in self.model: if name == "Adam": # TODO: make the optimizer options configurable by toml/cmd args optimizer = torch.optim.Adam(model.parameters(), **optimizer_kwargs) @@ -34,9 +36,13 @@ def __init__(self, model_parts, optimizer_kwargs, name): else: raise NotImplementedError(f"Optimizer {name} not added.") self.optimizers.append(optimizer) + self.plain_optim = ( + [self.optimizers] + if isinstance(self.optimizers, torch.optim.Optimizer) + else self.optimizers + ) def update_for_checkpoint(self, model): - self.model = [model] if isinstance(model, torch.nn.Module) else model self.plain_optim = ( [self.optimizers] if isinstance(self.optimizers, torch.optim.Optimizer) @@ -76,7 +82,8 @@ class OptimizersInBackwardContainer(OptimizersContainer): def __init__(self, model_parts, optimizer_kwargs, name): self.optimizers = [] - for model in model_parts: + self.model = model_parts + for model in self.model: if name == "Adam": # TODO: make the optimizer options configurable by toml/cmd args optim_dict = { @@ -100,9 +107,6 @@ def optim_hook(param) -> None: param.register_post_accumulate_grad_hook(optim_hook) self.optimizers.append([optim_dict[param] for param in model.parameters()]) - - def update_for_checkpoint(self, model): - self.model = [model] if isinstance(model, torch.nn.Module) else model self.plain_optim = [ sub_optim for optim_group in self.optimizers for sub_optim in optim_group ] From fa4eef969a1ff99c910d3ae5d226b0074b59a271 Mon Sep 17 00:00:00 2001 From: mori360 Date: Mon, 16 Dec 2024 19:13:54 -0800 Subject: [PATCH 4/8] add typing --- torchtitan/optimizer.py | 41 ++++++++++++++++++++--------------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/torchtitan/optimizer.py b/torchtitan/optimizer.py index f0110018..8da56eb1 100644 --- a/torchtitan/optimizer.py +++ b/torchtitan/optimizer.py @@ -5,9 +5,10 @@ # LICENSE file in the root directory of this source tree. import functools -from typing import Any, Dict +from typing import Any, Dict, List import torch +import torch.nn as nn from torch.distributed.checkpoint.state_dict import ( get_optimizer_state_dict, set_optimizer_state_dict, @@ -23,10 +24,11 @@ class OptimizersContainer(Stateful): and saving/loading optimizer state_dict at checkpoint. """ - def __init__(self, model_parts, optimizer_kwargs, name): + def __init__( + self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str + ) -> None: self.optimizers = [] self.model = model_parts - self.plain_optim = [] for model in self.model: if name == "Adam": # TODO: make the optimizer options configurable by toml/cmd args @@ -42,18 +44,11 @@ def __init__(self, model_parts, optimizer_kwargs, name): else self.optimizers ) - def update_for_checkpoint(self, model): - self.plain_optim = ( - [self.optimizers] - if isinstance(self.optimizers, torch.optim.Optimizer) - else self.optimizers - ) - - def step(self): + def step(self) -> None: for optimizer in self.optimizers: optimizer.step() - def zero_grad(self): + def zero_grad(self) -> None: for optimizer in self.optimizers: optimizer.zero_grad() @@ -80,7 +75,9 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: class OptimizersInBackwardContainer(OptimizersContainer): """Optimiers in backward to skip .step() and .zero_grad()""" - def __init__(self, model_parts, optimizer_kwargs, name): + def __init__( + self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str + ) -> None: self.optimizers = [] self.model = model_parts for model in self.model: @@ -111,15 +108,17 @@ def optim_hook(param) -> None: sub_optim for optim_group in self.optimizers for sub_optim in optim_group ] - def step(self): + def step(self) -> None: pass - def zero_grad(self): + def zero_grad(self) -> None: pass # consider split between PP and non-PP -def build_optimizers(model_parts, job_config: JobConfig): +def build_optimizers( + model_parts: List[nn.Module], job_config: JobConfig +) -> OptimizersContainer: """Wrap one optimizer per model part in an OptimizersContainer which provides a single step() and zero_grad() method for all the child optimizers. """ @@ -167,12 +166,12 @@ def linear_warmup_linear_decay( class SchedulersContainer: """Util for calling step on multiple learning rate schedulers needed for virtual pipeline stages""" - def __init__(self, optimizers, lr_lambda): + def __init__(self, optimizers, lr_lambda) -> None: self.schedulers = [] for optimizer in optimizers: self.schedulers.append(LambdaLR(optimizer, lr_lambda=lr_lambda)) - def step(self): + def step(self) -> None: for schedulers in self.schedulers: schedulers.step() @@ -191,7 +190,7 @@ def get_lr_scheduler_state(self) -> Dict[str, Any]: class SchedulersInBackwardContainer(SchedulersContainer): """Util for calling step on multiple learning rate schedulers when optimizers are in backward""" - def __init__(self, optimizers, lr_lambda): + def __init__(self, optimizers, lr_lambda) -> None: # all the schedulers for each optimizer group are the same, here we only store the first one # to self.schedulers follow the same structure as SchedulersContainer, but store all of them # to self.all_schedulers for container.step() to call @@ -202,7 +201,7 @@ def __init__(self, optimizers, lr_lambda): scheduler_group.append(LambdaLR(sub_optim, lr_lambda=lr_lambda)) self.schedulers.append(scheduler_group) - def step(self): + def step(self) -> None: for scheduler_group in self.schedulers: for scheduler in scheduler_group: scheduler.step() @@ -219,7 +218,7 @@ def get_lr_scheduler_state(self) -> Dict[str, Any]: return state_dict -def build_lr_schedulers(optimizers, job_config: JobConfig): +def build_lr_schedulers(optimizers, job_config: JobConfig) -> SchedulersContainer: optim_in_bwd = job_config.optimizer.early_step_in_backward warmup_steps = int(job_config.training.warmup_steps) decay_steps = float(max(1, job_config.training.steps - warmup_steps)) From bcb144c7300b8b6bb112732687fda00b41451d7c Mon Sep 17 00:00:00 2001 From: mori360 Date: Tue, 17 Dec 2024 11:06:54 -0800 Subject: [PATCH 5/8] restructure optimierInBackward class, combine self.optimziers and self.plain_optimizers --- torchtitan/optimizer.py | 54 +++-------------------------------------- 1 file changed, 4 insertions(+), 50 deletions(-) diff --git a/torchtitan/optimizer.py b/torchtitan/optimizer.py index 8da56eb1..4f30f289 100644 --- a/torchtitan/optimizer.py +++ b/torchtitan/optimizer.py @@ -38,11 +38,6 @@ def __init__( else: raise NotImplementedError(f"Optimizer {name} not added.") self.optimizers.append(optimizer) - self.plain_optim = ( - [self.optimizers] - if isinstance(self.optimizers, torch.optim.Optimizer) - else self.optimizers - ) def step(self) -> None: for optimizer in self.optimizers: @@ -58,9 +53,7 @@ def state_dict(self) -> Dict[str, Any]: options=StateDictOptions(flatten_optimizer_state_dict=True), ) return { - k: v - for sd in map(func, self.model, self.plain_optim) - for k, v in sd.items() + k: v for sd in map(func, self.model, self.optimizers) for k, v in sd.items() } def load_state_dict(self, state_dict: Dict[str, Any]) -> None: @@ -69,7 +62,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: optim_state_dict=state_dict, options=StateDictOptions(flatten_optimizer_state_dict=True), ) - list(map(func, self.model, self.plain_optim)) + list(map(func, self.model, self.optimizers)) class OptimizersInBackwardContainer(OptimizersContainer): @@ -103,10 +96,7 @@ def optim_hook(param) -> None: if param.requires_grad: param.register_post_accumulate_grad_hook(optim_hook) - self.optimizers.append([optim_dict[param] for param in model.parameters()]) - self.plain_optim = [ - sub_optim for optim_group in self.optimizers for sub_optim in optim_group - ] + self.optimizers.extend([optim_dict[param] for param in model.parameters()]) def step(self) -> None: pass @@ -187,45 +177,9 @@ def get_lr_scheduler_state(self) -> Dict[str, Any]: return state_dict -class SchedulersInBackwardContainer(SchedulersContainer): - """Util for calling step on multiple learning rate schedulers when optimizers are in backward""" - - def __init__(self, optimizers, lr_lambda) -> None: - # all the schedulers for each optimizer group are the same, here we only store the first one - # to self.schedulers follow the same structure as SchedulersContainer, but store all of them - # to self.all_schedulers for container.step() to call - self.schedulers = [] - for optim_group in optimizers: - scheduler_group = [] - for sub_optim in optim_group: - scheduler_group.append(LambdaLR(sub_optim, lr_lambda=lr_lambda)) - self.schedulers.append(scheduler_group) - - def step(self) -> None: - for scheduler_group in self.schedulers: - for scheduler in scheduler_group: - scheduler.step() - - def get_lr_scheduler_state(self) -> Dict[str, Any]: - state_dict = {} - if len(self.schedulers) == 1: - state_dict["lr_scheduler"] = self.schedulers[0][0] - else: - # For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler. - # It should only support saving and loading a distributed checkpoint with the same number of pp ranks - for idx, lr_scheduler in enumerate(self.schedulers): - state_dict[f"lr_scheduler_{idx}"] = lr_scheduler[0] - return state_dict - - def build_lr_schedulers(optimizers, job_config: JobConfig) -> SchedulersContainer: - optim_in_bwd = job_config.optimizer.early_step_in_backward warmup_steps = int(job_config.training.warmup_steps) decay_steps = float(max(1, job_config.training.steps - warmup_steps)) lr_lambda = functools.partial(linear_warmup_linear_decay, warmup_steps, decay_steps) - return ( - SchedulersContainer(optimizers, lr_lambda) - if not optim_in_bwd - else SchedulersInBackwardContainer(optimizers, lr_lambda) - ) + return SchedulersContainer(optimizers, lr_lambda) From 491d3722fcdae4209c0aa806afcc80f4f11dcdde Mon Sep 17 00:00:00 2001 From: mori360 Date: Tue, 17 Dec 2024 11:46:49 -0800 Subject: [PATCH 6/8] change num_optim check due to changes at optimiers.optimiers --- torchtitan/checkpoint.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index ebb81dac..6a14fc23 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -176,12 +176,18 @@ def __init__( TODO: This is currently unsolved and needs a fix. """ - assert len(model_parts) == len( + if job_config.optimizer.early_step_in_backward: + expected_optim_num = sum( + len([param for param in model.parameters()]) for model in model_parts + ) + else: + expected_optim_num = len(model_parts) + assert expected_optim_num == len( optimizers.optimizers - ), "Must pass one optimizer per model part" - assert len(model_parts) == len( + ), "Must pass one optimizer per model part (or per param for optim_in_bwd)" + assert expected_optim_num == len( lr_schedulers.schedulers - ), "Must pass one lr_scheduler per model part" + ), "Must pass one lr_scheduler per model part (or per param for optim_in_bwd)" self.states = states From 7fe2928becc0696d53f5854e1a59fce8affbac5c Mon Sep 17 00:00:00 2001 From: mori360 Date: Wed, 18 Dec 2024 11:05:30 -0800 Subject: [PATCH 7/8] git assert check to container, optimizer some names --- torchtitan/checkpoint.py | 13 ------------- torchtitan/optimizer.py | 28 ++++++++++++++++++++-------- 2 files changed, 20 insertions(+), 21 deletions(-) diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index 6a14fc23..db54ccd9 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -176,19 +176,6 @@ def __init__( TODO: This is currently unsolved and needs a fix. """ - if job_config.optimizer.early_step_in_backward: - expected_optim_num = sum( - len([param for param in model.parameters()]) for model in model_parts - ) - else: - expected_optim_num = len(model_parts) - assert expected_optim_num == len( - optimizers.optimizers - ), "Must pass one optimizer per model part (or per param for optim_in_bwd)" - assert expected_optim_num == len( - lr_schedulers.schedulers - ), "Must pass one lr_scheduler per model part (or per param for optim_in_bwd)" - self.states = states self.states.update( diff --git a/torchtitan/optimizer.py b/torchtitan/optimizer.py index 4f30f289..d31f2826 100644 --- a/torchtitan/optimizer.py +++ b/torchtitan/optimizer.py @@ -28,8 +28,8 @@ def __init__( self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str ) -> None: self.optimizers = [] - self.model = model_parts - for model in self.model: + self.model_parts = model_parts + for model in self.model_parts: if name == "Adam": # TODO: make the optimizer options configurable by toml/cmd args optimizer = torch.optim.Adam(model.parameters(), **optimizer_kwargs) @@ -38,6 +38,10 @@ def __init__( else: raise NotImplementedError(f"Optimizer {name} not added.") self.optimizers.append(optimizer) + expected_optim_num = len(self.model_parts) + assert expected_optim_num == len( + self.optimizers + ), "Must pass one optimizer per model part" def step(self) -> None: for optimizer in self.optimizers: @@ -53,7 +57,9 @@ def state_dict(self) -> Dict[str, Any]: options=StateDictOptions(flatten_optimizer_state_dict=True), ) return { - k: v for sd in map(func, self.model, self.optimizers) for k, v in sd.items() + k: v + for sd in map(func, self.model_parts, self.optimizers) + for k, v in sd.items() } def load_state_dict(self, state_dict: Dict[str, Any]) -> None: @@ -62,7 +68,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: optim_state_dict=state_dict, options=StateDictOptions(flatten_optimizer_state_dict=True), ) - list(map(func, self.model, self.optimizers)) + list(map(func, self.model_parts, self.optimizers)) class OptimizersInBackwardContainer(OptimizersContainer): @@ -72,8 +78,8 @@ def __init__( self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str ) -> None: self.optimizers = [] - self.model = model_parts - for model in self.model: + self.model_parts = model_parts + for model in self.model_parts: if name == "Adam": # TODO: make the optimizer options configurable by toml/cmd args optim_dict = { @@ -97,6 +103,12 @@ def optim_hook(param) -> None: param.register_post_accumulate_grad_hook(optim_hook) self.optimizers.extend([optim_dict[param] for param in model.parameters()]) + expected_optim_num = sum( + len([param for param in model.parameters()]) for model in self.model_parts + ) + assert expected_optim_num == len( + self.optimizers + ), "Must pass one optimizer per model param part" def step(self) -> None: pass @@ -162,8 +174,8 @@ def __init__(self, optimizers, lr_lambda) -> None: self.schedulers.append(LambdaLR(optimizer, lr_lambda=lr_lambda)) def step(self) -> None: - for schedulers in self.schedulers: - schedulers.step() + for scheduler in self.schedulers: + scheduler.step() def get_lr_scheduler_state(self) -> Dict[str, Any]: state_dict = {} From f91c1dac948379865a5393253bbe6bca707461a7 Mon Sep 17 00:00:00 2001 From: mori360 Date: Thu, 19 Dec 2024 15:44:49 -0800 Subject: [PATCH 8/8] isolate _validate_length --- torchtitan/optimizer.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/torchtitan/optimizer.py b/torchtitan/optimizer.py index d31f2826..4e205f04 100644 --- a/torchtitan/optimizer.py +++ b/torchtitan/optimizer.py @@ -38,10 +38,12 @@ def __init__( else: raise NotImplementedError(f"Optimizer {name} not added.") self.optimizers.append(optimizer) - expected_optim_num = len(self.model_parts) - assert expected_optim_num == len( + self._validate_length(len(self.model_parts)) + + def _validate_length(self, expected_length) -> None: + assert expected_length == len( self.optimizers - ), "Must pass one optimizer per model part" + ), "Must pass one optimizer per model part or per param if using OptimizersInBackwardContainer" def step(self) -> None: for optimizer in self.optimizers: @@ -103,12 +105,12 @@ def optim_hook(param) -> None: param.register_post_accumulate_grad_hook(optim_hook) self.optimizers.extend([optim_dict[param] for param in model.parameters()]) - expected_optim_num = sum( - len([param for param in model.parameters()]) for model in self.model_parts + self._validate_length( + sum( + len([param for param in model.parameters()]) + for model in self.model_parts + ) ) - assert expected_optim_num == len( - self.optimizers - ), "Must pass one optimizer per model param part" def step(self) -> None: pass