diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index 6a14fc23..db54ccd9 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -176,19 +176,6 @@ def __init__( TODO: This is currently unsolved and needs a fix. """ - if job_config.optimizer.early_step_in_backward: - expected_optim_num = sum( - len([param for param in model.parameters()]) for model in model_parts - ) - else: - expected_optim_num = len(model_parts) - assert expected_optim_num == len( - optimizers.optimizers - ), "Must pass one optimizer per model part (or per param for optim_in_bwd)" - assert expected_optim_num == len( - lr_schedulers.schedulers - ), "Must pass one lr_scheduler per model part (or per param for optim_in_bwd)" - self.states = states self.states.update( diff --git a/torchtitan/optimizer.py b/torchtitan/optimizer.py index 4f30f289..d31f2826 100644 --- a/torchtitan/optimizer.py +++ b/torchtitan/optimizer.py @@ -28,8 +28,8 @@ def __init__( self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str ) -> None: self.optimizers = [] - self.model = model_parts - for model in self.model: + self.model_parts = model_parts + for model in self.model_parts: if name == "Adam": # TODO: make the optimizer options configurable by toml/cmd args optimizer = torch.optim.Adam(model.parameters(), **optimizer_kwargs) @@ -38,6 +38,10 @@ def __init__( else: raise NotImplementedError(f"Optimizer {name} not added.") self.optimizers.append(optimizer) + expected_optim_num = len(self.model_parts) + assert expected_optim_num == len( + self.optimizers + ), "Must pass one optimizer per model part" def step(self) -> None: for optimizer in self.optimizers: @@ -53,7 +57,9 @@ def state_dict(self) -> Dict[str, Any]: options=StateDictOptions(flatten_optimizer_state_dict=True), ) return { - k: v for sd in map(func, self.model, self.optimizers) for k, v in sd.items() + k: v + for sd in map(func, self.model_parts, self.optimizers) + for k, v in sd.items() } def load_state_dict(self, state_dict: Dict[str, Any]) -> None: @@ -62,7 +68,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: optim_state_dict=state_dict, options=StateDictOptions(flatten_optimizer_state_dict=True), ) - list(map(func, self.model, self.optimizers)) + list(map(func, self.model_parts, self.optimizers)) class OptimizersInBackwardContainer(OptimizersContainer): @@ -72,8 +78,8 @@ def __init__( self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str ) -> None: self.optimizers = [] - self.model = model_parts - for model in self.model: + self.model_parts = model_parts + for model in self.model_parts: if name == "Adam": # TODO: make the optimizer options configurable by toml/cmd args optim_dict = { @@ -97,6 +103,12 @@ def optim_hook(param) -> None: param.register_post_accumulate_grad_hook(optim_hook) self.optimizers.extend([optim_dict[param] for param in model.parameters()]) + expected_optim_num = sum( + len([param for param in model.parameters()]) for model in self.model_parts + ) + assert expected_optim_num == len( + self.optimizers + ), "Must pass one optimizer per model param part" def step(self) -> None: pass @@ -162,8 +174,8 @@ def __init__(self, optimizers, lr_lambda) -> None: self.schedulers.append(LambdaLR(optimizer, lr_lambda=lr_lambda)) def step(self) -> None: - for schedulers in self.schedulers: - schedulers.step() + for scheduler in self.schedulers: + scheduler.step() def get_lr_scheduler_state(self) -> Dict[str, Any]: state_dict = {}