A PyTorch implementation of Centered Kernel Alignment (CKA) with GPU support.
This code was used for the CKA analysis in our CVPR 2023 paper, "On the Stability-Plasticity Dilemma of Class-Incremental Learning".
model1 = ... # Some model, casted to GPU
model2 = ... # Another model, casted to GPU
dataloader = ... # Your dataloader
calculator = CKACalculator(model1, model2, dataloader)
cka_matrix = calculator.calculate_cka_matrix()
Rather than caching intermediate feature representations, this code computes CKA on-the-fly (simultaneously with the model forward pass) by using the mini-batch CKA, as described in the paper by Nguyen et. al. By leveraging GPU superiority, this implementation runs much faster than any Numpy implementation.
I haven't added a requirements.txt
since the exact version of each package is not that important 🤷♂️
- python3.7+
- torch (any relatively recent version should be O.K.)
- torchvision
- tqdm
- torchmetrics
- jupyter
- matplotlib
- numpy
Try out the example notebook in example.ipynb
.
- If you found this repo helpful, please give it a ⭐
- If you find any bugs/improvements, feel free to create a new issue.
- This code is mostly tested on ResNets
- Ditch hooks; change to
torch.fx
implementation