This is my implementation of the paper Discrete Key Value Bottleneck, by Trauble, Goyal et al.
Deep neural networks perform exceedingly well on classification tasks where data streams are i.i.d. and labelled data is abundant. Challenges emerge in a production-level scenario where the data streams are non-stationary.
One good approach is the fine-tuning paradigm: pre-train large encoderson volumes of readily available data followed by task-specific tuning. However, this approach faces challenges in that during the fine-tuning of a large number of weights, information about the previous task is lost in a process called Catastrophic Forgetting.
The authors build upon a discrete key-value bottleneck containing a number of separate, learnable key-value pairs. The paradigm followed is
The input is fed to a pre-trained encoder, the output of the encoder is used to select the nearest keys, and the corresponding values are fed to the decoder to solve the task.
GOAL: To learn a model
Let the model be formulated as
In the first step an input is fed to the encoder
A KEY-VALUE CODEBOOK is a bijection that maps each code vector to a different value vector which is learnable. Within each codebook, a quantization process
For the purpose of classification the authors propose a simple non-parametric decoder function which uses average-pooling to calculate the element-wise average of all the fetched value codes and then applies a softmax on top of it.
We perform a simple eight-class classification task in a class-incremental manner to show the efficacy of the bottleneck. In each stage, we sample 100 examples of two classes for 1000 training steps, using gradient descent to update the weights, then move on to two new classes for the next 1000 steps. The input features of each class follow spatially separated normal distributions:
The results are clear: the naive models based on linear probes and a simple multi-layer perceptron simply overfit on the most recent training data, thus forgetting the previously learned information. However, at each step the Discrete Key Value Bottleneck holds on to the previous information while also learning new ones.