Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented the loss NTK calculation #109

Open
wants to merge 35 commits into
base: main
Choose a base branch
from

Conversation

ma-sauter
Copy link
Collaborator

No description provided.

Copy link
Member

@KonstiNik KonstiNik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good job in this PR. The first draft looks pretty nice.

I have two major comments:

  1. In general, a test should cover as many aspects of the code as possible. And a test should test the desired aspect as simply as possible. This includes trying to rely as little as possible on existing methods. E.g. including an existing data generator is not as good practice as creating some dummy test data.
  2. With your implementation, we would need to duplicate all observables for the loss ntk that we already have for the ntk. One can avoid this by having the option to either use a recorder for the loss ntk or the regular ntk. This could be done with one keyword at initialization e.g..

# Check if we need a loss NTK computation and update the class accordingly
if any(
[
"loss_ntk" in self._selected_properties,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I see, right now we would have to implement the trace and all other properties again to use them with the loss ntk.
I think it might be more reasonable you had one kwarg like use_loss_ntk with which all ntk calculations are now using the loss ntk, making the entire recorder a loss ntk recorder. With this, we could re-use all the properties we have already implemented.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that there's room for improvement, but I think if we introduce a flag like this here we should maybe also discuss more changes to the recorder. We should talk about this in person or in a meeting, but I'd like to make sure that the tests are working first because it's more urgent for the DPG if that's fine.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The flag got introduced in commit 1dac434

znnl/analysis/loss_ntk_calculation.py Show resolved Hide resolved
znnl/analysis/loss_ntk_calculation.py Outdated Show resolved Hide resolved
CI/unit_tests/analysis/test_loss_ntk_calculation.py Outdated Show resolved Hide resolved
CI/unit_tests/analysis/test_loss_ntk_calculation.py Outdated Show resolved Hide resolved
CI/unit_tests/analysis/test_loss_ntk_calculation.py Outdated Show resolved Hide resolved
znnl/training_strategies/simple_training.py Outdated Show resolved Hide resolved
Copy link
Member

@SamTov SamTov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a few comments. If you go through and address them all I can go back over it but in general, I like it and am happy to have it merged soon.


import os

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove from the test please.


# For LPNormLoss of order 2 and a 1D output Network, the NTK and the loss NTK
# should be the same up to a factor of +1 or -1.
assert_array_almost_equal(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this is an integration test, you will also want to check that the deployment has worked. You can check things like the shape of the stored values.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just name this test_loss_ntk. The naming of the tests should mirror the main python package just with test in front. All integration tests using the loss ntk should be in this one module.


import os

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this from the tests

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rename this to be inline with the package.

)

@staticmethod
def _unshape_data(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sam here.

"""

# Set the attributes
self.ntk_batch_size = model.ntk_batch_size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to make all of these arguments in this calculation. Especially when we later move into the new measurement system, this will all need to be self contained. Things like store_on_device are only pertinent to this calculator.


Returns
-------
input: np.ndarray
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some shape information here? It can be (batch * input size, ) or anything, but just some information about what I will get back. What you mean by unshape is also very unclear. Is it flattening is it reshaping, unshape doesn't have a real meaning.

batch_length, *input_shape[1:]
), datapoint[:, input_dimension:].reshape(batch_length, *target_shape[1:])

def _function_for_loss_ntk(self, params, datapoint) -> float:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer different naming here. Is it an apply function on flattened data, a loss function. What do you mean by subloss? Loss between two data points is just loss. Function for loss ntk could be anything.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the notebook is not clear, can you clear it of outputs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants