Understanding Attention

It’s about time that I thoroughly explored the Attention mechanism
Published

March 25, 2022

It’s about time that I thoroughly explored the Attention mechanism. This was introduced in the Attention is All You Need paper (Vaswani et al. 2017). Attention is the name of a deep learning block and it forms the basis of a Transformer, which is used in almost all NLP models.

Vaswani, Ashish, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. “Attention Is All You Need.” https://arxiv.org/abs/1706.03762.

This sort of post has been done many times before. One particularly good one is The Illustrated Transformer with the companion post of Visualizing A Neural Machine Translation. I would recommend reading them.

Scaled Dot-Product Attention

The basic attention block is the scaled dot-product, which takes three inputs. These inputs are used across attention implementations and are referred to as \(Q\), \(K\) and \(V\). From the paper:

An attention function can be described as mapping a query (\(Q\)) and a set of key-value (\(K\)-\(V\)) pairs to an output, where the query, keys, values, and output are all vectors

These form a block according to the equation:

\[ \text{Attention}(Q,K,V) = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right)V \]

The input consists of queries and keys of dimension \(d_k\), and values of dimension \(d_v\).

These sort of equations are not the easiest to read, so we can also express this as a graph of operations:

If we break down the equation we can see the following parts:

\[ \begin{aligned} \text{Attention}(Q,K,V) &= \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right)V \\ \text{Matrix Multiply}_1 &= QK^T \\ \text{Scale} &= \frac{1}{\sqrt{d_k}} \\ \text{Mask} &= \text{not present in equation} \\ \text{SoftMax} &= \text{softmax}\left( \ldots \right) \\ \text{Matrix Multiply}_2 &= \text{softmax}_{out} V \\ \end{aligned} \]

Which isn’t so bad.

The questions that I have after looking at all this are: * How would we implement this? * Where do the \(Q\), \(K\) and \(V\) vectors come from?

Implementation

Given the description above, along with the paper, I have come up with the following:

Code
import torch
from torch import nn

class ScaledDotProductAttention(nn.Module):
    def __init__(self, dk: int) -> None:
        super().__init__()
        self.scale = 1 / (dk**0.5)

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        attention = q @ k.T
        attention = attention * self.scale
        attention = attention.softmax(dim=-1)
        return attention @ v

Now my implementation is not batchwise and assumes that K only has two dimensions.

To check this implementation I have found a reference implementation, which is available here. The reference implementation is:

Code
import torch.nn.functional as F

class ReferenceScaledDotProductAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''
    # from https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Modules.py

    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)

    def forward(self, q, k, v, mask=None):
        attn = torch.matmul(q / self.temperature, k.transpose(2, 3))

        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)

        attn = self.dropout(F.softmax(attn, dim=-1))
        output = torch.matmul(attn, v)

        return output, attn

This shows that the mask is used to zero out part of the attention. The reference implementation also assumes batched input, which can be seen in the way \(K^T\) is calculated.

I’ve been using the matrix multiplication operator, which uses torch.matmul for tensors, so that part isn’t different at least.

To incorporate batch size support we just have to change the k.T to a transpose. Using negative indices changes the last two dimensions no matter what shape the tensor is.

Code
import torch
from torch import nn

class ScaledDotProductAttention(nn.Module):
    def __init__(self, dk: int) -> None:
        super().__init__()
        self.scale = 1 / (dk**0.5)

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        # the transpose here swaps the last two dimensions
        attention = q @ k.transpose(-2, -1)
        attention = attention * self.scale
        attention = attention.softmax(dim=-1)
        return attention @ v

We can prove that these are equivalent by running the same inputs through both. This does involve floating point operations, which are very sensitive to changes in order. Since the scaling is done in a different position the comparison should be a little bit generous.

Code
q = torch.rand(2, 4, 8, 16)
k = torch.rand(2, 4, 8, 16)
v = torch.rand(2, 4, 8, 16)

# number of parameters in k per batch entry
dk = 4 * 8 * 16

my_attention = ScaledDotProductAttention(dk)
reference_attention = ReferenceScaledDotProductAttention(
    temperature=dk**0.5,
    attn_dropout=0. # disable dropout for the comparison
)

with torch.no_grad():
    my_out = my_attention(q=q, k=k, v=v)

    # the reference model returns the attention as well
    reference_out = reference_attention(q=q, k=k, v=v)[0]

difference = my_out - reference_out

mean_out = my_out.mean().item()
mean_difference = difference.mean().item()

print(f"average output: {mean_out:0.4f}, with a difference of {mean_difference:0.4g}")
average output: 0.5085, with a difference of -1.048e-09

That’s a difference of almost one part in a billion, which seems close enough to me.

Input Shape

Why is the input 4 dimensional? Is it possible to create something that works with any number of dimensions?

The attention block has been created for NLP. In that the dimensions can be seen as:

  • document
  • token index
  • 2x token embedding

I said that my implementation can handle inputs of any shape, because it just flips the last two dimensions. This should help apply this to more tasks.

Let’s try proving that. For this proof we can test that the output matches exactly, as the order of operations will always be the same.

Code
q = torch.rand(8, 16)
k = torch.rand(8, 16)
v = torch.rand(8, 16)

# number of parameters in k per batch entry
dk = 8*16

my_attention = ScaledDotProductAttention(dk)

with torch.no_grad():
    out_2 = my_attention(q=q, k=k, v=v)
    out_3 = my_attention(
        q=q[None, :],
        k=k[None, :],
        v=v[None, :]
    )[0]
    out_4 = my_attention(
        q=q[None, None, :],
        k=k[None, None, :],
        v=v[None, None, :]
    )[0, 0]

two_matches_three = torch.all(torch.eq(out_2, out_3)).item()
two_matches_four = torch.all(torch.eq(out_2, out_4)).item()

print(f"2d input matches 3d input? {two_matches_three}")
print(f"2d input matches 4d input? {two_matches_four}")
2d input matches 3d input? True
2d input matches 4d input? True

So this block works with any dimension input.

Source of Query, Key and Value

Where do the \(Q\), \(K\) and \(V\) vectors come from?

The attention block takes three different arguments and produces only one output. It also lacks any trainable parameters, so if it doesn’t produce the perfect answer it never will.

So how can we solve both of these? We create a the query, key and value vectors by passing the input vector through three matrix transformations. These three matrices are trainable and allow the attention block to learn different behaviours.

Going back to the original equations what are the sizes of the matrices? The \(Q\) and \(K\) matrices have the same dimensions as they are multiplied together (\(K\) being transposed to do so). The \(V\) matrix is then multiplied by the result of that, which means that one side must match the size of \(QK^T\) and the other should match the size of the input. If this holds then the attention block will be repeatable.

Code
import torch
from torch import nn

class SymmetricAttention(nn.Module):
    def __init__(self, kdim: int, vdim: int) -> None:
        super().__init__()
        self.scale = 1 / ((kdim*vdim)**0.5)
        self.q = torch.rand(vdim, kdim)
        self.k = torch.rand(vdim, kdim)
        self.v = torch.rand(vdim, vdim)

    def forward(self, xs: torch.Tensor) -> torch.Tensor:
        q = xs @ self.q
        k = xs @ self.k
        v = xs @ self.v
        attention = q @ k.transpose(-2, -1)
        attention = attention * self.scale
        attention = attention.softmax(dim=-1)
        return attention @ v

Now it’s possible to repeat this block:

Code
xs = torch.rand(2, 4, 8, 16)

my_attention = SymmetricAttention(kdim=8, vdim=16)

with torch.no_grad():
    ys = my_attention(xs)
    ys = my_attention(ys) # can repeat it

With this the overall structure changes:

This looks pretty complex.

Attention Simplification

Two linear transformations can be turned into a single transformation. The matrix multiplication and scaling steps are linear. With this knowledge we can reexpress attention as follows:

Of these steps only the Mask and Softmax are not linear. If we can propose this new structure, can we prove that it is equivalent?

The first step is to show that the matrix multiplication we are doing is associative, i.e. that \((A \cdot B) \cdot C = A \cdot (B \cdot C)\):

Code
a = torch.rand(3, 3)
b = torch.rand(3, 3)
c = torch.rand(3, 3)

left = (a @ b) @ c
right = a @ (b @ c)
difference = left - right

left.mean(), left.std(), difference.mean(), difference.std()
(tensor(1.2551), tensor(0.3480), tensor(-1.9868e-08), tensor(4.2147e-08))

This looks like a difference of less than 1 in a million, which I previously accepted as good enough.

The other part is to be able to incorporate the scaling into the matrix:

Code
a = torch.rand(3, 3)
b = torch.rand(3, 3)
c = 2

left = (a @ b) / c
right = a @ (b / c)
difference = left - right

left.mean(), left.std(), difference.mean(), difference.std()
(tensor(0.3903), tensor(0.2554), tensor(0.), tensor(0.))

With these (rather feeble) demonstrations of equivalence it should be possible to create a simplified attention block from any full attention block. The real work is to be able to combine the \(K\) and \(Q\) matricies as well as the subsequent matrix multiplication.

The base equation is \(((Input \cdot Q) \cdot (Input \cdot K)^T) \cdot Scale\). The primary problem here is combining the \(K\) and \(Q\) matrices.

Code
x = torch.rand(3, 3)
k = torch.rand(3, 3)
q = torch.rand(3, 3)

left = (x @ q) @ (x @ k).T
right = x @ q @ (x @ k).T
difference = left - right

left.mean(), left.std(), difference.mean(), difference.std()
(tensor(2.0970), tensor(0.2826), tensor(0.), tensor(0.))
Code
x = torch.rand(3, 3)
k = torch.rand(3, 3)
q = torch.rand(3, 3)

left = (x @ q) @ (x @ k).T
right = (x @ x) @ (q @ k).T
difference = left - right

left.mean(), left.std(), difference.mean(), difference.std()
(tensor(4.1022), tensor(2.3298), tensor(-0.1896), tensor(1.8094))

The problem here is that matrix multiplication is not commutative. There might well be a way to combine these, but it doesn’t occur to me right now.

After consulting someone very smart they pointed out that the equation \(Input \cdot KQ = ((Input \cdot K) \cdot (Input \cdot Q)) \cdot Scale\) can be expressed as \(AX = B\) where \(A = Input\), \(X = KQ\) and \(B = ((Input \cdot K) \cdot (Input \cdot Q)) \cdot Scale\). This can then be provided to the torch.solve function which will return \(X\).

Code
a = torch.rand(3, 3)
k = torch.rand(3, 3)
q = torch.rand(3, 3)

left = (a @ q) @ (a @ k).T

x = torch.linalg.solve(a, left)
right = a @ x
difference = left - right

left.mean(), left.std(), difference.mean(), difference.std()
(tensor(2.2289), tensor(0.7044), tensor(6.6227e-08), tensor(1.0513e-07))

Does this work with non square matrices?

Code
a = torch.rand(3, 3)
k = torch.rand(3, 2)
q = torch.rand(3, 2)

left = (a @ q) @ (a @ k).T

x = torch.linalg.solve(a, left)
right = a @ x
difference = left - right

left.mean(), left.std(), difference.mean(), difference.std()
(tensor(1.2785), tensor(0.4785), tensor(0.), tensor(0.))

Does the proposed \(X\) matrix contain some traces of the input? After all, this has been proposed based on a fixed input. One way to test this is to vary the value of a in the above code.

Code
a = torch.rand(3, 3)
k = torch.rand(3, 2)
q = torch.rand(3, 2)

left = (a @ q) @ (a @ k).T
x = torch.linalg.solve(a, left)

# change a
a = torch.rand(3, 3)

left = (a @ q) @ (a @ k).T
right = a @ x
difference = left - right

left.mean(), left.std(), difference.mean(), difference.std()
(tensor(0.7653), tensor(0.4533), tensor(-0.0123), tensor(0.4644))

Boo hiss. It looks like the proposed simplification needs to incorporate the fact that the input appears twice in the unsimplified equation. I’m sure that there is an appropriate way to simplify this, it just escapes me.

Visualizing Attention

It’s possible to use bertviz to visualize the attention given to the input. The attention that I have used so far is totally untrained and so visualizing it isn’t that great.

It would be good to see how to apply this to a well trained layer though.

Multi Head Attention

Nah, don’t do this/ Cover how to extract associations from attention. Cover comparison to linear layer comparison to RNN