-
Notifications
You must be signed in to change notification settings - Fork 252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BE] Combine OptimizerWrapper and OptimizerContainer #738
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly good to me. But I do prefer to simplify the code by removing some duplicated variables.
torchtitan/optimizer.py
Outdated
@@ -25,22 +38,49 @@ def __init__(self, model_parts, optimizer_kwargs, name): | |||
else: | |||
raise NotImplementedError(f"Optimizer {name} not added.") | |||
self.optimizers.append(optimizer) | |||
self.plain_optim = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we deduplicate self.plain_optim
and self.optimizers
? I don't see a reason why do we need to keep both variables.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the comment. Yeah, self.plain_optim
is only used for get/load state_dict APIs at checkpoint.
Make self.optimizers
to be plain this time, from List[List[optim]]
to List[optim]
for backward case
Also make some changes related to optimizers.optimizers
, to remove SchedulersInBackwardContainer and modify assert check at CheckpointManager
def __init__(self, model_parts, optimizer_kwargs, name): | ||
def __init__( | ||
self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str | ||
) -> None: | ||
self.optimizers = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto.
torchtitan/checkpoint.py
Outdated
@@ -220,44 +176,29 @@ def __init__( | |||
|
|||
TODO: This is currently unsolved and needs a fix. | |||
""" | |||
assert len(model_parts) == len( | |||
if job_config.optimizer.early_step_in_backward: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We make the optimizers.optimizers to be plain for both in backward or not in.
Thus len(model_parts)
does not work for in backward case.
What do you think of the assert check here and error message?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should move the sanity check to OptimizerContainer
's __init__
constructors, rather than doing it here. Basically nothing would've changed after init, and we don't need to check it every time we do checkpoint save.
torchtitan/checkpoint.py
Outdated
@@ -220,44 +176,29 @@ def __init__( | |||
|
|||
TODO: This is currently unsolved and needs a fix. | |||
""" | |||
assert len(model_parts) == len( | |||
if job_config.optimizer.early_step_in_backward: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should move the sanity check to OptimizerContainer
's __init__
constructors, rather than doing it here. Basically nothing would've changed after init, and we don't need to check it every time we do checkpoint save.
# It should only support saving and loading a distributed checkpoint with the same number of pp ranks | ||
for idx, lr_scheduler in enumerate(lr_schedulers.schedulers): | ||
self.states[f"lr_scheduler_{idx}"] = lr_scheduler | ||
self.states.update(lr_schedulers.get_lr_scheduler_state()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed with @mori360 offline: I think as a next step, we should still have a single entry for lr_scheduler
. To achieve that, we need to understand if any further flattening is needed. This could potentially solve the "PP multi-schedule doesn't support DCP resharding" problem.
# For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler. | ||
# It should only support saving and loading a distributed checkpoint with the same number of pp ranks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel this could be solved, if we do flattening on lr_scheduler state_dict similar to what we did to models and optimizers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is hard to achieve. LRScheduler's state_dict is very unstructured. It basically just returns self.__dict__.items()
, which can be anything. And LRScheduler
doesn't define a parameter group structure. So different LRScheduler
may have different implementations. I'm not sure how to flatten LRScheduler
in a general approach. Unless we only focus on one LRScheduler
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fegin I see. Then maybe we can just focus on the LambdaLR, which should be straightforward -- since every scheduler has the same schedule, we can store only one state and recreate schedulers for each optimizer on the fly when doing checkpoint loading.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ye, I want to emphasize that the direction is correct to flatten LRScheduler, which TorchRec also does, iirc. But because we may not have bandwidths to support all of them, so focusing on one or two and claim what TorchTitan supports is a good idea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, thank you! Let's follow up in another PR for the lr scheduler flattening.
Please address final comment before merging.
Combine
state_dict
andload_state_dict
from OptimizerWrapper to OptimizerContainer so that we only have one optimzier related classAlso, add
get_lr_scheduler_state
to SchedulersContainer when updatelr_scheduler
at self.state