Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BE] Combine OptimizerWrapper and OptimizerContainer #738

Merged
merged 8 commits into from
Dec 20, 2024

Conversation

mori360
Copy link
Contributor

@mori360 mori360 commented Dec 13, 2024

Combine state_dict and load_state_dict from OptimizerWrapper to OptimizerContainer so that we only have one optimzier related class
Also, add get_lr_scheduler_state to SchedulersContainer when update lr_scheduler at self.state

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 13, 2024
@mori360 mori360 requested review from fegin and tianyu-l December 14, 2024 00:32
@mori360 mori360 marked this pull request as ready for review December 14, 2024 00:32
torchtitan/checkpoint.py Show resolved Hide resolved
torchtitan/optimizer.py Outdated Show resolved Hide resolved
torchtitan/optimizer.py Outdated Show resolved Hide resolved
torchtitan/optimizer.py Outdated Show resolved Hide resolved
torchtitan/optimizer.py Outdated Show resolved Hide resolved
@mori360 mori360 marked this pull request as draft December 16, 2024 19:08
Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly good to me. But I do prefer to simplify the code by removing some duplicated variables.

@@ -25,22 +38,49 @@ def __init__(self, model_parts, optimizer_kwargs, name):
else:
raise NotImplementedError(f"Optimizer {name} not added.")
self.optimizers.append(optimizer)
self.plain_optim = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we deduplicate self.plain_optim and self.optimizers? I don't see a reason why do we need to keep both variables.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the comment. Yeah, self.plain_optim is only used for get/load state_dict APIs at checkpoint.
Make self.optimizers to be plain this time, from List[List[optim]] to List[optim] for backward case
Also make some changes related to optimizers.optimizers, to remove SchedulersInBackwardContainer and modify assert check at CheckpointManager

def __init__(self, model_parts, optimizer_kwargs, name):
def __init__(
self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str
) -> None:
self.optimizers = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

@mori360 mori360 marked this pull request as ready for review December 17, 2024 19:48
@@ -220,44 +176,29 @@ def __init__(

TODO: This is currently unsolved and needs a fix.
"""
assert len(model_parts) == len(
if job_config.optimizer.early_step_in_backward:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We make the optimizers.optimizers to be plain for both in backward or not in.
Thus len(model_parts) does not work for in backward case.
What do you think of the assert check here and error message?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should move the sanity check to OptimizerContainer's __init__ constructors, rather than doing it here. Basically nothing would've changed after init, and we don't need to check it every time we do checkpoint save.

@mori360 mori360 requested review from fegin and tianyu-l December 18, 2024 01:04
@@ -220,44 +176,29 @@ def __init__(

TODO: This is currently unsolved and needs a fix.
"""
assert len(model_parts) == len(
if job_config.optimizer.early_step_in_backward:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should move the sanity check to OptimizerContainer's __init__ constructors, rather than doing it here. Basically nothing would've changed after init, and we don't need to check it every time we do checkpoint save.

torchtitan/optimizer.py Outdated Show resolved Hide resolved
torchtitan/optimizer.py Outdated Show resolved Hide resolved
torchtitan/optimizer.py Outdated Show resolved Hide resolved
# It should only support saving and loading a distributed checkpoint with the same number of pp ranks
for idx, lr_scheduler in enumerate(lr_schedulers.schedulers):
self.states[f"lr_scheduler_{idx}"] = lr_scheduler
self.states.update(lr_schedulers.get_lr_scheduler_state())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed with @mori360 offline: I think as a next step, we should still have a single entry for lr_scheduler. To achieve that, we need to understand if any further flattening is needed. This could potentially solve the "PP multi-schedule doesn't support DCP resharding" problem.

Comment on lines +173 to +174
# For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler.
# It should only support saving and loading a distributed checkpoint with the same number of pp ranks
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this could be solved, if we do flattening on lr_scheduler state_dict similar to what we did to models and optimizers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is hard to achieve. LRScheduler's state_dict is very unstructured. It basically just returns self.__dict__.items(), which can be anything. And LRScheduler doesn't define a parameter group structure. So different LRScheduler may have different implementations. I'm not sure how to flatten LRScheduler in a general approach. Unless we only focus on one LRScheduler.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin I see. Then maybe we can just focus on the LambdaLR, which should be straightforward -- since every scheduler has the same schedule, we can store only one state and recreate schedulers for each optimizer on the fly when doing checkpoint loading.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ye, I want to emphasize that the direction is correct to flatten LRScheduler, which TorchRec also does, iirc. But because we may not have bandwidths to support all of them, so focusing on one or two and claim what TorchTitan supports is a good idea.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, thank you! Let's follow up in another PR for the lr scheduler flattening.
Please address final comment before merging.

torchtitan/optimizer.py Outdated Show resolved Hide resolved
@mori360 mori360 merged commit ba24697 into pytorch:main Dec 20, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants