Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BE] Combine OptimizerWrapper and OptimizerContainer #738

Merged
merged 8 commits into from
Dec 20, 2024
Merged
87 changes: 14 additions & 73 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,14 @@
import torch.nn as nn
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
get_optimizer_state_dict,
set_model_state_dict,
set_optimizer_state_dict,
StateDictOptions,
)
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import DataLoader
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging import init_logger, logger
from torchtitan.optimizer import (
OptimizersContainer,
OptimizersInBackwardContainer,
SchedulersContainer,
SchedulersInBackwardContainer,
)
from torchtitan.optimizer import OptimizersContainer, SchedulersContainer


class IntervalType(enum.Enum):
Expand Down Expand Up @@ -104,43 +97,6 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
list(map(func, self.model))


class OptimizerWrapper(Stateful):
def __init__(
self,
model: Union[nn.Module, List[nn.Module]],
optim: OptimizersContainer,
) -> None:
self.model = [model] if isinstance(model, nn.Module) else model
if isinstance(optim, OptimizersInBackwardContainer):
self.optim = [
sub_optim
for optim_group in optim.optimizers
for sub_optim in optim_group
]
else:
optimizers = optim.optimizers
self.optim = (
[optimizers]
if isinstance(optimizers, torch.optim.Optimizer)
else optimizers
)

def state_dict(self) -> Dict[str, Any]:
func = functools.partial(
get_optimizer_state_dict,
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
return {k: v for sd in map(func, self.model, self.optim) for k, v in sd.items()}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
func = functools.partial(
set_optimizer_state_dict,
optim_state_dict=state_dict,
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
list(map(func, self.model, self.optim))


class Terminate:
pass

Expand Down Expand Up @@ -204,7 +160,7 @@ def __init__(
restore its optimizer states, others will error.

The solution to this problem is optimizer flattening: it landed in #127071 and is enabled in TorchTitan
by passing the 'flatten_optimizer_state_dict' kwarg to DCP functions called in the OptimizerWrapper.
by passing the 'flatten_optimizer_state_dict' kwarg to DCP functions called in the OptimizerContainer.

2. With complex PP schedules, we have multiple model chunks per pp rank. This compounds challenge (1) by also
requiring us to reason about multiple 'optim' objects locally.
Expand All @@ -220,44 +176,29 @@ def __init__(

TODO: This is currently unsolved and needs a fix.
"""
assert len(model_parts) == len(
if job_config.optimizer.early_step_in_backward:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We make the optimizers.optimizers to be plain for both in backward or not in.
Thus len(model_parts) does not work for in backward case.
What do you think of the assert check here and error message?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should move the sanity check to OptimizerContainer's __init__ constructors, rather than doing it here. Basically nothing would've changed after init, and we don't need to check it every time we do checkpoint save.

expected_optim_num = sum(
len([param for param in model.parameters()]) for model in model_parts
)
else:
expected_optim_num = len(model_parts)
assert expected_optim_num == len(
optimizers.optimizers
), "Must pass one optimizer per model part"
assert len(model_parts) == len(
), "Must pass one optimizer per model part (or per param for optim_in_bwd)"
assert expected_optim_num == len(
lr_schedulers.schedulers
), "Must pass one lr_scheduler per model part"
), "Must pass one lr_scheduler per model part (or per param for optim_in_bwd)"

self.states = states

self.states.update(
{
"model": ModelWrapper(model_parts),
"optimizer": OptimizerWrapper(
mori360 marked this conversation as resolved.
Show resolved Hide resolved
model_parts,
optimizers,
),
"optimizer": optimizers,
"dataloader": dataloader,
}
)
# SchedulersInBackwardContainer has a different structure than SchedulersContainer, List[List[Scheduler]] rahter
# than List[Scheduler], but the schedulers are the same for each list inside, so here just store the first one.
# TODO: Restructure SchedulersInBackwardContainer to be consisitent with SchedulersContainer.
if isinstance(lr_schedulers, SchedulersInBackwardContainer):
if len(lr_schedulers.schedulers) == 1:
self.states["lr_scheduler"] = lr_schedulers.schedulers[0][0]
else:
# For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler.
# It should only support saving and loading a distributed checkpoint with the same number of pp ranks
for idx, lr_scheduler in enumerate(lr_schedulers.schedulers):
self.states[f"lr_scheduler_{idx}"] = lr_scheduler[0]
else:
if len(lr_schedulers.schedulers) == 1:
self.states["lr_scheduler"] = lr_schedulers.schedulers[0]
else:
# For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler.
# It should only support saving and loading a distributed checkpoint with the same number of pp ranks
for idx, lr_scheduler in enumerate(lr_schedulers.schedulers):
self.states[f"lr_scheduler_{idx}"] = lr_scheduler
self.states.update(lr_schedulers.get_lr_scheduler_state())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed with @mori360 offline: I think as a next step, we should still have a single entry for lr_scheduler. To achieve that, we need to understand if any further flattening is needed. This could potentially solve the "PP multi-schedule doesn't support DCP resharding" problem.


self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
self.interval_type = (
Expand Down
101 changes: 61 additions & 40 deletions torchtitan/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,31 @@
# LICENSE file in the root directory of this source tree.

import functools
from typing import Any, Dict, List

import torch
import torch.nn as nn
from torch.distributed.checkpoint.state_dict import (
get_optimizer_state_dict,
set_optimizer_state_dict,
StateDictOptions,
)
from torch.distributed.checkpoint.stateful import Stateful
from torch.optim.lr_scheduler import LambdaLR
from torchtitan.config_manager import JobConfig


class OptimizersContainer:
"""Util for calling step/zero_grad on multiple optimizers needed for virtual pipeline stages"""
class OptimizersContainer(Stateful):
"""Util for calling step/zero_grad on multiple optimizers needed for virtual pipeline stages
and saving/loading optimizer state_dict at checkpoint.
"""

def __init__(self, model_parts, optimizer_kwargs, name):
def __init__(
self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str
) -> None:
self.optimizers = []
for model in model_parts:
self.model = model_parts
mori360 marked this conversation as resolved.
Show resolved Hide resolved
for model in self.model:
if name == "Adam":
# TODO: make the optimizer options configurable by toml/cmd args
optimizer = torch.optim.Adam(model.parameters(), **optimizer_kwargs)
Expand All @@ -26,21 +39,41 @@ def __init__(self, model_parts, optimizer_kwargs, name):
raise NotImplementedError(f"Optimizer {name} not added.")
self.optimizers.append(optimizer)

def step(self):
def step(self) -> None:
for optimizer in self.optimizers:
optimizer.step()

def zero_grad(self):
def zero_grad(self) -> None:
for optimizer in self.optimizers:
optimizer.zero_grad()

def state_dict(self) -> Dict[str, Any]:
func = functools.partial(
get_optimizer_state_dict,
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
return {
k: v for sd in map(func, self.model, self.optimizers) for k, v in sd.items()
}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
func = functools.partial(
set_optimizer_state_dict,
optim_state_dict=state_dict,
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
list(map(func, self.model, self.optimizers))


class OptimizersInBackwardContainer(OptimizersContainer):
"""Optimiers in backward to skip .step() and .zero_grad()"""

def __init__(self, model_parts, optimizer_kwargs, name):
def __init__(
self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str
) -> None:
self.optimizers = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

for model in model_parts:
self.model = model_parts
mori360 marked this conversation as resolved.
Show resolved Hide resolved
for model in self.model:
if name == "Adam":
# TODO: make the optimizer options configurable by toml/cmd args
optim_dict = {
Expand All @@ -63,17 +96,19 @@ def optim_hook(param) -> None:
if param.requires_grad:
param.register_post_accumulate_grad_hook(optim_hook)

self.optimizers.append([optim_dict[param] for param in model.parameters()])
self.optimizers.extend([optim_dict[param] for param in model.parameters()])

def step(self):
def step(self) -> None:
pass

def zero_grad(self):
def zero_grad(self) -> None:
pass


# consider split between PP and non-PP
def build_optimizers(model_parts, job_config: JobConfig):
def build_optimizers(
model_parts: List[nn.Module], job_config: JobConfig
) -> OptimizersContainer:
"""Wrap one optimizer per model part in an OptimizersContainer which provides a single
step() and zero_grad() method for all the child optimizers.
"""
Expand Down Expand Up @@ -121,44 +156,30 @@ def linear_warmup_linear_decay(
class SchedulersContainer:
"""Util for calling step on multiple learning rate schedulers needed for virtual pipeline stages"""

def __init__(self, optimizers, lr_lambda):
def __init__(self, optimizers, lr_lambda) -> None:
self.schedulers = []
for optimizer in optimizers:
self.schedulers.append(LambdaLR(optimizer, lr_lambda=lr_lambda))

def step(self):
def step(self) -> None:
for schedulers in self.schedulers:
schedulers.step()
mori360 marked this conversation as resolved.
Show resolved Hide resolved

def get_lr_scheduler_state(self) -> Dict[str, Any]:
state_dict = {}
if len(self.schedulers) == 1:
state_dict["lr_scheduler"] = self.schedulers[0]
else:
# For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler.
# It should only support saving and loading a distributed checkpoint with the same number of pp ranks
Comment on lines +187 to +188
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this could be solved, if we do flattening on lr_scheduler state_dict similar to what we did to models and optimizers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is hard to achieve. LRScheduler's state_dict is very unstructured. It basically just returns self.__dict__.items(), which can be anything. And LRScheduler doesn't define a parameter group structure. So different LRScheduler may have different implementations. I'm not sure how to flatten LRScheduler in a general approach. Unless we only focus on one LRScheduler.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin I see. Then maybe we can just focus on the LambdaLR, which should be straightforward -- since every scheduler has the same schedule, we can store only one state and recreate schedulers for each optimizer on the fly when doing checkpoint loading.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ye, I want to emphasize that the direction is correct to flatten LRScheduler, which TorchRec also does, iirc. But because we may not have bandwidths to support all of them, so focusing on one or two and claim what TorchTitan supports is a good idea.

for idx, lr_scheduler in enumerate(self.schedulers):
state_dict[f"lr_scheduler_{idx}"] = lr_scheduler
return state_dict

class SchedulersInBackwardContainer(SchedulersContainer):
"""Util for calling step on multiple learning rate schedulers when optimizers are in backward"""

def __init__(self, optimizers, lr_lambda):
# all the schedulers for each optimizer group are the same, here we only store the first one
# to self.schedulers follow the same structure as SchedulersContainer, but store all of them
# to self.all_schedulers for container.step() to call
self.schedulers = []
for optim_group in optimizers:
scheduler_group = []
for sub_optim in optim_group:
scheduler_group.append(LambdaLR(sub_optim, lr_lambda=lr_lambda))
self.schedulers.append(scheduler_group)

def step(self):
for scheduler_group in self.schedulers:
for scheduler in scheduler_group:
scheduler.step()


def build_lr_schedulers(optimizers, job_config: JobConfig):
optim_in_bwd = job_config.optimizer.early_step_in_backward
def build_lr_schedulers(optimizers, job_config: JobConfig) -> SchedulersContainer:
warmup_steps = int(job_config.training.warmup_steps)
decay_steps = float(max(1, job_config.training.steps - warmup_steps))
lr_lambda = functools.partial(linear_warmup_linear_decay, warmup_steps, decay_steps)

return (
SchedulersContainer(optimizers, lr_lambda)
if not optim_in_bwd
else SchedulersInBackwardContainer(optimizers, lr_lambda)
)
return SchedulersContainer(optimizers, lr_lambda)
Loading