How can we show the influence of the input to the output?
Published
February 20, 2023
In a previous post I was investigating the relationship between images and the embedding that CLIP produces. I came up with a way to link the two together, however all of the approaches that I tried had problems. Either the results did not make sense, the numbers broke, or the mathematical approach underpinning the technique had flaws.
This is a more basic exploration of this technique when applied to individual pytorch modules. In this a wider range of tests will be performed over each approach and hopefully a few different approaches can be tested.
Basic Principles
The aim here is to track the influence of the input values over the output values. What does this mean and how can we apply this to a given module?
The influence of a value is the amount of change in output that can be attributed to it. If we were to apply that to basic mathematical operations then we can see how to build up from there to more complex operations.
When describing this I am going to talk about values and composite values. A value is a scalar or tensor such as \(V\) (\(V\) for value). A composite value is that scalar or tensor that has been decomposed into a representation of the influence of the input values, such as \(\left[ C_1, C_{...}, C_n \right]\) (\(C\) for composite).
We can start by defining our basic principles:
The composite value should be seen as the decomposition of the original value, that is \(V = \sum_{i = 0}^{n} C_i\).
The individual entry in the composite output value should be related to the influence of the individual input value.
Testing
Before proceeding we want a reliable way to empirically test my proposals. Randomly producing values could pass the test. What we need is a way to evaluate this more systematically, especially with edge cases.
I’ve defined a long test that will pass in two composite values with various numerical characteristics and see if the value function output matches the composite value function output. The first thing to do with this function is to test it. It should accept the two functions to compare - one which operates over the values and one which operates over the composite values.
Code
from typing import Callable, Optional, Tupleimport torchimport pandas as pdFunction = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]def compare_functions( value_function: Function, composite_function: Function, size: Optional[Tuple[int,...]] =None,) -> pd.DataFrame:def rand_1k( value_function: Function, composite_function: Function, size: Optional[Tuple[int,...]] =None, ) ->bool:returnall( compare_rand( value_function=value_function, composite_function=composite_function, size=size, )for _ inrange(1_000) ) rand_1k.__name__= compare_rand.__name__return pd.DataFrame([ {"name": f.__name__[len("compare_"):],"passed": f( value_function=value_function, composite_function=composite_function, size=size, ) }for f in [ compare_absolute_zeros, compare_sum_zeros, compare_has_zeros, rand_1k, ] ])def compare_absolute_zeros( value_function: Function, composite_function: Function, size: Optional[Tuple[int,...]] =None,) ->bool:ifnot size: size = (3, 2) zero = torch.zeros(*size) nonzero = (torch.rand(*size) *2) -1returnall( compare( value_function=value_function, composite_function=composite_function, left=left, right=right, )for left, right in [(zero, zero), (zero, nonzero), (nonzero, zero)] )def compare_sum_zeros( value_function: Function, composite_function: Function, size: Optional[Tuple[int,...]] =None,) ->bool:ifnot size: size = (3, 2) zero = (torch.rand(*size) *2) -1 nonzero = (torch.rand(*size) *2) -1# change the last entry in the last dimension to make the dimension sum to zero z = zero.view(-1, size[-1]) z[:, -1] =-z[:, :-1].sum(dim=-1) check = zero.sum(dim=-1).abs().max()assert check <=1e-6, f"assertion failed: {check:.4g}"returnall( compare( value_function=value_function, composite_function=composite_function, left=left, right=right, )for left, right in [(zero, zero), (zero, nonzero), (nonzero, zero)] )def compare_has_zeros( value_function: Function, composite_function: Function, size: Optional[Tuple[int,...]] =None,) ->bool:ifnot size: size = (3, 2) zero = (torch.rand(*size) *2) -1 nonzero = (torch.rand(*size) *2) -1# change a random entry in the last dimension to be zero z = zero.view(-1) count = z.shape[0] // size[-1] indices = torch.tensor(range(count)) * size[-1] indices += torch.randint(low=0, high=size[-1], size=(count,)) z[indices] =0. check = zero.abs().min(dim=-1)[0].max()assert check <=1e-6, f"assertion failed: {check:.4g}"returnall( compare( value_function=value_function, composite_function=composite_function, left=left, right=right, )for left, right in [(zero, zero), (zero, nonzero), (nonzero, zero)] )def compare_rand( value_function: Function, composite_function: Function, size: Optional[Tuple[int,...]] =None,) ->bool:ifnot size: size = (3, 2) left = (torch.rand(*size) *2) -1 right = (torch.rand(*size) *2) -1return compare( value_function=value_function, composite_function=composite_function, left=left, right=right, )def compare(*, value_function: Function, composite_function: Function, left: torch.Tensor, right: torch.Tensor, threshold: float=1e-5,) ->bool: value_output = value_function(left.sum(dim=-1), right.sum(dim=-1)) composite_output = composite_function(left, right) difference = value_output - composite_output.sum(dim=-1) difference = difference.abs().min().item()return difference <= threshold
We can check this by testing it on the simplest approaches that we come up with first.
Multiplication by a Constant
If I have an input of \(V\) and I multiply that by \(\alpha\) then I get \(\alpha V\). If \(V\) is actually the composite of the values, \(\left[ C_1, C_2 \dots C_n \right]\), and I multiply that by \(\alpha\) then I get \(\left[ \alpha C_1, \alpha C_2 \dots \alpha C_n \right]\).
How can we tell if this is the correct approach? Does it maintain the principles that were defined earlier?
It’s pretty simple to show that the multiplication still maintains the relationship between the composite value and the original value:
\[
\begin{aligned}
\alpha V &= \left[ \alpha C_1, \alpha C_2 \dots \alpha C_n \right] \\
\alpha V &= \sum_{i = 0}^{n} \alpha C_i \\
\alpha V &= \alpha \sum_{i = 0}^{n} C_i \\
\alpha V &= \alpha V
\end{aligned}
\]
I was less precise with my definition of the influence, however I see that the linear relationship between the input value and the output value is still maintained for the composite input value and composite output value.
This seems pretty straightforward so far. Let’s see if it passes the empirical test.
Remember that this takes two values. For this test we will ignore the second value and just multiply by a constant.
compare_functions( value_function=lambda left, right: left *3, composite_function=lambda left, right: left *3,)
name
passed
0
absolute_zeros
True
1
sum_zeros
True
2
has_zeros
True
3
rand
True
That worked, so I have some small confidence in the empirical test.
Addition of a Constant
If I have an input of \(V\) and I add \(\alpha\) then I get \(V + \alpha\). If \(V\) is actually the composite values, \(\left[ C_1, C_2, \dots C_n \right]\), and I add \(\alpha\) then I get \(\left[ C_1, C_2, \dots, C_n \right] + \alpha\).
Given the constraints that were defined for multiplication, can we check that this is the correct approach?
The composite value has now been extended by an additional term. The sum of the composite value with this additional term is now \(\alpha + \sum_{i = 0}^{n} C_i\).
If we accept this additional term as an extension of the original then the decomposition is still valid. I do not want to assign the sum to any individual value as it would break the second constraint.
The added value, \(\alpha\), cannot be assigned to any of the decomposed values. This is because the decomposed values do not influence the output - \(\alpha\) does.
This means that the decomposed value must be viewed as \(\left[ C_1, C_2 \dots C_n \right] + \alpha\). I think that programmatically it will be easier to implement by adding a single additional entry which comes after the composite values. This is because for other operations the unattributed constant value behaves the same as the composite values.
Once again we can empirically test this. Remember that the last index of the composite value is the \(\alpha\), so for the composite operation we will add to that.
The problem here is that by splitting the values we end up with smaller values. These smaller values are then multiplied together, which compounds the difference in magnitude. When summed the compounded difference leads to a different value.
We can see that this isn’t a fixed movement, but is instead a ratio based on how well distributed the value is when turned into the composite value. If we put all of the value into a single index, leaving the other indices at zero, then we get our desired behaviour:
At this point it would be good to check that our empirical test rejects this approach. So far we have only seen it pass, maybe the empirical test accepts anything?
compare_functions( value_function=lambda left, right: left * right, composite_function=lambda left, right: left * right,)
name
passed
0
absolute_zeros
True
1
sum_zeros
False
2
has_zeros
False
3
rand
False
This gives me more confidence in the empirical test, as it does not accept everything. Now that we can show empirically that this approach fails, we can try to discuss why it fails in a mathematical sense. The first thing to show is what actually gets multiplied.
If we have two vectors, \(\left[ a_1, a_2 \right]\) and \(\left[ b_1, b_2 \right]\) then the result of multiplying them is \(\left[ a_1 b_1, a_2 b_2 \right]\). We can show this quite simply:
As the tensors can have arbitrary numbers of dimensions we can try this with larger numbers:
import torchprint("add one more dimension to both to make them 1,2")display(torch.tensor([[1, 2]]) * torch.tensor([[3, 4]]))print()print("add different dimensions to make them 2,1 and 1,2")display(torch.tensor([[1], [2]]) * torch.tensor([[3, 4]]))print()print("add different numbers of dimensions to make them 2,1 and 2")torch.tensor([[1], [2]]) * torch.tensor([3, 4])
add one more dimension to both to make them 1,2
tensor([[3, 8]])
add different dimensions to make them 2,1 and 1,2
tensor([[3, 4],
[6, 8]])
add different numbers of dimensions to make them 2,1 and 2
tensor([[3, 4],
[6, 8]])
The way it works is that it broadcasts any missing dimensions to be a single value. Multiplying a tensor by a single value multiplies every value in the tensor by that single value. By understanding this we can then reason through what is wrong with the multiplication.
One way to show what the multiplication should be is to work through a substitution of the original \(A B\), as follows:
\[
\begin{aligned}
A &= \left[ a_1, a_2 \right] \\
&= \sum_{i = 1}^{2} a_i \\
\\
B &= \left[ b_1, b_2 \right] \\
&= \sum_{i = 1}^{2} b_i \\
\\
A B &= \left( \sum_{i = 1}^{2} a_i \right) \left( \sum_{i = 1}^{2} b_i \right) \\
A B &= \left[ a_1, a_2 \right] \left( \sum_{i = 1}^{2} b_i \right) \\
A B &= \sum_{i = 1}^{2} a_i \sum_{j = 1}^{2} b_j
\end{aligned}
\]
The multiplication that we are actually performing is \(A B = \left[ a_1 b_1, a_2 b_2 \right]\) which is not even close. I’m glad I worked through this.
There is a problem with the proposed implementation - it loses all influence from B, as the B values are rolled up. One way to fix this would be to take an equal amount of influence from each side, by changing the equation slightly:
We will handle this when we move onto matrix multiplication.
Addition of Two Composite Values
This should be straightforward, as we can just add the individual values as there should be no funny business. Summing the composite values should maintain the overall sum.
compare_functions( value_function=lambda left, right: left + right, composite_function=lambda left, right: left + right,)
name
passed
0
absolute_zeros
True
1
sum_zeros
True
2
has_zeros
True
3
rand
True
Easy.
Matrix Multiplication
Matrix multiplication is the composite of multiplication and addition. Let’s see if that works out for us.
A 3x2 by 2x3 matrix multiplication is implemented as follows:
This gets tricky when we use composite values. Does it have to be though? What if we step through each of those composite value indices as separate matricies?
import torchdef composite_matrix_multiply( left: torch.Tensor, right: torch.Tensor,) -> torch.Tensor: right = right.transpose(0,1) result = torch.zeros(left.shape[0], right.shape[1], left.shape[2])for i inrange(left.shape[2]): result[:, :, i] = left[:, :, i] @ right[:, :, i]return resultcompare_functions( value_function=lambda left, right: left @ right.T, composite_function=composite_matrix_multiply, size=(2,3,4),)
name
passed
0
absolute_zeros
True
1
sum_zeros
False
2
has_zeros
False
3
rand
False
This has the same kind of problem as the multiplication of two composite values. We can fix that in the same way:
Is it possible to express this without the loop? Torch has support for Einstein Summation, which is described as:
Sums the product of the elements of the input operands along dimensions specified using a notation based on the Einstein summation convention.
If we review the matrix multiplication, above, we can see that regular matrix multiplication is a sum of products. Furthermore the composite value matrix multiplication is regular matrix multiplication with an additional dimension.
Since this is a new notation that is being introduced we should take care to establish that it is equivalent to what came before. Let’s start by checking that matrix multiplication is equivalent when using einsum:
This is clearly not equivalent. The problem here is that the einstein summation is doing the original matrix multiplication, which uses the composite value multiplication. Composite value multiplication is tricky to do right, we need to operate over the rolled up values. Let’s try again.
Now it works great. We still need to check it with our more systematic tests.
import torchdef composite_matrix_multiply( left: torch.Tensor, right: torch.Tensor,) -> torch.Tensor: right = right.transpose(0,1) left_rollup = left.sum(dim=-1).unsqueeze(-1) right_rollup = right.sum(dim=-1).unsqueeze(-1) result = torch.einsum("nmi,mpi->npi", left, right_rollup) result = result + torch.einsum("nmi,mpi->npi", left_rollup, right) result = result /2return resultcompare_functions( value_function=lambda left, right: left @ right.T, composite_function=composite_matrix_multiply, size=(2,3,4),)
name
passed
0
absolute_zeros
True
1
sum_zeros
True
2
has_zeros
True
3
rand
True
The best part about using einsum is that we can apply the same approach to torch.bmm (batch matrix multiply):
import torchdef composite_batch_matrix_multiply( left: torch.Tensor, right: torch.Tensor,) -> torch.Tensor: right = right.transpose(1,2) left_rollup = left.sum(dim=-1).unsqueeze(-1) right_rollup = right.sum(dim=-1).unsqueeze(-1) result = torch.einsum("bnmi,bmpi->bnpi", left, right_rollup) result = result + torch.einsum("bnmi,bmpi->bnpi", left_rollup, right) result = result /2return resultcompare_functions( value_function=lambda left, right: ( torch.bmm(left, right.transpose(1,2)) ), composite_function=composite_batch_matrix_multiply, size=(2,3,4,5),)
name
passed
0
absolute_zeros
True
1
sum_zeros
True
2
has_zeros
True
3
rand
True
Softmax
This is where it gets tricky. The softmax function is a non linear mapping of the values which preserves order. It’s possible to generalize the approach taken for softmax to gelu and layer normalization, as they share the same problem; The results are very different if run over the individual composite values compared to the original value. This is because the function is non linear so passing different values in results in very different outputs
We can see this most clearly with ReLU which is defined as \(ReLU(V) = max(0, V)\). If we have a value of 1 then \(ReLU(1) = 1\). If the composite value is \(\left[ 2, -1 \right]\) then \(ReLU(\left[ 2, -1 \right]) = \left[ 2, 0 \right]\). This has clearly broken the sum.
Previously I was operating over the original values and then scaling the composite values accordingly. This is done by calculating the ratio of the input value to the output value (\(\frac{V_{out}}{V_{in}}\)) and then scaling the composite values by this. The problem with this approach is related to very small or zero \(V_{in}\) values. If \(V_{in}\) is very small then the ratio results in boosting the magnitude of the composite values, making it difficult to interpret them. If \(V_{in}\) is zero then the ratio becomes nan and this breaks further calculation completely.
What is desired is a way to:
Maintain the constraint that the composite value should be seen as the decomposition of the original value, that is \(V = \sum_{i = 0}^{n} C_i\).
Keep the magnitude of the composite value similar to the original value.
It should be possible to use the unattributed constant value to ensure that these constraints hold. Let’s try it out.
import torchdef composite_softmax( composite_input_value: torch.Tensor) -> torch.Tensor: input_value = composite_input_value.sum(dim=-1) output_value = input_value.softmax(dim=-1)# calculate the input magnitude and output magnitude # for scaling each composite value set flat_composite_input_value = composite_input_value.reshape(-1, composite_input_value.shape[-1] ) composite_input_magnitude = ( flat_composite_input_value[:, :-1] .abs() .max(dim=-1) [0] ) output_magnitude = output_value.abs().flatten()# calculate and apply the ratio ratio = output_magnitude / composite_input_magnitude flat_composite_output_value = ( flat_composite_input_value * ratio[:, None] )# wipe out the values where the ratio would be nan,# there is no input influence to track flat_composite_output_value[composite_input_magnitude ==0, :] =0.# check that the magnitude has been correctly scaled composite_output_magnitude = ( flat_composite_output_value[:, :-1] .abs() .max(dim=-1) [0] ) magnitude_difference = ( (composite_output_magnitude - output_magnitude) .max() )assert magnitude_difference <=1e-6# use the unattributed constant value to ensure that# the sum of the composite value equals the original value flat_composite_output_value[:, -1] = ( output_value.flatten() - flat_composite_output_value[:, :-1].sum(dim=-1) ) composite_output_value = ( flat_composite_output_value .reshape(composite_input_value.shape) )return composite_output_valuecompare_functions( value_function=lambda left, right: left.softmax(dim=-1), composite_function=lambda left, right: composite_softmax(left), size=(2,3,4,5),)
name
passed
0
absolute_zeros
True
1
sum_zeros
True
2
has_zeros
True
3
rand
True
This is a much more complex approach to the problem than before. It should be possible to apply this to all of the activation functions without having it blow up.
from typing import Callableimport torchimport torch.nn.functional as Fdef composite_nonlinear( composite_input_value: torch.Tensor, function: Callable[[torch.Tensor], torch.Tensor]) -> torch.Tensor: input_value = composite_input_value.sum(dim=-1) output_value = function(input_value)# calculate the input magnitude and output magnitude# for scaling each composite value set flat_composite_input_value = composite_input_value.reshape(-1, composite_input_value.shape[-1] ) composite_input_magnitude = ( flat_composite_input_value[:, :-1] .abs() .max(dim=-1) [0] ) output_magnitude = output_value.abs().flatten()# calculate and apply the ratio ratio = output_magnitude / composite_input_magnitude flat_composite_output_value = ( flat_composite_input_value * ratio[:, None] )# wipe out the values where the ratio would be nan, there is no input influence to track flat_composite_output_value[composite_input_magnitude ==0, :] =0.# check that the magnitude has been correctly scaled composite_output_magnitude = ( flat_composite_output_value[:, :-1] .abs() .max(dim=-1) [0] ) magnitude_difference = ( (composite_output_magnitude - output_magnitude) .max() )assert magnitude_difference <=1e-6# use the unattributed constant value to ensure that the sum of the composite value equals the original value flat_composite_output_value[:, -1] = ( output_value.flatten() - flat_composite_output_value[:, :-1].sum(dim=-1) ) composite_output_value = ( flat_composite_output_value .reshape(composite_input_value.shape) )return composite_output_valueprint("softmax")display( compare_functions( value_function=lambda left, right: left.softmax(dim=-1), composite_function=lambda left, right: ( composite_nonlinear(left, lambda t: t.softmax(dim=-1)) ), size=(2,3,4,5), ))print("gelu")display( compare_functions( value_function=lambda left, right: F.gelu(left), composite_function=lambda left, right: ( composite_nonlinear(left, F.gelu) ), size=(2,3,4,5), ))
softmax
name
passed
0
absolute_zeros
True
1
sum_zeros
True
2
has_zeros
True
3
rand
True
gelu
name
passed
0
absolute_zeros
True
1
sum_zeros
True
2
has_zeros
True
3
rand
True
CLIP Tracing
Let’s have a go at using this new code to trace the CLIP activations. This is what I tried to do before, and it had a lot of problems. Previously the traced results varied quite a bit:
model type
maximum value
mean value
original
1.6939
0.0007
tracing
3.2656
-0.0154
The variance here is huge considering that this was designed to produce consistent output with the original model. I believe that this variance is one of the reasons why the results from the traced output was inconclusive.
It’s easy enough to fit the code that we have covered in this post into the structure that was created in that previous post. I’m going to do that now, which is quite a bit of code.
Code
# from src/main/python/blog/tracing/v2023_02/layers.pyfrom typing import Optional, Tupleimport torchfrom torch import nnfrom transformers import CLIPModelfrom transformers.models.clip.modeling_clip import ( CLIPMLP, CLIPAttention, CLIPEncoder, CLIPEncoderLayer, CLIPVisionTransformer,)def load_tracing_image_model( model_name: str="openai/clip-vit-base-patch32",) -> CLIPModel: model = CLIPModel.from_pretrained(model_name) tracing_model = TracingCLIPVisionTransformer( model=model.vision_model, projection=model.visual_projection ) tracing_model.eval()return tracing_modelclass TracingCLIPVisionTransformer(nn.Module):def__init__(self, model: CLIPVisionTransformer, projection: nn.Linear) ->None:super().__init__()self.embeddings = model.embeddings# misspelling present in modelself.pre_layernorm = TracingActivation(model.pre_layrnorm)self.encoder = CLIPEncoder(model.encoder.config)self.encoder.layers = nn.ModuleList( [TracingCLIPEncoderLayer(layer) for layer in model.encoder.layers] )self.post_layernorm = TracingActivation(model.post_layernorm)self.visual_projection = TracingLinear(projection)def forward(self, pixel_values: Optional[torch.Tensor] =None) -> torch.Tensor: embeddings =self.embeddings(pixel_values=pixel_values) _, image_regions, _ = embeddings.shape embeddings_traced = torch.zeros(*embeddings.shape, image_regions, device=embeddings.device )# 0 is the class embedding, that is not related to the image so it goes# in the constant index (last index)for i inrange(image_regions -1): embeddings_traced[:, i +1, :, i] = embeddings[:, i +1, :] embeddings_traced[:, 0, :, -1] = embeddings[:, 0, :] embeddings_normalized =self.pre_layernorm(embeddings_traced) encoded =self.encoder( inputs_embeds=embeddings_normalized, output_attentions=False, output_hidden_states=False, return_dict=False, ) encoded = encoded[0] encoded = encoded[:, 0, :] encoded =self.post_layernorm(encoded)returnself.visual_projection(encoded)class TracingCLIPEncoderLayer(nn.Module):def__init__(self, layer: CLIPEncoderLayer) ->None:super().__init__()self.self_attn = TracingCLIPAttention(layer.self_attn)self.layer_norm1 = TracingActivation(layer.layer_norm1)self.mlp = TracingCLIPMLP(layer.mlp)self.layer_norm2 = TracingActivation(layer.layer_norm2)def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], causal_attention_mask: Optional[torch.Tensor], output_attentions: bool, ) -> torch.Tensor: residual = hidden_states hidden_states =self.layer_norm1(hidden_states) hidden_states, attn_weights =self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, causal_attention_mask=causal_attention_mask, output_attentions=output_attentions, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states =self.layer_norm2(hidden_states) hidden_states =self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,)if output_attentions: outputs += (attn_weights,)return outputsclass TracingCLIPAttention(nn.Module):"""Multi-headed attention from 'Attention Is All You Need' paper"""def__init__(self, layer: CLIPAttention) ->None:super().__init__()self.layer = layerself.k_proj = TracingLinear(layer.k_proj)self.v_proj = TracingLinear(layer.v_proj)self.q_proj = TracingLinear(layer.q_proj)self.out_proj = TracingLinear(layer.out_proj)self.bmm = TracingBMM()self.softmax = TracingActivation(nn.Softmax(dim=-1))def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] =None, causal_attention_mask: Optional[torch.Tensor] =None, output_attentions: Optional[bool] =False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:"""Input shape: Batch x Time x Channel x Tracing""" bsz, tgt_len, embed_dim, tracking_dim = hidden_states.size()# get query proj query_states =self.q_proj(hidden_states) *self.layer.scale key_states =self._shape(self.k_proj(hidden_states), -1, bsz, tracking_dim) value_states =self._shape(self.v_proj(hidden_states), -1, bsz, tracking_dim) proj_shape = (bsz *self.layer.num_heads, -1, self.layer.head_dim, tracking_dim) query_states =self._shape(query_states, tgt_len, bsz, tracking_dim).view(*proj_shape ) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) src_len = key_states.size(1) attn_weights =self.bmm(query_states, key_states.transpose(1, 2))# code has been cut out here which relates to the# causal_attention_mask, attention_mask and error checkingassert causal_attention_mask isNoneassert attention_mask isNone attn_weights =self.softmax(attn_weights)if output_attentions:# this operation is a bit akward, but it's required to# make sure that attn_weights keeps its gradient.# In order to do so, attn_weights have to reshaped# twice and have to be reused in the following attn_weights_reshaped = attn_weights.view( bsz, self.layer.num_heads, tgt_len, src_len, tracking_dim ) attn_weights = attn_weights_reshaped.view( bsz *self.layer.num_heads, tgt_len, src_len, tracking_dim )else: attn_weights_reshaped =None# this is not intended for training so dropout is removed entirely# attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_probs = attn_weights attn_output =self.bmm(attn_probs, value_states) attn_output = attn_output.view( bsz, self.layer.num_heads, tgt_len, self.layer.head_dim, tracking_dim ) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, tgt_len, embed_dim, tracking_dim) attn_output =self.out_proj(attn_output)return attn_output, attn_weights_reshapeddef _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int, tracking_dim: int):return ( tensor.view( bsz, seq_len, self.layer.num_heads, self.layer.head_dim, tracking_dim ) .transpose(1, 2) .contiguous() )class TracingCLIPMLP(nn.Module):def__init__(self, layer: CLIPMLP) ->None:super().__init__()self.activation_fn = TracingActivation(layer.activation_fn)self.fc1 = TracingLinear(layer.fc1)self.fc2 = TracingLinear(layer.fc2)def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states =self.fc1(hidden_states) hidden_states =self.activation_fn(hidden_states) hidden_states =self.fc2(hidden_states)return hidden_statesclass TracingLinear(nn.Module):def__init__(self, linear: nn.Linear) ->None:super().__init__()self.weight = linear.weight[:, :, None]self.bias = linear.biasdef forward(self, xs: torch.Tensor) -> torch.Tensor:iflen(xs.shape) ==3: mm =self._matrix_multiply_3d(xs)else: mm =self._matrix_multiply_4d(xs)ifself.bias isnotNone: flat = mm.reshape(-1, *mm.shape[-2:]) flat[:, :, -1] +=self.bias mm = flat.reshape(*mm.shape)return mmdef _matrix_multiply_3d(self, xs: torch.Tensor) -> torch.Tensor:return torch.einsum("bik,jik->bjk", xs, self.weight)def _matrix_multiply_4d(self, xs: torch.Tensor) -> torch.Tensor:return torch.einsum("Bbik,Bjik->Bbjk", xs, self.weight[None])class TracingBMM(nn.Module):def forward(self, xs: torch.Tensor, ys: torch.Tensor) -> torch.Tensor: x_rollup = xs.sum(dim=-1).unsqueeze(-1) y_rollup = ys.sum(dim=-1).unsqueeze(-1) x_bmm = torch.einsum("bnmi,bmpi->bnpi", xs, y_rollup) y_bmm = torch.einsum("bnmi,bmpi->bnpi", x_rollup, ys) bmm = x_bmm + y_bmm bmm = bmm /2return bmmclass TracingActivation(nn.Module):def__init__(self, activation_function: nn.Module) ->None:super().__init__()self.activation_function = activation_functiondef forward(self, composite_input_value: torch.Tensor) -> torch.Tensor: input_value = composite_input_value.sum(dim=-1) output_value =self.activation_function(input_value)# calculate the input magnitude and output magnitude for scaling each composite value set flat_composite_input_value = composite_input_value.reshape(-1, composite_input_value.shape[-1] ) composite_input_magnitude = ( flat_composite_input_value[:, :-1].abs().max(dim=-1)[0] ) output_magnitude = output_value.abs().flatten()# calculate and apply the ratio ratio = output_magnitude / composite_input_magnitude flat_composite_output_value = flat_composite_input_value * ratio[:, None]# wipe out the values where the ratio would be nan, there is no input influence to track flat_composite_output_value[composite_input_magnitude ==0, :] =0.0# check that the magnitude has been correctly scaled composite_output_magnitude = ( flat_composite_output_value[:, :-1].abs().max(dim=-1)[0] ) magnitude_difference = (composite_output_magnitude - output_magnitude).max()if magnitude_difference >1e-5:raiseAssertionError(f"magnitude check failed: {magnitude_difference:0.4g}")# use the unattributed constant value to ensure that the sum of the# composite value equals the original value composite_offset = flat_composite_output_value[:, :-1].sum(dim=-1) flat_composite_output_value[:, -1] = output_value.flatten() - composite_offset composite_output_value = flat_composite_output_value.reshape( composite_input_value.shape )return composite_output_value
The first check is to run the model over the same cat image:
This will be run through the original CLIP model and the new tracing CLIP model. Then we will compare the original output to the rolled up traced output. Ideally this version of the traced model will be more consistent with the original model.
Code
import warningsimport pandas as pdwith warnings.catch_warnings():# transformers/models/clip/processing_clip.py:142: FutureWarning:# `feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead. warnings.simplefilter("ignore") pixel_values = preprocess_image("cat.jpg") original_embedding = get_output(pixel_values) traced_embedding = get_traced_output(pixel_values) tracing_difference = (original_embedding - traced_embedding.sum(dim=-1)).abs()pd.DataFrame([ {"model type": "original","maximum": original_embedding.max().item(),"mean": original_embedding.mean().item(), }, {"model type": "tracing","maximum": traced_embedding.sum(dim=-1).max().item(),"mean": traced_embedding.sum(dim=-1).mean().item(), }, {"model type": "difference","maximum": tracing_difference.max().item(),"mean": tracing_difference.mean().item(), }])
model type
maximum
mean
0
original
1.693860
0.000684
1
tracing
1.693868
0.000684
2
difference
0.000008
0.000002
This is wild. These values are nearly identical.
I feel way better about this approach now!
Finally let’s try to find which parts of the image contribute to the cat classification. We originally found that the prompt a rendering of a cat. was the best match for the image. We can use the weights from that text embedding to find what parts of the image contribute most to the cat classification.
Code
import warningsimport pandas as pdimport torchimport torch.nn.functional as Fdef show_weights(image: str, prompt: str) ->None:with warnings.catch_warnings():# transformers/models/clip/processing_clip.py:142: FutureWarning:# `feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead. warnings.simplefilter("ignore") pixel_values = preprocess_image(image) traced_embedding = get_traced_output(pixel_values) traced_embedding = traced_embedding[0] text_embedding = get_text_embedding(prompt) text_embedding = text_embedding[0] similarity = F.cosine_similarity(traced_embedding.sum(dim=-1), text_embedding, dim=0)print(f"cosine similarity: {similarity:0.4g}") weights = traced_embedding[:, :-1] * text_embedding[:, None] weights = weights.T.sum(dim=-1) weights = weights.sigmoid() weights = weights *2-1 show_region_attention( pixel_values, F.relu(weights), )
show_weights("cat.jpg", "a rendering of a cat.")
cosine similarity: 0.2972
This is a way better result. The face is what i would expect to contribute most to the classification, and this supports that.
Is this just the areas of the image that had any activation at all? We can check that quite quickly:
import torchweights = traced_embedding[0, :, :-1]weights = weights.T.sum(dim=-1)# could show just the positive activations with this:# weights[weights <= 0] = 0.weights = weights.sigmoid()show_region_attention( pixel_values, weights, scale_attention=False,)
The use of the text prompt does heavily alter the regions of the image that are used. This is consistent with the fact that there are both positive and negative values in the embedding - when the sign aligns that index is being selected.
Is this approach transferrable? Does it work well for other subjects or images?
show_weights("baseball.jpg", "a photo of a baseball player.")
cosine similarity: 0.2679
show_weights("baseball.jpg", "a centered satellite photo of permanent crop land.")
cosine similarity: 0.1259
The classification of the correct label produces a stronger activation over the interesting parts of the image. When the wrong label is used the overlap between the activations is weaker.
Ultimately the image is processed without consideration for the prompt, so the areas that produce strong output for a given embedding index are likely to show up as being interesting parts of the source image.
Further Work
This has implemented tracking for attention. Attention is used extensively in NLP. It would be very interesting to see how this works with something like a sentiment model - does it show the words or phrases correlated with the document sentiment?