-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Separate the metrics variables from non-trainable variables. #910
Conversation
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## main #910 +/- ##
==========================================
- Coverage 76.82% 73.25% -3.58%
==========================================
Files 329 329
Lines 31427 31434 +7
Branches 6112 6114 +2
==========================================
- Hits 24144 23027 -1117
- Misses 5719 6893 +1174
+ Partials 1564 1514 -50
Flags with carried forward coverage won't be shown. Click here to find out more.
☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
@@ -602,6 +603,17 @@ def non_trainable_weights(self): | |||
return self.weights | |||
return [v for v in self.weights if not v.trainable] | |||
|
|||
@property | |||
def metrics_variables(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this work with compiled metrics? We should add support for that in the Trainer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The trainer override this method in
keras-core/keras_core/trainers/trainer.py
Line 233 in a465816
def metrics_variables(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. Is it tested?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the existing test in trainer does cover that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eg in
self.assertEqual(len(model.metrics_variables), 6) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
As discussed in #897, we separate the metrics related variables from non-trainable variables, so that we can properly leverage the jax memory donation.
This will also allow us to skip the saving for metrics variables during checkpoint/savemodel