Explaining PyTorch Modules

How can we show the influence of the input to the output?
Published

February 20, 2023

In a previous post I was investigating the relationship between images and the embedding that CLIP produces. I came up with a way to link the two together, however all of the approaches that I tried had problems. Either the results did not make sense, the numbers broke, or the mathematical approach underpinning the technique had flaws.

This is a more basic exploration of this technique when applied to individual pytorch modules. In this a wider range of tests will be performed over each approach and hopefully a few different approaches can be tested.

Basic Principles

The aim here is to track the influence of the input values over the output values. What does this mean and how can we apply this to a given module?

The influence of a value is the amount of change in output that can be attributed to it. If we were to apply that to basic mathematical operations then we can see how to build up from there to more complex operations.

When describing this I am going to talk about values and composite values. A value is a scalar or tensor such as \(V\) (\(V\) for value). A composite value is that scalar or tensor that has been decomposed into a representation of the influence of the input values, such as \(\left[ C_1, C_{...}, C_n \right]\) (\(C\) for composite).

We can start by defining our basic principles:

  • The composite value should be seen as the decomposition of the original value, that is \(V = \sum_{i = 0}^{n} C_i\).
  • The individual entry in the composite output value should be related to the influence of the individual input value.

Testing

Before proceeding we want a reliable way to empirically test my proposals. Randomly producing values could pass the test. What we need is a way to evaluate this more systematically, especially with edge cases.

I’ve defined a long test that will pass in two composite values with various numerical characteristics and see if the value function output matches the composite value function output. The first thing to do with this function is to test it. It should accept the two functions to compare - one which operates over the values and one which operates over the composite values.

Code
from typing import Callable, Optional, Tuple
import torch
import pandas as pd

Function = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]

def compare_functions(
    value_function: Function,
    composite_function: Function,
    size: Optional[Tuple[int,...]] = None,
) -> pd.DataFrame:
    def rand_1k(
        value_function: Function,
        composite_function: Function,
        size: Optional[Tuple[int,...]] = None,
    ) -> bool:
        return all(
            compare_rand(
                value_function=value_function,
                composite_function=composite_function,
                size=size,
            )
            for _ in range(1_000)
        )
    rand_1k.__name__ = compare_rand.__name__

    return pd.DataFrame([
        {
            "name": f.__name__[len("compare_"):],
            "passed": f(
                value_function=value_function,
                composite_function=composite_function,
                size=size,
            )
        }
        for f in [
            compare_absolute_zeros,
            compare_sum_zeros,
            compare_has_zeros,
            rand_1k,
        ]
    ])

def compare_absolute_zeros(
    value_function: Function,
    composite_function: Function,
    size: Optional[Tuple[int,...]] = None,
) -> bool:
    if not size:
        size = (3, 2)
    zero = torch.zeros(*size)
    nonzero = (torch.rand(*size) * 2) - 1

    return all(
        compare(
            value_function=value_function,
            composite_function=composite_function,
            left=left,
            right=right,
        )
        for left, right in [(zero, zero), (zero, nonzero), (nonzero, zero)]
    )

def compare_sum_zeros(
    value_function: Function,
    composite_function: Function,
    size: Optional[Tuple[int,...]] = None,
) -> bool:
    if not size:
        size = (3, 2)
    zero = (torch.rand(*size) * 2) - 1
    nonzero = (torch.rand(*size) * 2) - 1

    # change the last entry in the last dimension to make the dimension sum to zero
    z = zero.view(-1, size[-1])
    z[:, -1] = -z[:, :-1].sum(dim=-1)

    check = zero.sum(dim=-1).abs().max()
    assert check <= 1e-6, f"assertion failed: {check:.4g}"
    
    return all(
        compare(
            value_function=value_function,
            composite_function=composite_function,
            left=left,
            right=right,
        )
        for left, right in [(zero, zero), (zero, nonzero), (nonzero, zero)]
    )

def compare_has_zeros(
    value_function: Function,
    composite_function: Function,
    size: Optional[Tuple[int,...]] = None,
) -> bool:
    if not size:
        size = (3, 2)
    zero = (torch.rand(*size) * 2) - 1
    nonzero = (torch.rand(*size) * 2) - 1

    # change a random entry in the last dimension to be zero
    z = zero.view(-1)
    count = z.shape[0] // size[-1]
    indices = torch.tensor(range(count)) * size[-1]
    indices += torch.randint(low=0, high=size[-1], size=(count,))
    z[indices] = 0.

    check = zero.abs().min(dim=-1)[0].max()
    assert check <= 1e-6, f"assertion failed: {check:.4g}"
    
    return all(
        compare(
            value_function=value_function,
            composite_function=composite_function,
            left=left,
            right=right,
        )
        for left, right in [(zero, zero), (zero, nonzero), (nonzero, zero)]
    )

def compare_rand(
    value_function: Function,
    composite_function: Function,
    size: Optional[Tuple[int,...]] = None,
) -> bool:
    if not size:
        size = (3, 2)
    left = (torch.rand(*size) * 2) - 1
    right = (torch.rand(*size) * 2) - 1
    
    return compare(
        value_function=value_function,
        composite_function=composite_function,
        left=left,
        right=right,
    )

def compare(
    *,
    value_function: Function,
    composite_function: Function,
    left: torch.Tensor,
    right: torch.Tensor,
    threshold: float = 1e-5,
) -> bool:
    value_output = value_function(left.sum(dim=-1), right.sum(dim=-1))
    composite_output = composite_function(left, right)

    difference = value_output - composite_output.sum(dim=-1)
    difference = difference.abs().min().item()
    return difference <= threshold

We can check this by testing it on the simplest approaches that we come up with first.

Multiplication by a Constant

If I have an input of \(V\) and I multiply that by \(\alpha\) then I get \(\alpha V\). If \(V\) is actually the composite of the values, \(\left[ C_1, C_2 \dots C_n \right]\), and I multiply that by \(\alpha\) then I get \(\left[ \alpha C_1, \alpha C_2 \dots \alpha C_n \right]\).

How can we tell if this is the correct approach? Does it maintain the principles that were defined earlier?

It’s pretty simple to show that the multiplication still maintains the relationship between the composite value and the original value:

\[ \begin{aligned} \alpha V &= \left[ \alpha C_1, \alpha C_2 \dots \alpha C_n \right] \\ \alpha V &= \sum_{i = 0}^{n} \alpha C_i \\ \alpha V &= \alpha \sum_{i = 0}^{n} C_i \\ \alpha V &= \alpha V \end{aligned} \]

I was less precise with my definition of the influence, however I see that the linear relationship between the input value and the output value is still maintained for the composite input value and composite output value.

This seems pretty straightforward so far. Let’s see if it passes the empirical test.

Remember that this takes two values. For this test we will ignore the second value and just multiply by a constant.

compare_functions(
    value_function=lambda left, right: left * 3,
    composite_function=lambda left, right: left * 3,
)
name passed
0 absolute_zeros True
1 sum_zeros True
2 has_zeros True
3 rand True

That worked, so I have some small confidence in the empirical test.

Addition of a Constant

If I have an input of \(V\) and I add \(\alpha\) then I get \(V + \alpha\). If \(V\) is actually the composite values, \(\left[ C_1, C_2, \dots C_n \right]\), and I add \(\alpha\) then I get \(\left[ C_1, C_2, \dots, C_n \right] + \alpha\).

Given the constraints that were defined for multiplication, can we check that this is the correct approach?

  • The composite value has now been extended by an additional term. The sum of the composite value with this additional term is now \(\alpha + \sum_{i = 0}^{n} C_i\).

If we accept this additional term as an extension of the original then the decomposition is still valid. I do not want to assign the sum to any individual value as it would break the second constraint.

  • The added value, \(\alpha\), cannot be assigned to any of the decomposed values. This is because the decomposed values do not influence the output - \(\alpha\) does.

This means that the decomposed value must be viewed as \(\left[ C_1, C_2 \dots C_n \right] + \alpha\). I think that programmatically it will be easier to implement by adding a single additional entry which comes after the composite values. This is because for other operations the unattributed constant value behaves the same as the composite values.

Once again we can empirically test this. Remember that the last index of the composite value is the \(\alpha\), so for the composite operation we will add to that.

import torch

def add(tensor: torch.Tensor, value: float) -> torch.Tensor:
    tensor[:, -1] += value
    return tensor

compare_functions(
    value_function=lambda left, right: left + 3,
    composite_function=lambda left, right: add(left, 3),
)
name passed
0 absolute_zeros True
1 sum_zeros True
2 has_zeros True
3 rand True

Multiplication of Two Composite Values

This is where it gets more tricky.

I want to be able to maintain the principles. We can see this most easily in code.

We are going to multiply two composite values in the simplest way possible:

import torch
import pandas as pd

left_composite = torch.normal(mean=0.0, std=1.0, size=(3,2))
right_composite = torch.normal(mean=0.0, std=1.0, size=(3,2))
left_value = left_composite.sum(dim=-1)
right_value = right_composite.sum(dim=-1)

value = left_value * right_value
composite_value = (left_composite * right_composite).sum(dim=-1)

pd.DataFrame([
    {
        "value": v.item(),
        "composite summed value": cv.item(),
        "absolute difference": (v - cv).abs().item(),
    }
    for v, cv in zip(value, composite_value)
])
value composite summed value absolute difference
0 -0.369377 -0.280636 0.088740
1 -3.130352 0.146490 3.276842
2 0.494240 0.209202 0.285037

The problem here is that by splitting the values we end up with smaller values. These smaller values are then multiplied together, which compounds the difference in magnitude. When summed the compounded difference leads to a different value.

We can see that this isn’t a fixed movement, but is instead a ratio based on how well distributed the value is when turned into the composite value. If we put all of the value into a single index, leaving the other indices at zero, then we get our desired behaviour:

import torch
import pandas as pd

left_composite = torch.zeros(*left_value.shape, 2)
left_composite[:, 0] = left_value

right_composite = torch.zeros(*right_value.shape, 2)
right_composite[:, 0] = right_value

composite_value = (left_composite * right_composite).sum(dim=-1)

pd.DataFrame([
    {
        "value": v.item(),
        "composite summed value": cv.item(),
        "absolute difference": (v - cv).abs().item(),
    }
    for v, cv in zip(value, composite_value)
])
value composite summed value absolute difference
0 -0.369377 -0.369377 0.0
1 -3.130352 -3.130352 0.0
2 0.494240 0.494240 0.0

At this point it would be good to check that our empirical test rejects this approach. So far we have only seen it pass, maybe the empirical test accepts anything?

compare_functions(
    value_function=lambda left, right: left * right,
    composite_function=lambda left, right: left * right,
)
name passed
0 absolute_zeros True
1 sum_zeros False
2 has_zeros False
3 rand False

This gives me more confidence in the empirical test, as it does not accept everything. Now that we can show empirically that this approach fails, we can try to discuss why it fails in a mathematical sense. The first thing to show is what actually gets multiplied.

If we have two vectors, \(\left[ a_1, a_2 \right]\) and \(\left[ b_1, b_2 \right]\) then the result of multiplying them is \(\left[ a_1 b_1, a_2 b_2 \right]\). We can show this quite simply:

import torch

torch.tensor([1, 2]) * torch.tensor([3, 4])
tensor([3, 8])

As the tensors can have arbitrary numbers of dimensions we can try this with larger numbers:

import torch

print("add one more dimension to both to make them 1,2")
display(torch.tensor([[1, 2]]) * torch.tensor([[3, 4]]))
print()

print("add different dimensions to make them 2,1 and 1,2")
display(torch.tensor([[1], [2]]) * torch.tensor([[3, 4]]))
print()

print("add different numbers of dimensions to make them 2,1 and 2")
torch.tensor([[1], [2]]) * torch.tensor([3, 4])
add one more dimension to both to make them 1,2
tensor([[3, 8]])

add different dimensions to make them 2,1 and 1,2
tensor([[3, 4],
        [6, 8]])

add different numbers of dimensions to make them 2,1 and 2
tensor([[3, 4],
        [6, 8]])

The way it works is that it broadcasts any missing dimensions to be a single value. Multiplying a tensor by a single value multiplies every value in the tensor by that single value. By understanding this we can then reason through what is wrong with the multiplication.

One way to show what the multiplication should be is to work through a substitution of the original \(A B\), as follows:

\[ \begin{aligned} A &= \left[ a_1, a_2 \right] \\ &= \sum_{i = 1}^{2} a_i \\ \\ B &= \left[ b_1, b_2 \right] \\ &= \sum_{i = 1}^{2} b_i \\ \\ A B &= \left( \sum_{i = 1}^{2} a_i \right) \left( \sum_{i = 1}^{2} b_i \right) \\ A B &= \left[ a_1, a_2 \right] \left( \sum_{i = 1}^{2} b_i \right) \\ A B &= \sum_{i = 1}^{2} a_i \sum_{j = 1}^{2} b_j \end{aligned} \]

The multiplication that we are actually performing is \(A B = \left[ a_1 b_1, a_2 b_2 \right]\) which is not even close. I’m glad I worked through this.

There is a problem with the proposed implementation - it loses all influence from B, as the B values are rolled up. One way to fix this would be to take an equal amount of influence from each side, by changing the equation slightly:

\[ A B = \frac{ \left( \sum_{i = 1}^{2} a_i \sum_{j = 1}^{2} b_j \right) + \left( \sum_{i = 1}^{2} b_i \sum_{j = 1}^{2} a_j \right) }{2} \]

I like this as I stumbled upon this approach when I originally implemented the batch matrix multiplication step. Let’s see if it passes the tests.

import torch

def composite_multiply(
    left: torch.Tensor,
    right: torch.Tensor,
) -> torch.Tensor:
    rollup_left = left.sum(dim=-1).unsqueeze(-1)
    rollup_right = right.sum(dim=-1).unsqueeze(-1)
    left_multiply = left * rollup_right
    right_multiply = right * rollup_left
    return (left_multiply + right_multiply) / 2

compare_functions(
    value_function=lambda left, right: left * right,
    composite_function=composite_multiply,
)
name passed
0 absolute_zeros True
1 sum_zeros True
2 has_zeros True
3 rand True

This works!

Is this matrix multiplication though? (no)

import torch

left = torch.rand(2,2)
right = torch.rand(2,2)

(left * right) - (left @ right)
tensor([[-0.1304, -0.1619],
        [-0.4379, -0.0524]])

We will handle this when we move onto matrix multiplication.

Addition of Two Composite Values

This should be straightforward, as we can just add the individual values as there should be no funny business. Summing the composite values should maintain the overall sum.

compare_functions(
    value_function=lambda left, right: left + right,
    composite_function=lambda left, right: left + right,
)
name passed
0 absolute_zeros True
1 sum_zeros True
2 has_zeros True
3 rand True

Easy.

Matrix Multiplication

Matrix multiplication is the composite of multiplication and addition. Let’s see if that works out for us.

A 3x2 by 2x3 matrix multiplication is implemented as follows:

\[ \begin{bmatrix} x_{11} & x_{12} & x_{13} \\ x_{21} & x_{22} & x_{23} \end{bmatrix} \begin{bmatrix} y_{11} & y_{12} \\ y_{21} & y_{22} \\ y_{31} & y_{32} \end{bmatrix} = \begin{bmatrix} (x_{11} y_{11}) + (x_{12} y_{21}) + (x_{13} y_{31}) & (x_{11} y_{12}) + (x_{12} y_{22}) + (x_{13} y_{32}) \\ (x_{21} y_{11}) + (x_{22} y_{21}) + (x_{23} y_{31}) & (x_{21} y_{12}) + (x_{22} y_{22}) + (x_{23} y_{32}) \end{bmatrix} \]

This gets tricky when we use composite values. Does it have to be though? What if we step through each of those composite value indices as separate matricies?

import torch

def composite_matrix_multiply(
    left: torch.Tensor,
    right: torch.Tensor,
) -> torch.Tensor:
    right = right.transpose(0,1)
    result = torch.zeros(left.shape[0], right.shape[1], left.shape[2])
    for i in range(left.shape[2]):
        result[:, :, i] = left[:, :, i] @ right[:, :, i]
    return result

compare_functions(
    value_function=lambda left, right: left @ right.T,
    composite_function=composite_matrix_multiply,
    size=(2,3,4),
)
name passed
0 absolute_zeros True
1 sum_zeros False
2 has_zeros False
3 rand False

This has the same kind of problem as the multiplication of two composite values. We can fix that in the same way:

import torch

def composite_matrix_multiply(
    left: torch.Tensor,
    right: torch.Tensor,
) -> torch.Tensor:
    right = right.transpose(0,1)
    left_rollup = left.sum(dim=-1)
    right_rollup = right.sum(dim=-1)

    result = torch.zeros(left.shape[0], right.shape[1], left.shape[2])
    for i in range(left.shape[2]):
        result[:, :, i] = left[:, :, i] @ right_rollup
        result[:, :, i] += left_rollup @ right[:, :, i]
        result[:, :, i] = result[:, :, i] / 2
    return result

compare_functions(
    value_function=lambda left, right: left @ right.T,
    composite_function=composite_matrix_multiply,
    size=(2,3,4),
)
name passed
0 absolute_zeros True
1 sum_zeros True
2 has_zeros True
3 rand True

Is it possible to express this without the loop? Torch has support for Einstein Summation, which is described as:

Sums the product of the elements of the input operands along dimensions specified using a notation based on the Einstein summation convention.

If we review the matrix multiplication, above, we can see that regular matrix multiplication is a sum of products. Furthermore the composite value matrix multiplication is regular matrix multiplication with an additional dimension.

Since this is a new notation that is being introduced we should take care to establish that it is equivalent to what came before. Let’s start by checking that matrix multiplication is equivalent when using einsum:

import torch

left = torch.rand(2, 3)
right = torch.rand(2, 3)

mm_output = left @ right.T
einsum_output = torch.einsum("nm,mp->np", left, right.T)

(mm_output - einsum_output).abs().max()
tensor(0.)

This has worked perfectly. Now we can check if the composite value version of this is equivalent to what we wrote before:

import torch

def composite_matrix_multiply(
    left: torch.Tensor,
    right: torch.Tensor,
) -> torch.Tensor:
    right = right.transpose(0,1)
    left_rollup = left.sum(dim=-1)
    right_rollup = right.sum(dim=-1)

    result = torch.zeros(left.shape[0], right.shape[1], left.shape[2])
    for i in range(left.shape[2]):
        result[:, :, i] = left[:, :, i] @ right_rollup
        result[:, :, i] += left_rollup @ right[:, :, i]
        result[:, :, i] = result[:, :, i] / 2
    return result

left = torch.rand(2, 3, 4)
right = torch.rand(2, 3, 4)

mm_output = composite_matrix_multiply(left, right)
einsum_output = torch.einsum("nmi,mpi->npi", left, right.transpose(0,1))

(mm_output - einsum_output).abs().max()
tensor(2.6074)

This is clearly not equivalent. The problem here is that the einstein summation is doing the original matrix multiplication, which uses the composite value multiplication. Composite value multiplication is tricky to do right, we need to operate over the rolled up values. Let’s try again.

import torch

def composite_matrix_multiply(
    left: torch.Tensor,
    right: torch.Tensor,
) -> torch.Tensor:
    right = right.transpose(0,1)
    left_rollup = left.sum(dim=-1)
    right_rollup = right.sum(dim=-1)

    result = torch.zeros(left.shape[0], right.shape[1], left.shape[2])
    for i in range(left.shape[2]):
        result[:, :, i] = left[:, :, i] @ right_rollup
        result[:, :, i] += left_rollup @ right[:, :, i]
        result[:, :, i] = result[:, :, i] / 2
    return result

left = torch.rand(2, 3, 4)
right = torch.rand(2, 3, 4)

left_rollup = left.sum(dim=-1).unsqueeze(-1)
right_rollup = right.sum(dim=-1).unsqueeze(-1)

mm_output = composite_matrix_multiply(left, right)
einsum_output = (
    torch.einsum("nmi,mpi->npi", left, right_rollup.transpose(0,1)) +
    torch.einsum("nmi,mpi->npi", left_rollup, right.transpose(0,1))
)
einsum_output = einsum_output / 2

(mm_output - einsum_output).abs().max()
tensor(0.)

Now it works great. We still need to check it with our more systematic tests.

import torch

def composite_matrix_multiply(
    left: torch.Tensor,
    right: torch.Tensor,
) -> torch.Tensor:
    right = right.transpose(0,1)
    left_rollup = left.sum(dim=-1).unsqueeze(-1)
    right_rollup = right.sum(dim=-1).unsqueeze(-1)

    result = torch.einsum("nmi,mpi->npi", left, right_rollup)
    result = result + torch.einsum("nmi,mpi->npi", left_rollup, right)
    result = result / 2
    return result

compare_functions(
    value_function=lambda left, right: left @ right.T,
    composite_function=composite_matrix_multiply,
    size=(2,3,4),
)
name passed
0 absolute_zeros True
1 sum_zeros True
2 has_zeros True
3 rand True

The best part about using einsum is that we can apply the same approach to torch.bmm (batch matrix multiply):

import torch

def composite_batch_matrix_multiply(
    left: torch.Tensor,
    right: torch.Tensor,
) -> torch.Tensor:
    right = right.transpose(1,2)
    left_rollup = left.sum(dim=-1).unsqueeze(-1)
    right_rollup = right.sum(dim=-1).unsqueeze(-1)

    result = torch.einsum("bnmi,bmpi->bnpi", left, right_rollup)
    result = result + torch.einsum("bnmi,bmpi->bnpi", left_rollup, right)
    result = result / 2
    return result

compare_functions(
    value_function=lambda left, right: (
        torch.bmm(left, right.transpose(1,2))
    ),
    composite_function=composite_batch_matrix_multiply,
    size=(2,3,4,5),
)
name passed
0 absolute_zeros True
1 sum_zeros True
2 has_zeros True
3 rand True

Softmax

This is where it gets tricky. The softmax function is a non linear mapping of the values which preserves order. It’s possible to generalize the approach taken for softmax to gelu and layer normalization, as they share the same problem; The results are very different if run over the individual composite values compared to the original value. This is because the function is non linear so passing different values in results in very different outputs

We can see this most clearly with ReLU which is defined as \(ReLU(V) = max(0, V)\). If we have a value of 1 then \(ReLU(1) = 1\). If the composite value is \(\left[ 2, -1 \right]\) then \(ReLU(\left[ 2, -1 \right]) = \left[ 2, 0 \right]\). This has clearly broken the sum.

Previously I was operating over the original values and then scaling the composite values accordingly. This is done by calculating the ratio of the input value to the output value (\(\frac{V_{out}}{V_{in}}\)) and then scaling the composite values by this. The problem with this approach is related to very small or zero \(V_{in}\) values. If \(V_{in}\) is very small then the ratio results in boosting the magnitude of the composite values, making it difficult to interpret them. If \(V_{in}\) is zero then the ratio becomes nan and this breaks further calculation completely.

What is desired is a way to:

  • Maintain the constraint that the composite value should be seen as the decomposition of the original value, that is \(V = \sum_{i = 0}^{n} C_i\).
  • Keep the magnitude of the composite value similar to the original value.

It should be possible to use the unattributed constant value to ensure that these constraints hold. Let’s try it out.

import torch

def composite_softmax(
    composite_input_value: torch.Tensor
) -> torch.Tensor:
    input_value = composite_input_value.sum(dim=-1)
    output_value = input_value.softmax(dim=-1)

    # calculate the input magnitude and output magnitude 
    # for scaling each composite value set
    flat_composite_input_value = composite_input_value.reshape(
        -1, composite_input_value.shape[-1]
    )
    composite_input_magnitude = (
        flat_composite_input_value[:, :-1]
            .abs()
            .max(dim=-1)
            [0]
    )
    output_magnitude = output_value.abs().flatten()

    # calculate and apply the ratio
    ratio = output_magnitude / composite_input_magnitude
    flat_composite_output_value = (
        flat_composite_input_value * ratio[:, None]
    )
    # wipe out the values where the ratio would be nan,
    # there is no input influence to track
    flat_composite_output_value[composite_input_magnitude == 0, :] = 0.

    # check that the magnitude has been correctly scaled
    composite_output_magnitude = (
        flat_composite_output_value[:, :-1]
            .abs()
            .max(dim=-1)
            [0]
    )
    magnitude_difference = (
        (composite_output_magnitude - output_magnitude)
            .max()
    )
    assert magnitude_difference <= 1e-6

    # use the unattributed constant value to ensure that
    # the sum of the composite value equals the original value
    flat_composite_output_value[:, -1] = (
        output_value.flatten() -
        flat_composite_output_value[:, :-1].sum(dim=-1)
    )

    composite_output_value = (
        flat_composite_output_value
            .reshape(composite_input_value.shape)
    )
    return composite_output_value
    

compare_functions(
    value_function=lambda left, right: left.softmax(dim=-1),
    composite_function=lambda left, right: composite_softmax(left),
    size=(2,3,4,5),
)
name passed
0 absolute_zeros True
1 sum_zeros True
2 has_zeros True
3 rand True

This is a much more complex approach to the problem than before. It should be possible to apply this to all of the activation functions without having it blow up.

from typing import Callable
import torch
import torch.nn.functional as F

def composite_nonlinear(
    composite_input_value: torch.Tensor,
    function: Callable[[torch.Tensor], torch.Tensor]
) -> torch.Tensor:
    input_value = composite_input_value.sum(dim=-1)
    output_value = function(input_value)

    # calculate the input magnitude and output magnitude
    # for scaling each composite value set
    flat_composite_input_value = composite_input_value.reshape(
        -1, composite_input_value.shape[-1]
    )
    composite_input_magnitude = (
        flat_composite_input_value[:, :-1]
            .abs()
            .max(dim=-1)
            [0]
    )
    output_magnitude = output_value.abs().flatten()

    # calculate and apply the ratio
    ratio = output_magnitude / composite_input_magnitude
    flat_composite_output_value = (
        flat_composite_input_value * ratio[:, None]
    )
    # wipe out the values where the ratio would be nan, there is no input influence to track
    flat_composite_output_value[composite_input_magnitude == 0, :] = 0.

    # check that the magnitude has been correctly scaled
    composite_output_magnitude = (
        flat_composite_output_value[:, :-1]
            .abs()
            .max(dim=-1)
            [0]
    )
    magnitude_difference = (
        (composite_output_magnitude - output_magnitude)
            .max()
    )
    assert magnitude_difference <= 1e-6

    # use the unattributed constant value to ensure that the sum of the composite value equals the original value
    flat_composite_output_value[:, -1] = (
        output_value.flatten() -
        flat_composite_output_value[:, :-1].sum(dim=-1)
    )

    composite_output_value = (
        flat_composite_output_value
            .reshape(composite_input_value.shape)
    )
    return composite_output_value
    

print("softmax")
display(
    compare_functions(
        value_function=lambda left, right: left.softmax(dim=-1),
        composite_function=lambda left, right: (
            composite_nonlinear(left, lambda t: t.softmax(dim=-1))
        ),
        size=(2,3,4,5),
    )
)

print("gelu")
display(
    compare_functions(
        value_function=lambda left, right: F.gelu(left),
        composite_function=lambda left, right: (
            composite_nonlinear(left, F.gelu)
        ),
        size=(2,3,4,5),
    )
)
softmax
name passed
0 absolute_zeros True
1 sum_zeros True
2 has_zeros True
3 rand True
gelu
name passed
0 absolute_zeros True
1 sum_zeros True
2 has_zeros True
3 rand True

CLIP Tracing

Let’s have a go at using this new code to trace the CLIP activations. This is what I tried to do before, and it had a lot of problems. Previously the traced results varied quite a bit:

model type maximum value mean value
original 1.6939 0.0007
tracing 3.2656 -0.0154

The variance here is huge considering that this was designed to produce consistent output with the original model. I believe that this variance is one of the reasons why the results from the traced output was inconclusive.

It’s easy enough to fit the code that we have covered in this post into the structure that was created in that previous post. I’m going to do that now, which is quite a bit of code.

Code
# from src/main/python/blog/tracing/v2023_02/layers.py
from typing import Optional, Tuple

import torch
from torch import nn
from transformers import CLIPModel
from transformers.models.clip.modeling_clip import (
    CLIPMLP,
    CLIPAttention,
    CLIPEncoder,
    CLIPEncoderLayer,
    CLIPVisionTransformer,
)


def load_tracing_image_model(
    model_name: str = "openai/clip-vit-base-patch32",
) -> CLIPModel:
    model = CLIPModel.from_pretrained(model_name)
    tracing_model = TracingCLIPVisionTransformer(
        model=model.vision_model, projection=model.visual_projection
    )
    tracing_model.eval()
    return tracing_model


class TracingCLIPVisionTransformer(nn.Module):
    def __init__(self, model: CLIPVisionTransformer, projection: nn.Linear) -> None:
        super().__init__()
        self.embeddings = model.embeddings
        # misspelling present in model
        self.pre_layernorm = TracingActivation(model.pre_layrnorm)
        self.encoder = CLIPEncoder(model.encoder.config)
        self.encoder.layers = nn.ModuleList(
            [TracingCLIPEncoderLayer(layer) for layer in model.encoder.layers]
        )
        self.post_layernorm = TracingActivation(model.post_layernorm)
        self.visual_projection = TracingLinear(projection)

    def forward(self, pixel_values: Optional[torch.Tensor] = None) -> torch.Tensor:
        embeddings = self.embeddings(pixel_values=pixel_values)
        _, image_regions, _ = embeddings.shape
        embeddings_traced = torch.zeros(
            *embeddings.shape, image_regions, device=embeddings.device
        )
        # 0 is the class embedding, that is not related to the image so it goes
        # in the constant index (last index)
        for i in range(image_regions - 1):
            embeddings_traced[:, i + 1, :, i] = embeddings[:, i + 1, :]
        embeddings_traced[:, 0, :, -1] = embeddings[:, 0, :]

        embeddings_normalized = self.pre_layernorm(embeddings_traced)
        encoded = self.encoder(
            inputs_embeds=embeddings_normalized,
            output_attentions=False,
            output_hidden_states=False,
            return_dict=False,
        )
        encoded = encoded[0]
        encoded = encoded[:, 0, :]
        encoded = self.post_layernorm(encoded)
        return self.visual_projection(encoded)


class TracingCLIPEncoderLayer(nn.Module):
    def __init__(self, layer: CLIPEncoderLayer) -> None:
        super().__init__()
        self.self_attn = TracingCLIPAttention(layer.self_attn)
        self.layer_norm1 = TracingActivation(layer.layer_norm1)
        self.mlp = TracingCLIPMLP(layer.mlp)
        self.layer_norm2 = TracingActivation(layer.layer_norm2)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor],
        causal_attention_mask: Optional[torch.Tensor],
        output_attentions: bool,
    ) -> torch.Tensor:
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
        hidden_states, attn_weights = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            causal_attention_mask=causal_attention_mask,
            output_attentions=output_attentions,
        )
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (attn_weights,)

        return outputs


class TracingCLIPAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, layer: CLIPAttention) -> None:
        super().__init__()
        self.layer = layer

        self.k_proj = TracingLinear(layer.k_proj)
        self.v_proj = TracingLinear(layer.v_proj)
        self.q_proj = TracingLinear(layer.q_proj)
        self.out_proj = TracingLinear(layer.out_proj)
        self.bmm = TracingBMM()
        self.softmax = TracingActivation(nn.Softmax(dim=-1))

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        causal_attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        """Input shape: Batch x Time x Channel x Tracing"""

        bsz, tgt_len, embed_dim, tracking_dim = hidden_states.size()

        # get query proj
        query_states = self.q_proj(hidden_states) * self.layer.scale
        key_states = self._shape(self.k_proj(hidden_states), -1, bsz, tracking_dim)
        value_states = self._shape(self.v_proj(hidden_states), -1, bsz, tracking_dim)

        proj_shape = (bsz * self.layer.num_heads, -1, self.layer.head_dim, tracking_dim)
        query_states = self._shape(query_states, tgt_len, bsz, tracking_dim).view(
            *proj_shape
        )
        key_states = key_states.view(*proj_shape)
        value_states = value_states.view(*proj_shape)

        src_len = key_states.size(1)
        attn_weights = self.bmm(query_states, key_states.transpose(1, 2))

        # code has been cut out here which relates to the
        # causal_attention_mask, attention_mask and error checking
        assert causal_attention_mask is None
        assert attention_mask is None

        attn_weights = self.softmax(attn_weights)

        if output_attentions:
            # this operation is a bit akward, but it's required to
            # make sure that attn_weights keeps its gradient.
            # In order to do so, attn_weights have to reshaped
            # twice and have to be reused in the following
            attn_weights_reshaped = attn_weights.view(
                bsz, self.layer.num_heads, tgt_len, src_len, tracking_dim
            )
            attn_weights = attn_weights_reshaped.view(
                bsz * self.layer.num_heads, tgt_len, src_len, tracking_dim
            )
        else:
            attn_weights_reshaped = None

        # this is not intended for training so dropout is removed entirely
        # attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
        attn_probs = attn_weights

        attn_output = self.bmm(attn_probs, value_states)
        attn_output = attn_output.view(
            bsz, self.layer.num_heads, tgt_len, self.layer.head_dim, tracking_dim
        )
        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.reshape(bsz, tgt_len, embed_dim, tracking_dim)

        attn_output = self.out_proj(attn_output)

        return attn_output, attn_weights_reshaped

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int, tracking_dim: int):
        return (
            tensor.view(
                bsz, seq_len, self.layer.num_heads, self.layer.head_dim, tracking_dim
            )
            .transpose(1, 2)
            .contiguous()
        )


class TracingCLIPMLP(nn.Module):
    def __init__(self, layer: CLIPMLP) -> None:
        super().__init__()
        self.activation_fn = TracingActivation(layer.activation_fn)
        self.fc1 = TracingLinear(layer.fc1)
        self.fc2 = TracingLinear(layer.fc2)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states = self.fc2(hidden_states)
        return hidden_states


class TracingLinear(nn.Module):
    def __init__(self, linear: nn.Linear) -> None:
        super().__init__()
        self.weight = linear.weight[:, :, None]
        self.bias = linear.bias

    def forward(self, xs: torch.Tensor) -> torch.Tensor:
        if len(xs.shape) == 3:
            mm = self._matrix_multiply_3d(xs)
        else:
            mm = self._matrix_multiply_4d(xs)
        if self.bias is not None:
            flat = mm.reshape(-1, *mm.shape[-2:])
            flat[:, :, -1] += self.bias
            mm = flat.reshape(*mm.shape)
        return mm

    def _matrix_multiply_3d(self, xs: torch.Tensor) -> torch.Tensor:
        return torch.einsum("bik,jik->bjk", xs, self.weight)

    def _matrix_multiply_4d(self, xs: torch.Tensor) -> torch.Tensor:
        return torch.einsum("Bbik,Bjik->Bbjk", xs, self.weight[None])


class TracingBMM(nn.Module):
    def forward(self, xs: torch.Tensor, ys: torch.Tensor) -> torch.Tensor:
        x_rollup = xs.sum(dim=-1).unsqueeze(-1)
        y_rollup = ys.sum(dim=-1).unsqueeze(-1)
        x_bmm = torch.einsum("bnmi,bmpi->bnpi", xs, y_rollup)
        y_bmm = torch.einsum("bnmi,bmpi->bnpi", x_rollup, ys)
        bmm = x_bmm + y_bmm
        bmm = bmm / 2
        return bmm


class TracingActivation(nn.Module):
    def __init__(self, activation_function: nn.Module) -> None:
        super().__init__()
        self.activation_function = activation_function

    def forward(self, composite_input_value: torch.Tensor) -> torch.Tensor:
        input_value = composite_input_value.sum(dim=-1)
        output_value = self.activation_function(input_value)

        # calculate the input magnitude and output magnitude for scaling each composite value set
        flat_composite_input_value = composite_input_value.reshape(
            -1, composite_input_value.shape[-1]
        )
        composite_input_magnitude = (
            flat_composite_input_value[:, :-1].abs().max(dim=-1)[0]
        )
        output_magnitude = output_value.abs().flatten()

        # calculate and apply the ratio
        ratio = output_magnitude / composite_input_magnitude
        flat_composite_output_value = flat_composite_input_value * ratio[:, None]
        # wipe out the values where the ratio would be nan, there is no input influence to track
        flat_composite_output_value[composite_input_magnitude == 0, :] = 0.0

        # check that the magnitude has been correctly scaled
        composite_output_magnitude = (
            flat_composite_output_value[:, :-1].abs().max(dim=-1)[0]
        )
        magnitude_difference = (composite_output_magnitude - output_magnitude).max()
        if magnitude_difference > 1e-5:
            raise AssertionError(f"magnitude check failed: {magnitude_difference:0.4g}")

        # use the unattributed constant value to ensure that the sum of the
        # composite value equals the original value
        composite_offset = flat_composite_output_value[:, :-1].sum(dim=-1)
        flat_composite_output_value[:, -1] = output_value.flatten() - composite_offset

        composite_output_value = flat_composite_output_value.reshape(
            composite_input_value.shape
        )
        return composite_output_value
Code
from PIL import Image
import torch
from transformers import CLIPModel, CLIPProcessor

def preprocess_image(filename: str) -> torch.Tensor:
    image = Image.open(filename)
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    return processor.feature_extractor(image, return_tensors="pt").pixel_values

@torch.inference_mode()
def get_output(pixel_values: torch.Tensor) -> torch.Tensor:
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    vision_model = model.vision_model

    model.eval()

    encoder_outputs = vision_model(pixel_values=pixel_values)
    last_hidden_state = encoder_outputs[0]
    pooled_output = last_hidden_state[:, 0, :]
    pooled_output = vision_model.post_layernorm(pooled_output)
    return model.visual_projection(pooled_output)

@torch.inference_mode()
def get_traced_output(pixel_values: torch.Tensor) -> torch.Tensor:
    model = load_tracing_image_model()
    model.eval()

    return model(pixel_values=pixel_values)

@torch.inference_mode()
def get_text_embedding(prompt: str) -> torch.Tensor:
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

    model.eval()

    text_features = processor.tokenizer([prompt], return_tensors="pt")
    text_outputs = model.text_model(**text_features)
    text_embeds = text_outputs[1] # pooler_output
    return model.text_projection(text_embeds)
Code
from PIL import Image
from torchvision.transforms import ToPILImage
import matplotlib.pyplot as plt

image_mean = torch.tensor([
    0.48145466,
    0.4578275,
    0.40821073
])
image_std = torch.tensor([
    0.26862954,
    0.26130258,
    0.27577711
])

def show_region_attention(pixel_values: torch.Tensor, attention: torch.Tensor, scale_attention: bool = True) -> None:
    # need to resize attention to match the values
    attention = attention.reshape(7, 7)
    attention = attention.to(float)
    attention = attention.repeat_interleave(32, dim=0).repeat_interleave(32, dim=1)
    if scale_attention:
        attention = attention - attention.min()
        attention = attention / attention.max()
    if len(attention.shape) == 2:
        attention = attention[None]
    
    pixel_values = pixel_values[0]
    pixel_values = (pixel_values * image_std[:, None, None]) + image_mean[:, None, None]
    scaled_values = pixel_values * attention

    fig, axes = plt.subplots(1,3, figsize=(224/10,224*3/10))
    axes[0].imshow(ToPILImage()(pixel_values))
    axes[1].imshow(ToPILImage()(scaled_values))
    axes[2].imshow(ToPILImage()(attention), cmap="binary_r")
    for axis in axes:
        axis.axis('off')

The first check is to run the model over the same cat image:

a photo of a cat

This will be run through the original CLIP model and the new tracing CLIP model. Then we will compare the original output to the rolled up traced output. Ideally this version of the traced model will be more consistent with the original model.

Code
import warnings
import pandas as pd

with warnings.catch_warnings():
    # transformers/models/clip/processing_clip.py:142: FutureWarning:
    # `feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.
    warnings.simplefilter("ignore")
    pixel_values = preprocess_image("cat.jpg")
    original_embedding = get_output(pixel_values)
    traced_embedding = get_traced_output(pixel_values)
    tracing_difference = (original_embedding - traced_embedding.sum(dim=-1)).abs()

pd.DataFrame([
    {
        "model type": "original",
        "maximum": original_embedding.max().item(),
        "mean": original_embedding.mean().item(),
    },
    {
        "model type": "tracing",
        "maximum": traced_embedding.sum(dim=-1).max().item(),
        "mean": traced_embedding.sum(dim=-1).mean().item(),
    },
    {
        "model type": "difference",
        "maximum": tracing_difference.max().item(),
        "mean": tracing_difference.mean().item(),
    }
])
model type maximum mean
0 original 1.693860 0.000684
1 tracing 1.693868 0.000684
2 difference 0.000008 0.000002

This is wild. These values are nearly identical.

I feel way better about this approach now!

Finally let’s try to find which parts of the image contribute to the cat classification. We originally found that the prompt a rendering of a cat. was the best match for the image. We can use the weights from that text embedding to find what parts of the image contribute most to the cat classification.

Code
import warnings

import pandas as pd
import torch
import torch.nn.functional as F

def show_weights(image: str, prompt: str) -> None:
    with warnings.catch_warnings():
        # transformers/models/clip/processing_clip.py:142: FutureWarning:
        # `feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.
        warnings.simplefilter("ignore")
        pixel_values = preprocess_image(image)
        traced_embedding = get_traced_output(pixel_values)
        traced_embedding = traced_embedding[0]

    text_embedding = get_text_embedding(prompt)
    text_embedding = text_embedding[0]

    similarity = F.cosine_similarity(traced_embedding.sum(dim=-1), text_embedding, dim=0)
    print(f"cosine similarity: {similarity:0.4g}")

    weights = traced_embedding[:, :-1] * text_embedding[:, None]

    weights = weights.T.sum(dim=-1)
    weights = weights.sigmoid()
    weights = weights * 2 - 1

    show_region_attention(
        pixel_values,
        F.relu(weights),
    )
show_weights("cat.jpg", "a rendering of a cat.")
cosine similarity: 0.2972

This is a way better result. The face is what i would expect to contribute most to the classification, and this supports that.

Is this just the areas of the image that had any activation at all? We can check that quite quickly:

import torch

weights = traced_embedding[0, :, :-1]
weights = weights.T.sum(dim=-1)
# could show just the positive activations with this:
# weights[weights <= 0] = 0.
weights = weights.sigmoid()

show_region_attention(
    pixel_values,
    weights,
    scale_attention=False,
)

The use of the text prompt does heavily alter the regions of the image that are used. This is consistent with the fact that there are both positive and negative values in the embedding - when the sign aligns that index is being selected.

Is this approach transferrable? Does it work well for other subjects or images?

show_weights("baseball.jpg", "a photo of a baseball player.")
cosine similarity: 0.2679

show_weights("baseball.jpg", "a centered satellite photo of permanent crop land.")
cosine similarity: 0.1259

The classification of the correct label produces a stronger activation over the interesting parts of the image. When the wrong label is used the overlap between the activations is weaker.

Ultimately the image is processed without consideration for the prompt, so the areas that produce strong output for a given embedding index are likely to show up as being interesting parts of the source image.

Further Work

This has implemented tracking for attention. Attention is used extensively in NLP. It would be very interesting to see how this works with something like a sentiment model - does it show the words or phrases correlated with the document sentiment?