Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BE] Combine OptimizerWrapper and OptimizerContainer #738

Merged
merged 8 commits into from
Dec 20, 2024
Merged
13 changes: 0 additions & 13 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,19 +176,6 @@ def __init__(

TODO: This is currently unsolved and needs a fix.
"""
if job_config.optimizer.early_step_in_backward:
expected_optim_num = sum(
len([param for param in model.parameters()]) for model in model_parts
)
else:
expected_optim_num = len(model_parts)
assert expected_optim_num == len(
optimizers.optimizers
), "Must pass one optimizer per model part (or per param for optim_in_bwd)"
assert expected_optim_num == len(
lr_schedulers.schedulers
), "Must pass one lr_scheduler per model part (or per param for optim_in_bwd)"

self.states = states

self.states.update(
Expand Down
28 changes: 20 additions & 8 deletions torchtitan/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def __init__(
self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str
) -> None:
self.optimizers = []
self.model = model_parts
for model in self.model:
self.model_parts = model_parts
for model in self.model_parts:
if name == "Adam":
# TODO: make the optimizer options configurable by toml/cmd args
optimizer = torch.optim.Adam(model.parameters(), **optimizer_kwargs)
Expand All @@ -38,6 +38,10 @@ def __init__(
else:
raise NotImplementedError(f"Optimizer {name} not added.")
self.optimizers.append(optimizer)
expected_optim_num = len(self.model_parts)
assert expected_optim_num == len(
self.optimizers
), "Must pass one optimizer per model part"
mori360 marked this conversation as resolved.
Show resolved Hide resolved

def step(self) -> None:
for optimizer in self.optimizers:
Expand All @@ -53,7 +57,9 @@ def state_dict(self) -> Dict[str, Any]:
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
return {
k: v for sd in map(func, self.model, self.optimizers) for k, v in sd.items()
k: v
for sd in map(func, self.model_parts, self.optimizers)
for k, v in sd.items()
}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
Expand All @@ -62,7 +68,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
optim_state_dict=state_dict,
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
list(map(func, self.model, self.optimizers))
list(map(func, self.model_parts, self.optimizers))


class OptimizersInBackwardContainer(OptimizersContainer):
Expand All @@ -72,8 +78,8 @@ def __init__(
self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str
) -> None:
self.optimizers = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

self.model = model_parts
for model in self.model:
self.model_parts = model_parts
for model in self.model_parts:
if name == "Adam":
# TODO: make the optimizer options configurable by toml/cmd args
optim_dict = {
Expand All @@ -97,6 +103,12 @@ def optim_hook(param) -> None:
param.register_post_accumulate_grad_hook(optim_hook)

self.optimizers.extend([optim_dict[param] for param in model.parameters()])
expected_optim_num = sum(
len([param for param in model.parameters()]) for model in self.model_parts
)
assert expected_optim_num == len(
self.optimizers
), "Must pass one optimizer per model param part"

def step(self) -> None:
pass
Expand Down Expand Up @@ -162,8 +174,8 @@ def __init__(self, optimizers, lr_lambda) -> None:
self.schedulers.append(LambdaLR(optimizer, lr_lambda=lr_lambda))

def step(self) -> None:
for schedulers in self.schedulers:
schedulers.step()
for scheduler in self.schedulers:
scheduler.step()

def get_lr_scheduler_state(self) -> Dict[str, Any]:
state_dict = {}
Expand Down
Loading