Inspired by: https://github.com/sail-sg/poolformer - https://arxiv.org/pdf/2111.11418.pdf
In this section, we point out that the multi-headed scaled dot product attention introduced in 2017 is equivalent to a general quadratic form that lends itself to a more efficient reformulation. Furthermore, we argue on the grounds of efficiency, interpretability and regularization for the imposition that the form be a metric/metric-like tensor.
What follows is a short exposition of scaled dot product, using Ricci calculus to avoid underspecification and transitioning into the proposed quadratic and metric attentions.
Let
Each query is dotted with every other key and the result is inversly scaled by the root of the dimensionality of the projection space before being softmaxed along one of the directions, producing
where
and the result is reflatened and projected to the original embedding space
Our focus is on the step right before the application of a softmax
By substituting the operations that produced the queries and keys,
and by defining
It is evident that the original group of equations are equivalent to the simple quadratic form.
The motivation for using multiple heads that operate on a smaller dimensional space is that, whearas the quadratic form makes use of
However, it is not the most efficient reformulation that can be squeezed out of the quadratic form. Let us assume that there exists
This restriction has now halved the number of parameters down to
Some additional things to note:
- the
$U^n_{dd'} = P^{nk}_d P^{nk} _{d'}$ condition restricts the amount of possible values of$U^n _{dd'}$ , leading to a possible regularization effect - the
$U^n_{dd'} = P^{nk}_d P^{nk} _{d'}$ condition leads to metric-like properties like non-negativity and symmetry - moving forward towards a true metric might mean venturing into more computationally complex operations, missing properties: identity of indiscernibles and triangle inequality
Let
At the heart of the proposed attention mechanism is a learnable dot product of each projected embedding with each other embedding. This is achieved using
The metric tensor is symmetric, so we can reduce the number of computations by grouping the terms strategically,
Let
and
Such an arrangement is easily achieved by storing two arrays to be used as a lookup table for
which we use to rewrite our original expression as
where
At this point, our expression already fits quite well within a cuda kernel. Note how the
However, a further computational saving is unlocked with the usage of a metric tensor, since dot products are comutative it follows that
To avoid repetition, I'll do the treatment for the following expression
and perform symbol substitution where necessary in order to place it back on the expression we're working. Performing direct substitution we get
which we can similarly split into two terms
Substituting this back, while attending to the relevant substitution on the first term of the original expression,
which we'll now group according to the
Note that for every combination of
To proceed with the rest of the attention mechanism,
but followed by the application of the scores on the same projection
The result is then reflattened and a final transformation is applied to ensure mixing of the features and align the dimensionality to the original embedding space
to provide some clarity into how this fits toguether in a cuda kernel, here q_bnul corresponds to
$r^{bnul}$ which is then summed over l afterwards to get$r^{bnu}$
template <typename scalar_t>
__global__ void metric_attention_forwards_kernel(
CudaTensorView<scalar_t, 4> p_bnck,
CudaTensorView<scalar_t, 2> M_nl,
CudaTensorView<scalar_t, 4> q_bnul,
CudaTensorView<size_t, 2> index_table_2l,
CudaTensorView<size_t, 2> index_table_2u,
const int max_global_idx
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x; // Global thread index
if (idx > max_global_idx) return;
size_t b;
compute_index(idx, q_bnul.size(0), b);
size_t n;
compute_index(idx, q_bnul.size(1), n);
size_t u;
compute_index(idx, q_bnul.size(2), u);
size_t l;
compute_index(idx, q_bnul.size(3), l);
size_t k = index_table_2l[0][l];
size_t k_1 = index_table_2l[1][l];
size_t c = index_table_2u[0][u];
size_t c_1 = index_table_2u[1][u];
// assign common factor
q_bnul[b][n][u][l] = M_nl[n][l]*p_bnck[b][n][c][k];
if (k == k_1 && c == c_1){
q_bnul[b][n][u][l] *= p_bnck[b][n][c][k];
} else if (k == k_1 && c != c_1) {
q_bnul[b][n][u][l] *= 2*p_bnck[b][n][c_1][k];
} else if (k != k_1 && c == c_1) {
q_bnul[b][n][u][l] *= 2*p_bnck[b][n][c][k_1];
} else if (k != k_1 && c != c_1) {
q_bnul[b][n][u][l] *= 4*p_bnck[b][n][c_1][k_1];
}
}
In the backwards pass, we're interested in calculating the following quantities,
and
where
Gradient with respect with the metric coordinates:
Gradient with respect to the input coordinates
Which can be rewritten as
Note: all workflows have been removed, pipelines are being moved to prefect
Name and Status | Dataset | Usability | Workflow Badge |
---|---|---|---|
Sentiment Analysis Task (Completed with success) | asa-v0.2.0 | Outdated | |
Sentiment Analysis Task (Completed without success, model overfits easily) | stanford dataset | Outdated | |
GPT Shakespeare Textgen (Completed with success) | sha-v0.1.0 | Outdated | |
GPT Array Sorter Experiment (Completed with success) | Generated | Outdated |
NanoGPT was trained to sort the tokens 1, 2 and 3.
- induced distances between the embeddings for each token
- position (i, j) = distance between token i and token j
- note how the first head is clearly encoding for the sort order
- scaled dot product doesn't really have an analogue to this, so there's nothing to compare
- it does however, also have scores tables, which we can compare
- scores for scaled dot product
- scores for metric based
- we can also try to compare the weights matrices
- in case of metric attention, they are metric tensors
- in case of scaled dot product, we use WqWk.T as an analogue
- https://github.com/Digital-Defiance/IMBd-dataset
- Early results on this dataset strongly point to the attention mechanism not being important for the task
- Quadratic attention, straight average pooling and even an identity map were able to substitute scaled dot product with no signs of decreasing accuracy (1 transformer block followed by point-wise projection into the number of classes and an averaging of the embeddings, invalidating the capacity of the ouput layer as a possible explanation )
These are some results and explorations from earlier experiments, they will soon be replaced by final (and more intelligible) results.
- Modified Self Attention, Metric Tensor Heads (possible avenues to look at when trying to interpret what they are doing)
- Loss Graph Comparison between Transformer and Metric Tensor Network (not much difference)
-
Output Comparison (not much difference)
-
Transformer:
The meaning of life is full of me:
if
I spy an age be content: the sea excuse that very
are to achieve for our prisoner to Rome's wife,
'Sirrah's see, command, let twenty pound
Strive might now; since is about me than,
Were then but the point of death: he were a
them where I'll wear what to wash you, for
And copy of the action; and gave me down himself
For why I should give for these fashion of them
Whether but relished to the hand:
Then speak, old and pray, no when the petition
With what, by our petition you bear this, after;
Not writ we held him. When like subjects put out,
That six years would reap the will we more
To follow us fly the throne of heaven, as sad
Which had no further. There, gentle Paulina,
The same no bodes with our valiant power,
And that's the herd, there be an his certain
Nor private gentlemen with you,--O, and your
Long may answer, from us, to fly seeing remorse,
The mutinous their eyes, who hath slain!
His senate-face, and my life sent,
The dangerous lenity where she starts;
And all with the sea or mistaken;
For him from whence can I do.
SOMERSET:
No310 of fear, here comes it.
ARCHBUSHY:
Ay, give? it not fall of this:
If thy mother shall be seen the world
Might gently before thyself in time.
MeDecline image and look'd, then, take him:
'Shall I we see thee thy tongue.
GREEN:
All Edward again. Give me to France, madam, I.
- metric tensor net
The meaning of life is soaking,'er's friend,
For I will in some man. It were to Richmond,
But by the general made up,
And when he walks, make him yea,
Thou shalt teach thee will to give himself?
Than Lewis he did I think of infirm'd too.
HASTINGS:
Under whom me so I swear to deliver me?
HASTINGS:
Ghost that I, a kingdom this amongst us.
BUCKINGHAM:
His lie such an Cates, he fears you.
KING EDWARD IV:
But raise this stands giftedave.
QUEEN MARGARET:
The rest be not your crown?
QUEEN ELIZABETH:
Is this once, that I enforce his sign of four
Which be uncle, till I let me to have done,
And not privy friend to a grief weep.
An, and my husband's wife hath done a want of mine.
My frost may follow to love.
Y ANNE:
The high forehead Margaret of Warwick mans your tongue and Derby,
To prove it of Buckingham shall way the streets.
QUEEN ELIZABETH:
Ay, by this device are butcher of Glouces;
Poor high love kill it will--
QUEEN ELIZABETH: may awake Boling;
And unblown, unto the cause
Or once to her repeal'd in private.
InsTER:
Come, no, the dying sovereign to my son and this land what
And were for Edward to thither to kill'd.
The knights and no conquest of them?
But do you be nor bestow' sovereign, nor debt:
Our children of Clarence, if 'tis trueborn blood.
Thus till then, my Edward is like our course of scful!
In all the results from very early experiments, despite the parameter reduction and the strong constraints, the network seemed to perform the same during and after training
- https://paperswithcode.com/method/strided-attention
- https://paperswithcode.com/method/fixed-factorized-attention
- https://paperswithcode.com/method/dot-product-attention
- https://paperswithcode.com/method/scaled
In our code, we use a specific notation to denote the shape of tensors. Here's how it works:
-
A tensor's shape is indicated by appending a suffix to the variable name. Each letter in the suffix corresponds to a dimension in the tensor's shape. For example, a tensor with shape
(a, b, c)
would be namedsome_tensor_abc
:a, b, c = 10, 3, 5 some_tensor_abc = torch.randn(a, b, c)
-
If the dimensions of the tensor are not represented by single letters, we use their initials. For instance, a tensor with dimensions
batch_size
andvocabolary_size
would be namedsome_tensor_bv
:batch_size, vocabolary_size = 32, 1024 some_tensor_bv = torch.randn(batch_size, vocabolary_size)
-
If a dimension has an explicit value, we include that value in the suffix. For example,
some_tensor_b2ac
indicates that the tensor has a second dimension (dim=1
) with a size of 2. We only include explicit values in the suffix if they have more than one digit. -
We also extend this notation to functions. A function name like
some_function_tq
indicates that the function transforms dimensionq
into sizet
:result_abct = some_function_tq(input_abcq)
This notation helps us keep track of tensor shapes throughout our code, making it easier to understand and debug.