Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate the metrics variables from non-trainable variables. #910

Merged
merged 2 commits into from
Sep 18, 2023

Conversation

qlzh727
Copy link
Member

@qlzh727 qlzh727 commented Sep 18, 2023

As discussed in #897, we separate the metrics related variables from non-trainable variables, so that we can properly leverage the jax memory donation.

This will also allow us to skip the saving for metrics variables during checkpoint/savemodel

@codecov
Copy link

codecov bot commented Sep 18, 2023

Codecov Report

Patch coverage: 100.00% and project coverage change: -3.58% ⚠️

Comparison is base (a465816) 76.82% compared to head (c0098fd) 73.25%.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #910      +/-   ##
==========================================
- Coverage   76.82%   73.25%   -3.58%     
==========================================
  Files         329      329              
  Lines       31427    31434       +7     
  Branches     6112     6114       +2     
==========================================
- Hits        24144    23027    -1117     
- Misses       5719     6893    +1174     
+ Partials     1564     1514      -50     
Flag Coverage Δ
keras_core 73.17% <100.00%> (-3.56%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Changed Coverage Δ
keras_core/layers/layer.py 86.92% <100.00%> (-0.50%) ⬇️

... and 17 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

@@ -602,6 +603,17 @@ def non_trainable_weights(self):
return self.weights
return [v for v in self.weights if not v.trainable]

@property
def metrics_variables(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work with compiled metrics? We should add support for that in the Trainer.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The trainer override this method in

def metrics_variables(self):
, which covered compiled metrics.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Is it tested?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the existing test in trainer does cover that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eg in

self.assertEqual(len(model.metrics_variables), 6)

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@fchollet fchollet merged commit 9d39e9a into keras-team:main Sep 18, 2023
7 of 8 checks passed
@qlzh727 qlzh727 deleted the variables branch September 18, 2023 22:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants