-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Konsti resnet implementation #105
Conversation
- Make small changes in the JaxModel Class to allow to resnet implementation - write huggingface Flax implementation - test the NTK calculation Todo: - test for models beyond resnets - update example script
- run black - fix imports
Thanks for the PR! In this case, I would have said the Black part should have been done down the chain in a separate PR. It makes it very difficult to review the larger changes to the code as there are now 109 files that need looking into. Can you highlight which modules you have changed in the ResNetPR? Alternatively, make a new PR to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is one of the biggest (importance wise) PRs that ZnNL has seen for a long time, so awesome work. I have a couple of points scattered throughout the code but also two things I want to raise here:
- Can we add one or two training pases to the test?
- Did you not have to update the training procedure? The time I got this working, I needed to take into consideration the batch statistics and all these other things being passed correctly. I don't see these changes here, what was the solution?
No, I did not have to. The HF call method is directly compatible with constructing a Trainstate. After constructing it the rest is straight forward. Where exactly did you run into issues? |
The call to the network has a new return signature? It should return the batch stats along with the logits and these batch stats have to be propagated to the network in the forward passes and during updates. We deal with this in the NTK calculation but unless it snuck through the last time I worked on it, there won't be batch stats passed |
Batch_stats are included in our model_state.params. But good point, I have to check whether the batch_stats get handled properly. |
Here for example, each time they call the model, they collect this other part of the output tuple: logits, new_model_state = state.apply_fn(
{'params': params, 'batch_stats': state.batch_stats},
batch['image'],
mutable=['batch_stats'], It isn't the batch stats sorry it is this model state part. This has to be passed to other functions in order for the model to train properly. From my memory, it had something to do with ensuring the batch stats are used and updated correctly. We don't do this in normal training we just ignore this additional output. Now in their weight update, they do the following: new_state = state.apply_gradients(
grads=grads, batch_stats=new_model_state['batch_stats']
) so they need these stats. They also seem to always pass it explicitly in model forward passes: variables = {'params': state.params, 'batch_stats': state.batch_stats}
logits = state.apply_fn(variables, batch['image'], train=False, mutable=False) This is in the eval step so nothing involved in training. This may not be necessary as the initial object is a |
- extend the TrainState to capture for batch statistics - a train_apply method to the jax model to distinguish between evaluating and training a model - adapt the nt, flax and hfflax model to the changes. - rewrite the train step to account for batch statistics
Further Changes need to include:
|
- Fix the parameter handling for other strategies.
You can remove traceopt from this, I can take care of it in my other traceopt PR. I have made enough changes to it there that this would just set the whole thing back. |
- make example clearer - Create example for something that is not clear yet.
There was an issue with HF FlaxResNets when using smaller models with layer_type='basic' instead of layer_type='bottleneck': For my examples and tests to pass, I have therefore used layer_type='bottleneck'. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Such a cool PR! I am excited to see it in action. I did have a few more comments that it would be great if we could discuss.
CI/integration_tests/models/test_huggingface_flax_model_deployment.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noice
Implementation of a flax ResNet from HuggingFace.