It’s about time that I thoroughly explored the Attention mechanism. This was introduced in the Attention is All You Need paper (Vaswani et al. 2017). Attention is the name of a deep learning block and it forms the basis of a Transformer, which is used in almost all NLP models.
This sort of post has been done many times before. One particularly good one is The Illustrated Transformer with the companion post of Visualizing A Neural Machine Translation. I would recommend reading them.
Scaled Dot-Product Attention
The basic attention block is the scaled dot-product, which takes three inputs. These inputs are used across attention implementations and are referred to as \(Q\), \(K\) and \(V\). From the paper:
An attention function can be described as mapping a query (\(Q\)) and a set of key-value (\(K\)-\(V\)) pairs to an output, where the query, keys, values, and output are all vectors
These form a block according to the equation:
\[ \text{Attention}(Q,K,V) = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right)V \]
The input consists of queries and keys of dimension \(d_k\), and values of dimension \(d_v\).
These sort of equations are not the easiest to read, so we can also express this as a graph of operations:
If we break down the equation we can see the following parts:
\[ \begin{aligned} \text{Attention}(Q,K,V) &= \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right)V \\ \text{Matrix Multiply}_1 &= QK^T \\ \text{Scale} &= \frac{1}{\sqrt{d_k}} \\ \text{Mask} &= \text{not present in equation} \\ \text{SoftMax} &= \text{softmax}\left( \ldots \right) \\ \text{Matrix Multiply}_2 &= \text{softmax}_{out} V \\ \end{aligned} \]
Which isn’t so bad.
The questions that I have after looking at all this are: * How would we implement this? * Where do the \(Q\), \(K\) and \(V\) vectors come from?
Implementation
Given the description above, along with the paper, I have come up with the following:
Code
import torch
from torch import nn
class ScaledDotProductAttention(nn.Module):
def __init__(self, dk: int) -> None:
super().__init__()
self.scale = 1 / (dk**0.5)
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
= q @ k.T
attention = attention * self.scale
attention = attention.softmax(dim=-1)
attention return attention @ v
Now my implementation is not batchwise and assumes that K only has two dimensions.
To check this implementation I have found a reference implementation, which is available here. The reference implementation is:
Code
import torch.nn.functional as F
class ReferenceScaledDotProductAttention(nn.Module):
''' Scaled Dot-Product Attention '''
# from https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Modules.py
def __init__(self, temperature, attn_dropout=0.1):
super().__init__()
self.temperature = temperature
self.dropout = nn.Dropout(attn_dropout)
def forward(self, q, k, v, mask=None):
= torch.matmul(q / self.temperature, k.transpose(2, 3))
attn
if mask is not None:
= attn.masked_fill(mask == 0, -1e9)
attn
= self.dropout(F.softmax(attn, dim=-1))
attn = torch.matmul(attn, v)
output
return output, attn
This shows that the mask is used to zero out part of the attention. The reference implementation also assumes batched input, which can be seen in the way \(K^T\) is calculated.
I’ve been using the matrix multiplication operator, which uses torch.matmul
for tensors, so that part isn’t different at least.
To incorporate batch size support we just have to change the k.T
to a transpose. Using negative indices changes the last two dimensions no matter what shape the tensor is.
Code
import torch
from torch import nn
class ScaledDotProductAttention(nn.Module):
def __init__(self, dk: int) -> None:
super().__init__()
self.scale = 1 / (dk**0.5)
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
# the transpose here swaps the last two dimensions
= q @ k.transpose(-2, -1)
attention = attention * self.scale
attention = attention.softmax(dim=-1)
attention return attention @ v
We can prove that these are equivalent by running the same inputs through both. This does involve floating point operations, which are very sensitive to changes in order. Since the scaling is done in a different position the comparison should be a little bit generous.
Code
= torch.rand(2, 4, 8, 16)
q = torch.rand(2, 4, 8, 16)
k = torch.rand(2, 4, 8, 16)
v
# number of parameters in k per batch entry
= 4 * 8 * 16
dk
= ScaledDotProductAttention(dk)
my_attention = ReferenceScaledDotProductAttention(
reference_attention =dk**0.5,
temperature=0. # disable dropout for the comparison
attn_dropout
)
with torch.no_grad():
= my_attention(q=q, k=k, v=v)
my_out
# the reference model returns the attention as well
= reference_attention(q=q, k=k, v=v)[0]
reference_out
= my_out - reference_out
difference
= my_out.mean().item()
mean_out = difference.mean().item()
mean_difference
print(f"average output: {mean_out:0.4f}, with a difference of {mean_difference:0.4g}")
average output: 0.5085, with a difference of -1.048e-09
That’s a difference of almost one part in a billion, which seems close enough to me.
Input Shape
Why is the input 4 dimensional? Is it possible to create something that works with any number of dimensions?
The attention block has been created for NLP. In that the dimensions can be seen as:
- document
- token index
- 2x token embedding
I said that my implementation can handle inputs of any shape, because it just flips the last two dimensions. This should help apply this to more tasks.
Let’s try proving that. For this proof we can test that the output matches exactly, as the order of operations will always be the same.
Code
= torch.rand(8, 16)
q = torch.rand(8, 16)
k = torch.rand(8, 16)
v
# number of parameters in k per batch entry
= 8*16
dk
= ScaledDotProductAttention(dk)
my_attention
with torch.no_grad():
= my_attention(q=q, k=k, v=v)
out_2 = my_attention(
out_3 =q[None, :],
q=k[None, :],
k=v[None, :]
v0]
)[= my_attention(
out_4 =q[None, None, :],
q=k[None, None, :],
k=v[None, None, :]
v0, 0]
)[
= torch.all(torch.eq(out_2, out_3)).item()
two_matches_three = torch.all(torch.eq(out_2, out_4)).item()
two_matches_four
print(f"2d input matches 3d input? {two_matches_three}")
print(f"2d input matches 4d input? {two_matches_four}")
2d input matches 3d input? True
2d input matches 4d input? True
So this block works with any dimension input.
Source of Query, Key and Value
Where do the \(Q\), \(K\) and \(V\) vectors come from?
The attention block takes three different arguments and produces only one output. It also lacks any trainable parameters, so if it doesn’t produce the perfect answer it never will.
So how can we solve both of these? We create a the query, key and value vectors by passing the input vector through three matrix transformations. These three matrices are trainable and allow the attention block to learn different behaviours.
Going back to the original equations what are the sizes of the matrices? The \(Q\) and \(K\) matrices have the same dimensions as they are multiplied together (\(K\) being transposed to do so). The \(V\) matrix is then multiplied by the result of that, which means that one side must match the size of \(QK^T\) and the other should match the size of the input. If this holds then the attention block will be repeatable.
Code
import torch
from torch import nn
class SymmetricAttention(nn.Module):
def __init__(self, kdim: int, vdim: int) -> None:
super().__init__()
self.scale = 1 / ((kdim*vdim)**0.5)
self.q = torch.rand(vdim, kdim)
self.k = torch.rand(vdim, kdim)
self.v = torch.rand(vdim, vdim)
def forward(self, xs: torch.Tensor) -> torch.Tensor:
= xs @ self.q
q = xs @ self.k
k = xs @ self.v
v = q @ k.transpose(-2, -1)
attention = attention * self.scale
attention = attention.softmax(dim=-1)
attention return attention @ v
Now it’s possible to repeat this block:
Code
= torch.rand(2, 4, 8, 16)
xs
= SymmetricAttention(kdim=8, vdim=16)
my_attention
with torch.no_grad():
= my_attention(xs)
ys = my_attention(ys) # can repeat it ys
With this the overall structure changes:
This looks pretty complex.
Attention Simplification
Two linear transformations can be turned into a single transformation. The matrix multiplication and scaling steps are linear. With this knowledge we can reexpress attention as follows:
Of these steps only the Mask and Softmax are not linear. If we can propose this new structure, can we prove that it is equivalent?
The first step is to show that the matrix multiplication we are doing is associative, i.e. that \((A \cdot B) \cdot C = A \cdot (B \cdot C)\):
Code
= torch.rand(3, 3)
a = torch.rand(3, 3)
b = torch.rand(3, 3)
c
= (a @ b) @ c
left = a @ (b @ c)
right = left - right
difference
left.mean(), left.std(), difference.mean(), difference.std()
(tensor(1.2551), tensor(0.3480), tensor(-1.9868e-08), tensor(4.2147e-08))
This looks like a difference of less than 1 in a million, which I previously accepted as good enough.
The other part is to be able to incorporate the scaling into the matrix:
Code
= torch.rand(3, 3)
a = torch.rand(3, 3)
b = 2
c
= (a @ b) / c
left = a @ (b / c)
right = left - right
difference
left.mean(), left.std(), difference.mean(), difference.std()
(tensor(0.3903), tensor(0.2554), tensor(0.), tensor(0.))
With these (rather feeble) demonstrations of equivalence it should be possible to create a simplified attention block from any full attention block. The real work is to be able to combine the \(K\) and \(Q\) matricies as well as the subsequent matrix multiplication.
The base equation is \(((Input \cdot Q) \cdot (Input \cdot K)^T) \cdot Scale\). The primary problem here is combining the \(K\) and \(Q\) matrices.
Code
= torch.rand(3, 3)
x = torch.rand(3, 3)
k = torch.rand(3, 3)
q
= (x @ q) @ (x @ k).T
left = x @ q @ (x @ k).T
right = left - right
difference
left.mean(), left.std(), difference.mean(), difference.std()
(tensor(2.0970), tensor(0.2826), tensor(0.), tensor(0.))
Code
= torch.rand(3, 3)
x = torch.rand(3, 3)
k = torch.rand(3, 3)
q
= (x @ q) @ (x @ k).T
left = (x @ x) @ (q @ k).T
right = left - right
difference
left.mean(), left.std(), difference.mean(), difference.std()
(tensor(4.1022), tensor(2.3298), tensor(-0.1896), tensor(1.8094))
The problem here is that matrix multiplication is not commutative. There might well be a way to combine these, but it doesn’t occur to me right now.
After consulting someone very smart they pointed out that the equation \(Input \cdot KQ = ((Input \cdot K) \cdot (Input \cdot Q)) \cdot Scale\) can be expressed as \(AX = B\) where \(A = Input\), \(X = KQ\) and \(B = ((Input \cdot K) \cdot (Input \cdot Q)) \cdot Scale\). This can then be provided to the torch.solve function which will return \(X\).
Code
= torch.rand(3, 3)
a = torch.rand(3, 3)
k = torch.rand(3, 3)
q
= (a @ q) @ (a @ k).T
left
= torch.linalg.solve(a, left)
x = a @ x
right = left - right
difference
left.mean(), left.std(), difference.mean(), difference.std()
(tensor(2.2289), tensor(0.7044), tensor(6.6227e-08), tensor(1.0513e-07))
Does this work with non square matrices?
Code
= torch.rand(3, 3)
a = torch.rand(3, 2)
k = torch.rand(3, 2)
q
= (a @ q) @ (a @ k).T
left
= torch.linalg.solve(a, left)
x = a @ x
right = left - right
difference
left.mean(), left.std(), difference.mean(), difference.std()
(tensor(1.2785), tensor(0.4785), tensor(0.), tensor(0.))
Does the proposed \(X\) matrix contain some traces of the input? After all, this has been proposed based on a fixed input. One way to test this is to vary the value of a
in the above code.
Code
= torch.rand(3, 3)
a = torch.rand(3, 2)
k = torch.rand(3, 2)
q
= (a @ q) @ (a @ k).T
left = torch.linalg.solve(a, left)
x
# change a
= torch.rand(3, 3)
a
= (a @ q) @ (a @ k).T
left = a @ x
right = left - right
difference
left.mean(), left.std(), difference.mean(), difference.std()
(tensor(0.7653), tensor(0.4533), tensor(-0.0123), tensor(0.4644))
Boo hiss. It looks like the proposed simplification needs to incorporate the fact that the input appears twice in the unsimplified equation. I’m sure that there is an appropriate way to simplify this, it just escapes me.
Visualizing Attention
It’s possible to use bertviz to visualize the attention given to the input. The attention that I have used so far is totally untrained and so visualizing it isn’t that great.
It would be good to see how to apply this to a well trained layer though.
Multi Head Attention
Nah, don’t do this/ Cover how to extract associations from attention. Cover comparison to linear layer comparison to RNN