Skip to content

Commit

Permalink
git assert check to container, optimizer some names
Browse files Browse the repository at this point in the history
  • Loading branch information
mori360 committed Dec 18, 2024
1 parent 491d372 commit 7fe2928
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 21 deletions.
13 changes: 0 additions & 13 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,19 +176,6 @@ def __init__(
TODO: This is currently unsolved and needs a fix.
"""
if job_config.optimizer.early_step_in_backward:
expected_optim_num = sum(
len([param for param in model.parameters()]) for model in model_parts
)
else:
expected_optim_num = len(model_parts)
assert expected_optim_num == len(
optimizers.optimizers
), "Must pass one optimizer per model part (or per param for optim_in_bwd)"
assert expected_optim_num == len(
lr_schedulers.schedulers
), "Must pass one lr_scheduler per model part (or per param for optim_in_bwd)"

self.states = states

self.states.update(
Expand Down
28 changes: 20 additions & 8 deletions torchtitan/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def __init__(
self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str
) -> None:
self.optimizers = []
self.model = model_parts
for model in self.model:
self.model_parts = model_parts
for model in self.model_parts:
if name == "Adam":
# TODO: make the optimizer options configurable by toml/cmd args
optimizer = torch.optim.Adam(model.parameters(), **optimizer_kwargs)
Expand All @@ -38,6 +38,10 @@ def __init__(
else:
raise NotImplementedError(f"Optimizer {name} not added.")
self.optimizers.append(optimizer)
expected_optim_num = len(self.model_parts)
assert expected_optim_num == len(
self.optimizers
), "Must pass one optimizer per model part"

def step(self) -> None:
for optimizer in self.optimizers:
Expand All @@ -53,7 +57,9 @@ def state_dict(self) -> Dict[str, Any]:
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
return {
k: v for sd in map(func, self.model, self.optimizers) for k, v in sd.items()
k: v
for sd in map(func, self.model_parts, self.optimizers)
for k, v in sd.items()
}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
Expand All @@ -62,7 +68,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
optim_state_dict=state_dict,
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
list(map(func, self.model, self.optimizers))
list(map(func, self.model_parts, self.optimizers))


class OptimizersInBackwardContainer(OptimizersContainer):
Expand All @@ -72,8 +78,8 @@ def __init__(
self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str
) -> None:
self.optimizers = []
self.model = model_parts
for model in self.model:
self.model_parts = model_parts
for model in self.model_parts:
if name == "Adam":
# TODO: make the optimizer options configurable by toml/cmd args
optim_dict = {
Expand All @@ -97,6 +103,12 @@ def optim_hook(param) -> None:
param.register_post_accumulate_grad_hook(optim_hook)

self.optimizers.extend([optim_dict[param] for param in model.parameters()])
expected_optim_num = sum(
len([param for param in model.parameters()]) for model in self.model_parts
)
assert expected_optim_num == len(
self.optimizers
), "Must pass one optimizer per model param part"

def step(self) -> None:
pass
Expand Down Expand Up @@ -162,8 +174,8 @@ def __init__(self, optimizers, lr_lambda) -> None:
self.schedulers.append(LambdaLR(optimizer, lr_lambda=lr_lambda))

def step(self) -> None:
for schedulers in self.schedulers:
schedulers.step()
for scheduler in self.schedulers:
scheduler.step()

def get_lr_scheduler_state(self) -> Dict[str, Any]:
state_dict = {}
Expand Down

0 comments on commit 7fe2928

Please sign in to comment.