Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Konsti resnet implementation #105

Merged
merged 27 commits into from
Nov 17, 2023
Merged

Konsti resnet implementation #105

merged 27 commits into from
Nov 17, 2023

Conversation

KonstiNik
Copy link
Member

@KonstiNik KonstiNik commented Nov 3, 2023

Implementation of a flax ResNet from HuggingFace.

  • Implement a pre-defined model. In that way, a pre-trained model can easily be fine-tuned.

KonstiNik and others added 7 commits November 1, 2023 22:22
- Make small changes in the JaxModel Class to allow to resnet implementation
- write huggingface Flax implementation
- test the NTK calculation

Todo:
- test for models beyond resnets
- update example script
- run black
- fix imports
@KonstiNik KonstiNik requested a review from SamTov November 3, 2023 13:05
@SamTov
Copy link
Member

SamTov commented Nov 3, 2023

Thanks for the PR! In this case, I would have said the Black part should have been done down the chain in a separate PR. It makes it very difficult to review the larger changes to the code as there are now 109 files that need looking into. Can you highlight which modules you have changed in the ResNetPR? Alternatively, make a new PR to main where you only do the black formatting and then merge that one here.

Copy link
Member

@SamTov SamTov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is one of the biggest (importance wise) PRs that ZnNL has seen for a long time, so awesome work. I have a couple of points scattered throughout the code but also two things I want to raise here:

  1. Can we add one or two training pases to the test?
  2. Did you not have to update the training procedure? The time I got this working, I needed to take into consideration the batch statistics and all these other things being passed correctly. I don't see these changes here, what was the solution?

examples/ResNet-Example.ipynb Show resolved Hide resolved
CI/unit_tests/models/test_huggingface_flax_model.py Outdated Show resolved Hide resolved
znnl/models/jax_model.py Show resolved Hide resolved
@KonstiNik
Copy link
Member Author

  1. Did you not have to update the training procedure? The time I got this working, I needed to take into consideration the batch statistics and all these other things being passed correctly. I don't see these changes here, what was the solution?

No, I did not have to. The HF call method is directly compatible with constructing a Trainstate. After constructing it the rest is straight forward. Where exactly did you run into issues?

@SamTov
Copy link
Member

SamTov commented Nov 6, 2023

  1. Did you not have to update the training procedure? The time I got this working, I needed to take into consideration the batch statistics and all these other things being passed correctly. I don't see these changes here, what was the solution?

No, I did not have to. The HF call method is directly compatible with constructing a Trainstate. After constructing it the rest is straight forward. Where exactly did you run into issues?

The call to the network has a new return signature? It should return the batch stats along with the logits and these batch stats have to be propagated to the network in the forward passes and during updates. We deal with this in the NTK calculation but unless it snuck through the last time I worked on it, there won't be batch stats passed

@KonstiNik
Copy link
Member Author

KonstiNik commented Nov 6, 2023

  1. Did you not have to update the training procedure? The time I got this working, I needed to take into consideration the batch statistics and all these other things being passed correctly. I don't see these changes here, what was the solution?

No, I did not have to. The HF call method is directly compatible with constructing a Trainstate. After constructing it the rest is straight forward. Where exactly did you run into issues?

The call to the network has a new return signature? It should return the batch stats along with the logits and these batch stats have to be propagated to the network in the forward passes and during updates. We deal with this in the NTK calculation but unless it snuck through the last time I worked on it, there won't be batch stats passed

Batch_stats are included in our model_state.params. But good point, I have to check whether the batch_stats get handled properly.

@SamTov
Copy link
Member

SamTov commented Nov 6, 2023

  1. Did you not have to update the training procedure? The time I got this working, I needed to take into consideration the batch statistics and all these other things being passed correctly. I don't see these changes here, what was the solution?

No, I did not have to. The HF call method is directly compatible with constructing a Trainstate. After constructing it the rest is straight forward. Where exactly did you run into issues?

The call to the network has a new return signature? It should return the batch stats along with the logits and these batch stats have to be propagated to the network in the forward passes and during updates. We deal with this in the NTK calculation but unless it snuck through the last time I worked on it, there won't be batch stats passed

Batch_stats are included in our model_state.params. But good point, I have to check whether the batch_stats get handled properly.

Here for example, each time they call the model, they collect this other part of the output tuple:

logits, new_model_state = state.apply_fn(
        {'params': params, 'batch_stats': state.batch_stats},
        batch['image'],
        mutable=['batch_stats'],

It isn't the batch stats sorry it is this model state part. This has to be passed to other functions in order for the model to train properly. From my memory, it had something to do with ensuring the batch stats are used and updated correctly. We don't do this in normal training we just ignore this additional output.

Now in their weight update, they do the following:

 new_state = state.apply_gradients(
      grads=grads, batch_stats=new_model_state['batch_stats']
  )

so they need these stats.

They also seem to always pass it explicitly in model forward passes:

variables = {'params': state.params, 'batch_stats': state.batch_stats}
 logits = state.apply_fn(variables, batch['image'], train=False, mutable=False)

This is in the eval step so nothing involved in training. This may not be necessary as the initial object is a dict anyway, but we do need to be sure.

- extend the TrainState to capture for batch statistics
- a train_apply method to the jax model to distinguish between evaluating and training a model
- adapt the nt, flax and hfflax model to the changes.
- rewrite the train step to account for batch statistics
@KonstiNik
Copy link
Member Author

KonstiNik commented Nov 9, 2023

Further Changes need to include:

  • Adapt Training Strategies to the updated TrainState
  • Jit larger training functions
  • Make ResNet example more descriptive

- Fix the parameter handling for other strategies.
@SamTov
Copy link
Member

SamTov commented Nov 10, 2023

Further Changes need to include:

  • Adapt Training Strategies to the updated TrainState
  • Jit larger training functions
  • Adapt trace opt to updated TrainState
  • Make ResNet example more descriptive

You can remove traceopt from this, I can take care of it in my other traceopt PR. I have made enough changes to it there that this would just set the whole thing back.

knikolaou added 4 commits November 10, 2023 17:04
@KonstiNik
Copy link
Member Author

There was an issue with HF FlaxResNets when using smaller models with layer_type='basic' instead of layer_type='bottleneck':
huggingface/transformers#27257
It is fixed and merged into the main branch of hf-transformers. So it might be released soon.

For my examples and tests to pass, I have therefore used layer_type='bottleneck'.

Copy link
Member

@SamTov SamTov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Such a cool PR! I am excited to see it in action. I did have a few more comments that it would be great if we could discuss.

znnl/training_strategies/training_steps.py Show resolved Hide resolved
znnl/training_strategies/training_decorator.py Outdated Show resolved Hide resolved
znnl/models/jax_model.py Outdated Show resolved Hide resolved
znnl/models/huggingface_flax_model.py Outdated Show resolved Hide resolved
znnl/models/flax_model.py Show resolved Hide resolved
znnl/models/flax_model.py Show resolved Hide resolved
CI/unit_tests/models/test_huggingface_flax_model.py Outdated Show resolved Hide resolved
Copy link
Member

@SamTov SamTov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noice

@KonstiNik KonstiNik merged commit ff88aca into main Nov 17, 2023
6 checks passed
@KonstiNik KonstiNik deleted the Konsti_resnet_implementation branch November 17, 2023 17:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants