-
Notifications
You must be signed in to change notification settings - Fork 252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BE] Combine OptimizerWrapper and OptimizerContainer #738
Changes from 6 commits
5011167
fbcbf66
3a5b16f
fa4eef9
bcb144c
491d372
7fe2928
f91c1da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,21 +21,14 @@ | |
import torch.nn as nn | ||
from torch.distributed.checkpoint.state_dict import ( | ||
get_model_state_dict, | ||
get_optimizer_state_dict, | ||
set_model_state_dict, | ||
set_optimizer_state_dict, | ||
StateDictOptions, | ||
) | ||
from torch.distributed.checkpoint.stateful import Stateful | ||
from torch.utils.data import DataLoader | ||
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP | ||
from torchtitan.logging import init_logger, logger | ||
from torchtitan.optimizer import ( | ||
OptimizersContainer, | ||
OptimizersInBackwardContainer, | ||
SchedulersContainer, | ||
SchedulersInBackwardContainer, | ||
) | ||
from torchtitan.optimizer import OptimizersContainer, SchedulersContainer | ||
|
||
|
||
class IntervalType(enum.Enum): | ||
|
@@ -104,43 +97,6 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | |
list(map(func, self.model)) | ||
|
||
|
||
class OptimizerWrapper(Stateful): | ||
def __init__( | ||
self, | ||
model: Union[nn.Module, List[nn.Module]], | ||
optim: OptimizersContainer, | ||
) -> None: | ||
self.model = [model] if isinstance(model, nn.Module) else model | ||
if isinstance(optim, OptimizersInBackwardContainer): | ||
self.optim = [ | ||
sub_optim | ||
for optim_group in optim.optimizers | ||
for sub_optim in optim_group | ||
] | ||
else: | ||
optimizers = optim.optimizers | ||
self.optim = ( | ||
[optimizers] | ||
if isinstance(optimizers, torch.optim.Optimizer) | ||
else optimizers | ||
) | ||
|
||
def state_dict(self) -> Dict[str, Any]: | ||
func = functools.partial( | ||
get_optimizer_state_dict, | ||
options=StateDictOptions(flatten_optimizer_state_dict=True), | ||
) | ||
return {k: v for sd in map(func, self.model, self.optim) for k, v in sd.items()} | ||
|
||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | ||
func = functools.partial( | ||
set_optimizer_state_dict, | ||
optim_state_dict=state_dict, | ||
options=StateDictOptions(flatten_optimizer_state_dict=True), | ||
) | ||
list(map(func, self.model, self.optim)) | ||
|
||
|
||
class Terminate: | ||
pass | ||
|
||
|
@@ -204,7 +160,7 @@ def __init__( | |
restore its optimizer states, others will error. | ||
|
||
The solution to this problem is optimizer flattening: it landed in #127071 and is enabled in TorchTitan | ||
by passing the 'flatten_optimizer_state_dict' kwarg to DCP functions called in the OptimizerWrapper. | ||
by passing the 'flatten_optimizer_state_dict' kwarg to DCP functions called in the OptimizerContainer. | ||
|
||
2. With complex PP schedules, we have multiple model chunks per pp rank. This compounds challenge (1) by also | ||
requiring us to reason about multiple 'optim' objects locally. | ||
|
@@ -220,44 +176,29 @@ def __init__( | |
|
||
TODO: This is currently unsolved and needs a fix. | ||
""" | ||
assert len(model_parts) == len( | ||
if job_config.optimizer.early_step_in_backward: | ||
expected_optim_num = sum( | ||
len([param for param in model.parameters()]) for model in model_parts | ||
) | ||
else: | ||
expected_optim_num = len(model_parts) | ||
assert expected_optim_num == len( | ||
optimizers.optimizers | ||
), "Must pass one optimizer per model part" | ||
assert len(model_parts) == len( | ||
), "Must pass one optimizer per model part (or per param for optim_in_bwd)" | ||
assert expected_optim_num == len( | ||
lr_schedulers.schedulers | ||
), "Must pass one lr_scheduler per model part" | ||
), "Must pass one lr_scheduler per model part (or per param for optim_in_bwd)" | ||
|
||
self.states = states | ||
|
||
self.states.update( | ||
{ | ||
"model": ModelWrapper(model_parts), | ||
"optimizer": OptimizerWrapper( | ||
mori360 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
model_parts, | ||
optimizers, | ||
), | ||
"optimizer": optimizers, | ||
"dataloader": dataloader, | ||
} | ||
) | ||
# SchedulersInBackwardContainer has a different structure than SchedulersContainer, List[List[Scheduler]] rahter | ||
# than List[Scheduler], but the schedulers are the same for each list inside, so here just store the first one. | ||
# TODO: Restructure SchedulersInBackwardContainer to be consisitent with SchedulersContainer. | ||
if isinstance(lr_schedulers, SchedulersInBackwardContainer): | ||
if len(lr_schedulers.schedulers) == 1: | ||
self.states["lr_scheduler"] = lr_schedulers.schedulers[0][0] | ||
else: | ||
# For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler. | ||
# It should only support saving and loading a distributed checkpoint with the same number of pp ranks | ||
for idx, lr_scheduler in enumerate(lr_schedulers.schedulers): | ||
self.states[f"lr_scheduler_{idx}"] = lr_scheduler[0] | ||
else: | ||
if len(lr_schedulers.schedulers) == 1: | ||
self.states["lr_scheduler"] = lr_schedulers.schedulers[0] | ||
else: | ||
# For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler. | ||
# It should only support saving and loading a distributed checkpoint with the same number of pp ranks | ||
for idx, lr_scheduler in enumerate(lr_schedulers.schedulers): | ||
self.states[f"lr_scheduler_{idx}"] = lr_scheduler | ||
self.states.update(lr_schedulers.get_lr_scheduler_state()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Discussed with @mori360 offline: I think as a next step, we should still have a single entry for |
||
|
||
self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder) | ||
self.interval_type = ( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,18 +5,31 @@ | |
# LICENSE file in the root directory of this source tree. | ||
|
||
import functools | ||
from typing import Any, Dict, List | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch.distributed.checkpoint.state_dict import ( | ||
get_optimizer_state_dict, | ||
set_optimizer_state_dict, | ||
StateDictOptions, | ||
) | ||
from torch.distributed.checkpoint.stateful import Stateful | ||
from torch.optim.lr_scheduler import LambdaLR | ||
from torchtitan.config_manager import JobConfig | ||
|
||
|
||
class OptimizersContainer: | ||
"""Util for calling step/zero_grad on multiple optimizers needed for virtual pipeline stages""" | ||
class OptimizersContainer(Stateful): | ||
"""Util for calling step/zero_grad on multiple optimizers needed for virtual pipeline stages | ||
and saving/loading optimizer state_dict at checkpoint. | ||
""" | ||
|
||
def __init__(self, model_parts, optimizer_kwargs, name): | ||
def __init__( | ||
self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str | ||
) -> None: | ||
self.optimizers = [] | ||
for model in model_parts: | ||
self.model = model_parts | ||
mori360 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for model in self.model: | ||
if name == "Adam": | ||
# TODO: make the optimizer options configurable by toml/cmd args | ||
optimizer = torch.optim.Adam(model.parameters(), **optimizer_kwargs) | ||
|
@@ -26,21 +39,41 @@ def __init__(self, model_parts, optimizer_kwargs, name): | |
raise NotImplementedError(f"Optimizer {name} not added.") | ||
self.optimizers.append(optimizer) | ||
|
||
def step(self): | ||
def step(self) -> None: | ||
for optimizer in self.optimizers: | ||
optimizer.step() | ||
|
||
def zero_grad(self): | ||
def zero_grad(self) -> None: | ||
for optimizer in self.optimizers: | ||
optimizer.zero_grad() | ||
|
||
def state_dict(self) -> Dict[str, Any]: | ||
func = functools.partial( | ||
get_optimizer_state_dict, | ||
options=StateDictOptions(flatten_optimizer_state_dict=True), | ||
) | ||
return { | ||
k: v for sd in map(func, self.model, self.optimizers) for k, v in sd.items() | ||
} | ||
|
||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | ||
func = functools.partial( | ||
set_optimizer_state_dict, | ||
optim_state_dict=state_dict, | ||
options=StateDictOptions(flatten_optimizer_state_dict=True), | ||
) | ||
list(map(func, self.model, self.optimizers)) | ||
|
||
|
||
class OptimizersInBackwardContainer(OptimizersContainer): | ||
"""Optimiers in backward to skip .step() and .zero_grad()""" | ||
|
||
def __init__(self, model_parts, optimizer_kwargs, name): | ||
def __init__( | ||
self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str | ||
) -> None: | ||
self.optimizers = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto. |
||
for model in model_parts: | ||
self.model = model_parts | ||
mori360 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for model in self.model: | ||
if name == "Adam": | ||
# TODO: make the optimizer options configurable by toml/cmd args | ||
optim_dict = { | ||
|
@@ -63,17 +96,19 @@ def optim_hook(param) -> None: | |
if param.requires_grad: | ||
param.register_post_accumulate_grad_hook(optim_hook) | ||
|
||
self.optimizers.append([optim_dict[param] for param in model.parameters()]) | ||
self.optimizers.extend([optim_dict[param] for param in model.parameters()]) | ||
|
||
def step(self): | ||
def step(self) -> None: | ||
pass | ||
|
||
def zero_grad(self): | ||
def zero_grad(self) -> None: | ||
pass | ||
|
||
|
||
# consider split between PP and non-PP | ||
def build_optimizers(model_parts, job_config: JobConfig): | ||
def build_optimizers( | ||
model_parts: List[nn.Module], job_config: JobConfig | ||
) -> OptimizersContainer: | ||
"""Wrap one optimizer per model part in an OptimizersContainer which provides a single | ||
step() and zero_grad() method for all the child optimizers. | ||
""" | ||
|
@@ -121,44 +156,30 @@ def linear_warmup_linear_decay( | |
class SchedulersContainer: | ||
"""Util for calling step on multiple learning rate schedulers needed for virtual pipeline stages""" | ||
|
||
def __init__(self, optimizers, lr_lambda): | ||
def __init__(self, optimizers, lr_lambda) -> None: | ||
self.schedulers = [] | ||
for optimizer in optimizers: | ||
self.schedulers.append(LambdaLR(optimizer, lr_lambda=lr_lambda)) | ||
|
||
def step(self): | ||
def step(self) -> None: | ||
for schedulers in self.schedulers: | ||
schedulers.step() | ||
mori360 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def get_lr_scheduler_state(self) -> Dict[str, Any]: | ||
state_dict = {} | ||
if len(self.schedulers) == 1: | ||
state_dict["lr_scheduler"] = self.schedulers[0] | ||
else: | ||
# For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler. | ||
# It should only support saving and loading a distributed checkpoint with the same number of pp ranks | ||
Comment on lines
+187
to
+188
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel this could be solved, if we do flattening on lr_scheduler state_dict similar to what we did to models and optimizers. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is hard to achieve. LRScheduler's state_dict is very unstructured. It basically just returns There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @fegin I see. Then maybe we can just focus on the LambdaLR, which should be straightforward -- since every scheduler has the same schedule, we can store only one state and recreate schedulers for each optimizer on the fly when doing checkpoint loading. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ye, I want to emphasize that the direction is correct to flatten LRScheduler, which TorchRec also does, iirc. But because we may not have bandwidths to support all of them, so focusing on one or two and claim what TorchTitan supports is a good idea. |
||
for idx, lr_scheduler in enumerate(self.schedulers): | ||
state_dict[f"lr_scheduler_{idx}"] = lr_scheduler | ||
return state_dict | ||
|
||
class SchedulersInBackwardContainer(SchedulersContainer): | ||
"""Util for calling step on multiple learning rate schedulers when optimizers are in backward""" | ||
|
||
def __init__(self, optimizers, lr_lambda): | ||
# all the schedulers for each optimizer group are the same, here we only store the first one | ||
# to self.schedulers follow the same structure as SchedulersContainer, but store all of them | ||
# to self.all_schedulers for container.step() to call | ||
self.schedulers = [] | ||
for optim_group in optimizers: | ||
scheduler_group = [] | ||
for sub_optim in optim_group: | ||
scheduler_group.append(LambdaLR(sub_optim, lr_lambda=lr_lambda)) | ||
self.schedulers.append(scheduler_group) | ||
|
||
def step(self): | ||
for scheduler_group in self.schedulers: | ||
for scheduler in scheduler_group: | ||
scheduler.step() | ||
|
||
|
||
def build_lr_schedulers(optimizers, job_config: JobConfig): | ||
optim_in_bwd = job_config.optimizer.early_step_in_backward | ||
def build_lr_schedulers(optimizers, job_config: JobConfig) -> SchedulersContainer: | ||
warmup_steps = int(job_config.training.warmup_steps) | ||
decay_steps = float(max(1, job_config.training.steps - warmup_steps)) | ||
lr_lambda = functools.partial(linear_warmup_linear_decay, warmup_steps, decay_steps) | ||
|
||
return ( | ||
SchedulersContainer(optimizers, lr_lambda) | ||
if not optim_in_bwd | ||
else SchedulersInBackwardContainer(optimizers, lr_lambda) | ||
) | ||
return SchedulersContainer(optimizers, lr_lambda) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We make the optimizers.optimizers to be plain for both in backward or not in.
Thus
len(model_parts)
does not work for in backward case.What do you think of the assert check here and error message?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should move the sanity check to
OptimizerContainer
's__init__
constructors, rather than doing it here. Basically nothing would've changed after init, and we don't need to check it every time we do checkpoint save.