Explaining CLIP Output

What parts of an image contribute to the CLIP classification?
image classification
clip
Published

November 19, 2022

CLIP (Radford et al. 2021) is a pair of models that produces an embedding from an image or from some text. The training involved aligning the output between an image and it’s caption.

Radford, Alec, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, et al. 2021. “Learning Transferable Visual Models from Natural Language Supervision.” arXiv. https://doi.org/10.48550/ARXIV.2103.00020.

We can use this to classify images by providing a prompt that describes some images and seeing how close the match is with a given image. This is how the original code classifies datasets like imagenet (see the example notebook for details).

What I would like to do is to determine what parts of an image contribute most to a given classification.

Example Inference

Here we are going to calculate the image similarity to a lot of different prompts. Then we can recreate the process of calculating the similarity to understand more about how CLIP works.

“Dataset”

To perform these inferences I need a dataset. I’m going to use this photo of a cat:

a cat

Image Classification

I am going to describe the content of this image. To do this I am going to use the prompt templates from the CLIP ImageNet prompts along with 10 different classes (big code block…):

Code
imagenet_classes = [
    "cat",
    "dog",
    "bird",
    "rat",
    "building",
    "tree",
    "man",
    "woman",
    "boy",
    "girl",
]

imagenet_templates = [
    'a bad photo of a {}.',
    'a photo of many {}.',
    'a sculpture of a {}.',
    'a photo of the hard to see {}.',
    'a low resolution photo of the {}.',
    'a rendering of a {}.',
    'graffiti of a {}.',
    'a bad photo of the {}.',
    'a cropped photo of the {}.',
    'a tattoo of a {}.',
    'the embroidered {}.',
    'a photo of a hard to see {}.',
    'a bright photo of a {}.',
    'a photo of a clean {}.',
    'a photo of a dirty {}.',
    'a dark photo of the {}.',
    'a drawing of a {}.',
    'a photo of my {}.',
    'the plastic {}.',
    'a photo of the cool {}.',
    'a close-up photo of a {}.',
    'a black and white photo of the {}.',
    'a painting of the {}.',
    'a painting of a {}.',
    'a pixelated photo of the {}.',
    'a sculpture of the {}.',
    'a bright photo of the {}.',
    'a cropped photo of a {}.',
    'a plastic {}.',
    'a photo of the dirty {}.',
    'a jpeg corrupted photo of a {}.',
    'a blurry photo of the {}.',
    'a photo of the {}.',
    'a good photo of the {}.',
    'a rendering of the {}.',
    'a {} in a video game.',
    'a photo of one {}.',
    'a doodle of a {}.',
    'a close-up photo of the {}.',
    'a photo of a {}.',
    'the origami {}.',
    'the {} in a video game.',
    'a sketch of a {}.',
    'a doodle of the {}.',
    'a origami {}.',
    'a low resolution photo of a {}.',
    'the toy {}.',
    'a rendition of the {}.',
    'a photo of the clean {}.',
    'a photo of a large {}.',
    'a rendition of a {}.',
    'a photo of a nice {}.',
    'a photo of a weird {}.',
    'a blurry photo of a {}.',
    'a cartoon {}.',
    'art of a {}.',
    'a sketch of the {}.',
    'a embroidered {}.',
    'a pixelated photo of a {}.',
    'itap of the {}.',
    'a jpeg corrupted photo of the {}.',
    'a good photo of a {}.',
    'a plushie {}.',
    'a photo of the nice {}.',
    'a photo of the small {}.',
    'a photo of the weird {}.',
    'the cartoon {}.',
    'art of the {}.',
    'a drawing of the {}.',
    'a photo of the large {}.',
    'a black and white photo of a {}.',
    'the plushie {}.',
    'a dark photo of a {}.',
    'itap of a {}.',
    'graffiti of the {}.',
    'a toy {}.',
    'itap of my {}.',
    'a photo of a cool {}.',
    'a photo of a small {}.',
    'a tattoo of the {}.',
]

print(f"{len(imagenet_classes)} classes, {len(imagenet_templates)} templates")
10 classes, 80 templates

I’ve had to cut this down to 10 classes just to spare my memory. The next thing is to calculate the similarity for each of these prompts:

Code
from PIL import Image
import pandas as pd
import torch

from transformers import CLIPProcessor, CLIPModel

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# url = "http://images.cocodataset.org/val2017/000000039769.jpg" two cats!
# image = Image.open(requests.get(url, stream=True).raw)
image = Image.open("cat.jpg")

descriptions = {
    name: [template.format(name) for template in imagenet_templates]
    for name in imagenet_classes
}
prompts = {
    prompt: label
    for label, prompts in descriptions.items()
    for prompt in prompts
}

with torch.inference_mode():
    inputs = processor(
        text=list(prompts.keys()),
        images=image,
        return_tensors="pt",
        padding=True,
    )

    outputs = model(**inputs)

logits_per_image = outputs.logits_per_image  # this is the image-text similarity score
probability_df = pd.DataFrame([
    {"label": label, "type": prompts[label], "probability": probability}
    for label, probability in zip(prompts.keys(), logits_per_image.softmax(dim=1)[0].tolist())
])

probability_df.sort_values(by="probability", ascending=False).head()
label type probability
5 a rendering of a cat. cat 0.090087
39 a photo of a cat. cat 0.076132
70 a black and white photo of a cat. cat 0.071223
78 a photo of a small cat. cat 0.063755
45 a low resolution photo of a cat. cat 0.060572
Code
logits_per_image[0, 5], logits_per_image.softmax(dim=1)[0, 5]
(tensor(29.7150), tensor(0.0901))

We can see that out of the 800 prompts the top 5 are all for variations of cat prompts. So it has correctly classified the image as one of a cat.

Understanding the Model

How does this actually work?

The important thing to realise is that there are two models under the hood, the text one and the image one. That also means that the preprocessing performed by the processor is actually two preprocessors, one for image and one for text. Then the embeddings for the two preprocessed forms are calculated and the similarity for each is produced.

We can break this down into the individual parts and reproduce that similarity score of 29.

Code
image_features = processor.feature_extractor(image, return_tensors="pt")
text_features = processor.tokenizer(["a rendering of a cat."], return_tensors="pt")
/home/matthew/.local/share/virtualenvs/blog-1tuLwbZm/lib/python3.10/site-packages/transformers/models/clip/processing_clip.py:142: FutureWarning: `feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.
  warnings.warn(
Code
import torch

with torch.inference_mode():
    vision_outputs = model.vision_model(**image_features)
    text_outputs = model.text_model(**text_features)

    image_embeds = vision_outputs[1] # pooler_output
    image_embeds = model.visual_projection(image_embeds)

    text_embeds = text_outputs[1] # pooler_output
    text_embeds = model.text_projection(text_embeds)

image_embeds.shape, text_embeds.shape
(torch.Size([1, 512]), torch.Size([1, 512]))
Code
import torch

with torch.inference_mode():
    # normalized features
    normalized_image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
    normalized_text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)

    similarity = torch.matmul(
        normalized_text_embeds,
        normalized_image_embeds.t()
    )
    print(f"unscaled similarity: {similarity[0,0]:0.3f}")
    print(f"scaled similarity: {(similarity * model.logit_scale.exp())[0,0]:0.3f}")
unscaled similarity: 0.297
scaled similarity: 29.715

This is actually calculating the cosine similarity, and boosting it using the model.logits_scale.exp() factor. We can easily prove this by just calculating the cosine similarity over the unscaled embeddings.

Code
import torch.nn.functional as F

F.cosine_similarity(text_embeds, image_embeds)
tensor([0.2971])

For some reason they have lifted the calculation into the model. The scaling factor is more reasonable as that alters the probabilities that are produced by the softmax, making them more extreme as we can see below:

Code
import torch
import pandas as pd

basic_similiarty = (torch.rand(10) - 0.5) * 2 # simulated similarities
pd.DataFrame({
    "similarity": basic_similiarty.tolist(),
    "unscaled": basic_similiarty.softmax(dim=0).tolist(),
    "scaled": (basic_similiarty * 100).softmax(dim=0).tolist(),
}).sort_values(by="unscaled")
similarity unscaled scaled
6 -0.762224 0.032671 0.000000e+00
7 -0.416272 0.046175 0.000000e+00
8 0.038008 0.072727 1.147126e-35
3 0.075670 0.075518 4.957201e-34
0 0.267258 0.091466 1.037075e-25
9 0.484861 0.113701 2.925587e-16
1 0.502547 0.115729 1.715108e-15
2 0.716907 0.143397 3.497896e-06
5 0.735059 0.146023 2.148565e-05
4 0.842540 0.162593 9.999751e-01

By scaling the random starting cosine similarities we can see that we move from a 24% confidence in the top class to a 99% confidence.

What Contributed to the Cat Classification?

If we imagine that the cosine similarity between the text embedding and the image embedding is a linear classification head, like the one that exists in ResNet, then we could try to work out what contributed to the classification. This has been done for ResNet before and has resulted in nice heatmaps that show the activations of different parts of the image. To make the heatmaps the unpooled output of the model was mapped back onto the image.

The image part of CLIP is not like ResNet though. It uses visual attention blocks, after converting the image to 32x32 blocks through a convolution layer. Attention blocks are fully connected so the outputs of a single layer can come from any part of the image.

The amount of math needed to reverse this seems feasible but large. It would be nice if there was a quicker way to calculate this.

Cosine similarity is a scaled dot product. The dot product is like a linear layer without a weight. So if we take the text embedding as our weights (it is the classifier after all), and the image embeddings as the activation, then looking for the most extreme text embedding values would be a good indication of what “features” contribute most to the classification.

Code
import pandas as pd

text_series = pd.Series(text_embeds[0].tolist())
(
    text_series
        .sort_values()
        .reset_index(drop=True)
        .plot()
)
(
    text_series[text_series.abs() >= 0.75]
        .to_frame()
        .rename(columns={0: "value"})
)
value
7 -0.909124
92 -1.150873
121 -1.046691
133 5.752840
211 0.824418
312 5.752370
329 -1.073854
493 0.913976

Most of the values cluster around zero and very few values are actually large. If I could find a way to enhance those values and trace it back to the image then that might show what is contributing to the classification.

We can review the image as well to see if it also has a very polarized classification.

Code
import pandas as pd

image_series = pd.Series(image_embeds[0].tolist())
(
    image_series
        .sort_values()
        .reset_index(drop=True)
        .plot()
)
(
    image_series[image_series.abs() >= 1]
        .to_frame()
        .rename(columns={0: "value"})
)
value
39 1.293888
92 -7.000751
98 -1.338176
198 -1.162172
258 -1.042851
321 1.693860
376 1.234921
389 -1.041473

Code
overlapping_strong_indices = set(image_series[image_series.abs() >= 1].index) & set(text_series[text_series.abs() >= 1].index)
print(f"The overlapping strong indices are: {sorted(overlapping_strong_indices)}")
The overlapping strong indices are: [92]

This time the image has more strongly negative values. What’s interesting here is that the strong values only overlap on index 92. I had thought that index 133 was the most important, but it seems not!

To do this I have a plan. If I were to ask the model to enhance the output values for these indices then it might change the areas that currently do not contribute to the classification. However if I ask the model to reduce the output values then it should identify the areas that do contribute to the classification, as it would be interested in cutting them out.

This is fundamentally a similar approach to what I did for the prompt training, except in reverse. Let’s try it out.

Broad Deoptimization Tracking

The aim is to find the parts of the input image which contribute to the classification. We don’t have to be incredibly precise, as ultimately the combination of pixels is what produces an image feature. This makes me think that working with the output of the strided convolution would be good.

To remind ourselves, the structure of the model is as follows:

The output of the 32x32 convolution is tightly associated with the source image, while smoothing over the individual pixels. Trying to “optimize” the output of this and tracking the locations that change should give an idea of what the model is focusing on. To achieve this I need to be able to pass in the convoluted image, which will require adjusting the embedding layer.

Code
from typing import Optional, Union, Tuple
import torch
from transformers.modeling_outputs import BaseModelOutputWithPooling
from transformers.models.clip.modeling_clip import CLIPVisionTransformer, CLIPVisionEmbeddings

# Copied from CLIPVisionTransformer.forward
def forward(
    vision_model: CLIPVisionTransformer,
    # pixel_values: Optional[torch.FloatTensor] = None,
    convoluted_pixels: torch.FloatTensor,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
    output_attentions = output_attentions if output_attentions is not None else vision_model.config.output_attentions
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else vision_model.config.output_hidden_states
    )
    return_dict = return_dict if return_dict is not None else vision_model.config.use_return_dict

    # hidden_states = vision_model.embeddings(pixel_values)
    hidden_states = _embeddings(
        vision_model.embeddings,
        convoluted_pixels,
    )
    hidden_states = vision_model.pre_layrnorm(hidden_states)

    encoder_outputs = vision_model.encoder(
        inputs_embeds=hidden_states,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )

    last_hidden_state = encoder_outputs[0]
    pooled_output = last_hidden_state[:, 0, :]
    pooled_output = vision_model.post_layernorm(pooled_output)

    if not return_dict:
        return (last_hidden_state, pooled_output) + encoder_outputs[1:]

    return BaseModelOutputWithPooling(
        last_hidden_state=last_hidden_state,
        pooler_output=pooled_output,
        hidden_states=encoder_outputs.hidden_states,
        attentions=encoder_outputs.attentions,
    )

# Copied from CLIPVisionEmbeddings.forward
def _embeddings(
    embedding_model: CLIPVisionEmbeddings,
    # pixel_values: torch.FloatTensor,
    convoluted_pixels: torch.FloatTensor,
) -> torch.Tensor:
    # batch_size = pixel_values.shape[0]
    # patch_embeds = self.patch_embedding(pixel_values)  # shape = [*, width, grid, grid]
    batch_size = convoluted_pixels.shape[0]
    patch_embeds = convoluted_pixels.flatten(2).transpose(1, 2)

    class_embeds = embedding_model.class_embedding.expand(batch_size, 1, -1)
    embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
    embeddings = embeddings + embedding_model.position_embedding(embedding_model.position_ids)
    return embeddings
Code
from typing import Callable
import torch
from torch.optim import SGD
from transformers.models.clip.modeling_clip import CLIPModel
from transformers.models.clip.processing_clip import CLIPProcessor

def regions_of_interest(
    processor: CLIPProcessor,
    model: CLIPModel,
    image: Image.Image,
    loss_fn: Callable = lambda output: output[0, 92] ** 2, # overlapping index between image and text embedding
    steps: int = 1,
) -> torch.Tensor:
    image_features = processor.feature_extractor(image, return_tensors="pt")

    with torch.no_grad():
        convoluted_pixels = model.vision_model.embeddings.patch_embedding(
            image_features.pixel_values
        )
    original_convoluted_pixels = convoluted_pixels.detach().clone()
    convoluted_pixels = torch.nn.Parameter(convoluted_pixels)

    optimizer = SGD(
        [convoluted_pixels],
        lr=0.1,
    )
    for _ in range(steps):
        optimizer.zero_grad()

        interim_output = forward(
            vision_model=model.vision_model,
            convoluted_pixels=convoluted_pixels,
        ).pooler_output
        output = model.visual_projection(interim_output)

        loss = loss_fn(output)
        loss.backward()
        optimizer.step()

    return (original_convoluted_pixels - convoluted_pixels.data)
Code
interest_92 = regions_of_interest(
    processor=processor,
    model=model,
    image=image,
    steps=1,
)

interest_92.max()
/home/matthew/.local/share/virtualenvs/blog-1tuLwbZm/lib/python3.10/site-packages/transformers/models/clip/processing_clip.py:142: FutureWarning: `feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.
  warnings.warn(
tensor(1.1655)
Code
import matplotlib.pyplot as plt

fig, ax = plt.subplots()
im = ax.imshow(
    interest_92
        .abs()
        .max(dim=1)
        .values[0]
        .numpy()
)

Looking at this doesn’t clearly show me what in the image it is looking at. I need a way to overlay this onto the original image. The approach of using the post convolution pixels is also questionable to me, as there are so many channels - how can I select the most appropriate value per region?

Let’s start by trying to convert the pixel values back into the original image. The preprocessor does a center crop, a resize, and then normalizes the pixel values. If I can reverse that then I can see what the model sees.

Code
from PIL import Image
from torchvision.transforms import ToPILImage

image_mean = torch.tensor([
    0.48145466,
    0.4578275,
    0.40821073
])
image_std = torch.tensor([
    0.26862954,
    0.26130258,
    0.27577711
])

def show_pixels(values: torch.Tensor) -> Image.Image:
    values = values[0]
    values = (values * image_std[:, None, None]) + image_mean[:, None, None]
    return ToPILImage()(values)
Code
show_pixels(
    processor.feature_extractor(image, return_tensors="pt").pixel_values
)

The model has clipped this image very well as it is clearly still a cat. If I was to scale the different sections of this image so that the parts of the image it does not want to change are faded, then we should be able to see what it wants to change.

Code
from PIL import Image
from torchvision.transforms import ToPILImage

image_mean = torch.tensor([
    0.48145466,
    0.4578275,
    0.40821073
])
image_std = torch.tensor([
    0.26862954,
    0.26130258,
    0.27577711
])

def show_region_attention(values: torch.Tensor, attention: torch.Tensor) -> None:
    # need to resize attention to match the values
    attention = attention.to(float)
    attention = attention[0]
    attention = attention.repeat_interleave(32, dim=0).repeat_interleave(32, dim=1)
    attention = attention - attention.min()
    attention = attention / attention.max()
    if len(attention.shape) == 2:
        attention = attention[None]
    
    values = values[0]
    values = (values * image_std[:, None, None]) + image_mean[:, None, None]
    values = values * attention
    display(ToPILImage()(values))

def most_varied_channel(attention: torch.Tensor) -> int:
    # this finds the channel that has the greatest variation between min and max value
    attention = attention.reshape(768, -1)
    value_range = attention.max(dim=1).values - attention.min(dim=1).values
    return value_range.argmax().item()
Code
print("92 feature: mean of all channels")
show_region_attention(
    image_features.pixel_values,
    interest_92.mean(dim=1),
)
print("92 feature: max of all channels")
show_region_attention(
    image_features.pixel_values,
    interest_92.max(dim=1).values,
)
print("92 feature: most varied channel")
show_region_attention(
    image_features.pixel_values,
    interest_92[:, most_varied_channel(interest_92)],
)
92 feature: mean of all channels

92 feature: max of all channels

92 feature: most varied channel

When I use the mean it seems that the tips of the ears are most interesting. The most varied channel is harder to interpret. Neither of these is the clear indication that I was hoping for.

I wonder if I am being too selective with the loss calculation.

Code
import torch

with torch.no_grad():
    text_features = processor.tokenizer(["a rendering of a cat."], return_tensors="pt")
    text_outputs = model.text_model(**text_features)
    text_embeds = text_outputs[1] # pooler_output
    text_embeds = model.text_projection(text_embeds)

def no_cat_loss(output: torch.Tensor) -> torch.Tensor:
    return (torch.matmul(output, text_embeds.T) ** 2).sum()

interest_cat = regions_of_interest(
    processor=processor,
    model=model,
    image=image,
    loss_fn=no_cat_loss,
    steps=1,
)

print("anti cat loss: mean of all channels")
show_region_attention(
    image_features.pixel_values,
    interest_cat.mean(dim=1),
)
print("anti cat loss: max of all channels")
show_region_attention(
    image_features.pixel_values,
    interest_cat.max(dim=1).values,
)
print("anti cat loss: most varied channel")
show_region_attention(
    image_features.pixel_values,
    interest_cat[:, most_varied_channel(interest_cat)],
)

interest_cat.max()
anti cat loss: mean of all channels

anti cat loss: max of all channels

anti cat loss: most varied channel

tensor(30.0327)
Code
import torch

with torch.no_grad():
    text_features = processor.tokenizer(["a rendering of a cat."], return_tensors="pt")
    text_outputs = model.text_model(**text_features)
    text_embeds = text_outputs[1] # pooler_output
    text_embeds = model.text_projection(text_embeds)

def positive_cat_loss(output: torch.Tensor) -> torch.Tensor:
    dot_product = torch.matmul(output, text_embeds.T)
    dot_product[dot_product <= 0] = 0.
    return (dot_product ** 2).sum()

interest_cat = regions_of_interest(
    processor=processor,
    model=model,
    image=image,
    loss_fn=positive_cat_loss,
)

print("only cat contributors loss: mean of all channels")
show_region_attention(
    image_features.pixel_values,
    interest_cat.mean(dim=1),
)
print("only cat contributors loss: max of all channels")
show_region_attention(
    image_features.pixel_values,
    interest_cat.max(dim=1).values,
)
print("only cat contributors loss: boosted max of all channels")
show_region_attention(
    image_features.pixel_values,
    (interest_cat.max(dim=1).values * 100).softmax(dim=2)
)

interest_cat.max()
only cat contributors loss: mean of all channels

only cat contributors loss: max of all channels

only cat contributors loss: boosted max of all channels

tensor(30.0327)
Code
import torch

with torch.no_grad():
    text_features = processor.tokenizer(["a rendering of a cat."], return_tensors="pt")
    text_outputs = model.text_model(**text_features)
    text_embeds = text_outputs[1] # pooler_output
    text_embeds = model.text_projection(text_embeds)

def strong_cat_loss(output: torch.Tensor) -> torch.Tensor:
    global text_embeds
    text_embeds = text_embeds.clone()
    text_embeds[text_embeds < 0.75] = 0.
    return (torch.matmul(output, text_embeds.T) ** 2).sum()

interest_cat = regions_of_interest(
    processor=processor,
    model=model,
    image=image,
    loss_fn=strong_cat_loss,
)

print("strong prompt contributors loss: mean of all channels")
show_region_attention(
    image_features.pixel_values,
    interest_cat.mean(dim=1),
)
print("strong prompt contributors loss: max of all channels")
show_region_attention(
    image_features.pixel_values,
    interest_cat.max(dim=1).values,
)
print("strong prompt contributors loss: boosted max of all channels")
show_region_attention(
    image_features.pixel_values,
    (interest_cat.max(dim=1).values * 100).softmax(dim=2)
)

interest_cat.max()
strong prompt contributors loss: mean of all channels

strong prompt contributors loss: max of all channels

strong prompt contributors loss: boosted max of all channels

tensor(2.3929)
Code
import torch

with torch.no_grad():
    text_features = processor.tokenizer(["a rendering of a cat."], return_tensors="pt")
    text_outputs = model.text_model(**text_features)
    text_embeds = text_outputs[1] # pooler_output
    text_embeds = model.text_projection(text_embeds)

def strong_cat_loss(output: torch.Tensor) -> torch.Tensor:
    global text_embeds
    text_embeds = text_embeds.clone()
    text_embeds[text_embeds > -0.75] = 0.
    return (torch.matmul(output, text_embeds.T) ** 2).sum()

interest_cat = regions_of_interest(
    processor=processor,
    model=model,
    image=image,
    loss_fn=strong_cat_loss,
)

print("strong prompt detractors loss: mean of all channels")
show_region_attention(
    image_features.pixel_values,
    interest_cat.mean(dim=1),
)
print("strong prompt detractors loss: max of all channels")
show_region_attention(
    image_features.pixel_values,
    interest_cat.max(dim=1).values,
)
print("strong prompt detractors loss: boosted max of all channels")
show_region_attention(
    image_features.pixel_values,
    (interest_cat.max(dim=1).values * 100).softmax(dim=2)
)

interest_cat.max()
strong prompt detractors loss: mean of all channels

strong prompt detractors loss: max of all channels

strong prompt detractors loss: boosted max of all channels

tensor(1.1274)

I can keep on fiddling with this for some time. It’s tricky to map the 768 channels to a per-region attention. I’ve had some success and the model seems to favour the eye and the chin of the cat.

If I instead operate over the original pixel values as my optimizable target then the code should be simpler and it should be considerably easier to apply to the source image.

Precise Deoptimization Tracking

This time I am just going to optimize the pixel values directly. Doing this should give a permutation over the three channels directly. I could even visualize the change as a picture itself.

Code
from typing import Callable
import torch
from torch.optim import SGD
from transformers.models.clip.modeling_clip import CLIPModel
from transformers.models.clip.processing_clip import CLIPProcessor

def pixels_of_interest(
    processor: CLIPProcessor,
    model: CLIPModel,
    image: Image.Image,
    loss_fn: Callable = lambda output: output[0, 92] ** 2, # overlapping index between image and text embedding
    steps: int = 1,
) -> torch.Tensor:
    image_features = processor.feature_extractor(image, return_tensors="pt")
    original_pixels = image_features.pixel_values.detach().clone()
    pixels = torch.nn.Parameter(image_features.pixel_values)

    optimizer = SGD(
        [pixels],
        lr=0.1,
    )
    for _ in range(steps):
        optimizer.zero_grad()

        interim_output = model.vision_model(pixel_values=pixels).pooler_output
        output = model.visual_projection(interim_output)

        loss = loss_fn(output)
        loss.backward()
        optimizer.step()

    return (original_pixels - pixels.data)
Code
from torchvision.transforms import ToPILImage

interest_92 = pixels_of_interest(
    processor=processor,
    model=model,
    image=image,
)

interest_92 = interest_92 - interest_92.min()
interest_92 = interest_92 / interest_92.max()
ToPILImage()(interest_92[0])

This is not an encouraging start and it shows something of the problem that I wanted to avoid. The convolution layer appears to attend to a very specific area of each region. I can try rerunning this process however I am not hopeful.

Code
from torchvision.transforms import ToPILImage

image_mean = torch.tensor([
    0.48145466,
    0.4578275,
    0.40821073
])
image_std = torch.tensor([
    0.26862954,
    0.26130258,
    0.27577711
])

def show_pixel_attention(values: torch.Tensor, attention: torch.Tensor) -> None:
    attention = attention.to(float)
    attention = attention[0]
    attention = attention - attention.min()
    attention = attention / attention.max()
    if attention.shape == 2:
        attention = attention[None]
    
    values = values[0]
    values = (values * image_std[:, None, None]) + image_mean[:, None, None]
    values = values * attention
    display(ToPILImage()(values))
Code
print("92 feature: RGB interest")
show_pixel_attention(
    image_features.pixel_values,
    interest_92,
)
92 feature: RGB interest

Code
import torch

with torch.no_grad():
    text_features = processor.tokenizer(["a rendering of a cat."], return_tensors="pt")
    text_outputs = model.text_model(**text_features)
    text_embeds = text_outputs[1] # pooler_output
    text_embeds = model.text_projection(text_embeds)
    text_embeds[text_embeds < 0.75] = 0.

def cat_loss(output: torch.Tensor) -> torch.Tensor:
    return (torch.matmul(output, text_embeds.T) ** 2).sum()

interest_cat = pixels_of_interest(
    processor=processor,
    model=model,
    image=image,
    loss_fn=cat_loss,
)

print("cat loss: RGB interest")
show_pixel_attention(
    image_features.pixel_values,
    interest_cat,
)

print("cat loss: threshold RGB interest")
show_pixel_attention(
    image_features.pixel_values,
    (interest_cat > 0.0).to(float).mean(dim=1),
)
cat loss: RGB interest

cat loss: threshold RGB interest

Code
import torch

with torch.no_grad():
    text_features = processor.tokenizer(["a rendering of a cat."], return_tensors="pt")
    text_outputs = model.text_model(**text_features)
    text_embeds = text_outputs[1] # pooler_output
    text_embeds = model.text_projection(text_embeds)

def positive_cat_loss(output: torch.Tensor) -> torch.Tensor:
    dot_product = torch.matmul(output, text_embeds.T)
    dot_product[dot_product <= 0] = 0.
    return (dot_product ** 2).sum()

interest_cat = pixels_of_interest(
    processor=processor,
    model=model,
    image=image,
    loss_fn=positive_cat_loss,
)

print("only cat contributors loss: mean of all channels")
show_pixel_attention(
    image_features.pixel_values,
    interest_cat.mean(dim=1),
)
print("only cat contributors loss: max of all channels")
show_pixel_attention(
    image_features.pixel_values,
    interest_cat.max(dim=1).values,
)
print("only cat contributors loss: boosted max of all channels")
show_pixel_attention(
    image_features.pixel_values,
    (interest_cat.max(dim=1).values * 100).softmax(dim=2)
)

interest_cat.max()
only cat contributors loss: mean of all channels

only cat contributors loss: max of all channels

only cat contributors loss: boosted max of all channels

tensor(5.7169)
Code
import torch

with torch.no_grad():
    text_features = processor.tokenizer(["a rendering of a cat."], return_tensors="pt")
    text_outputs = model.text_model(**text_features)
    text_embeds = text_outputs[1] # pooler_output
    text_embeds = model.text_projection(text_embeds)

def strong_cat_loss(output: torch.Tensor) -> torch.Tensor:
    global text_embeds
    text_embeds = text_embeds.clone()
    text_embeds[text_embeds < 0.75] = 0.
    return (torch.matmul(output, text_embeds.T) ** 2).sum()

interest_cat = pixels_of_interest(
    processor=processor,
    model=model,
    image=image,
    loss_fn=strong_cat_loss,
)

print("strong prompt contributors loss: mean of all channels")
show_pixel_attention(
    image_features.pixel_values,
    interest_cat.mean(dim=1),
)
print("strong prompt contributors loss: max of all channels")
show_pixel_attention(
    image_features.pixel_values,
    interest_cat.max(dim=1).values,
)
print("strong prompt contributors loss: boosted max of all channels")
show_pixel_attention(
    image_features.pixel_values,
    (interest_cat.max(dim=1).values * 100).softmax(dim=2)
)

interest_cat.max()
strong prompt contributors loss: mean of all channels

strong prompt contributors loss: max of all channels

strong prompt contributors loss: boosted max of all channels

tensor(0.7629)

So this hasn’t worked. The per pixel values are tied too heavily to the specific convolutions that are being done. Furthermore I’m not confident that this approach is the best way in general.

Instead of relying on back propagation to move the regions I need a way that tracks the contribution of each region through the model.

Forward Tracing

I want to be able to determine the influence of each region on the outcome. To do this I need to be able to track the contribution that each region has at each point in the model. This feels like a large task, so I wonder if there is a way to do it in pytorch that preserves information.

The aim would be to express each tensor value as a vector of values where the sum of all of them is the original value. Each individual value of the vector would be the value contribution from a specific region. With a bit of broadcasting wrangling it should be possible to get this working with the original model parameters.

The problem is that the torch code is highly optimized so “merely” patching the tensor type is not going to work out. This may involve reimplementing the attention layers. Luckily the model has very little variation in it, so if I can reimplement one layer then I can reimplement all of them.

To understand what’s involved in reimplementing the model, let’s review the model. The part that we are most concerned about right now is the repeating body of the model, which looks like this:

Code
model.vision_model.encoder.layers[0]
CLIPEncoderLayer(
  (self_attn): CLIPAttention(
    (k_proj): Linear(in_features=768, out_features=768, bias=True)
    (v_proj): Linear(in_features=768, out_features=768, bias=True)
    (q_proj): Linear(in_features=768, out_features=768, bias=True)
    (out_proj): Linear(in_features=768, out_features=768, bias=True)
  )
  (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (mlp): CLIPMLP(
    (activation_fn): QuickGELUActivation()
    (fc1): Linear(in_features=768, out_features=3072, bias=True)
    (fc2): Linear(in_features=3072, out_features=768, bias=True)
  )
  (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)

To be able to trace through this I want to be able to pass in a tensor that has a value which is split between each source region. If I can operate over that by multiplying it by the model parameters appropriately broadcast then I can get the tracing working.

There are the following kinds of operations that need to be replicated:

  • Linear Layer
  • Batchwize Matrix Multiplication
  • View / Transpose / Contiguous
  • Activation Functions (Softmax / GELU)

Traceable Linear Layer

The majority of these layers are linear. Attention also has some internal operations which are not represented by these modules which will need to be replicated.

A linear layer is a matrix multiplication combined with an addition. The multiplication can be expanded using Einstein summation. Addition needs to be scaled such that it affects each contribution is changed by the appropriate amount. Another approach to addition is to have an additional index which holds such fixed offsets.

Code
import torch
from torch import nn

linear_layer = model.vision_model.encoder.layers[0].mlp.fc1
input_tensor = torch.rand(2, 768)

@torch.no_grad()
def track(t: torch.Tensor) -> torch.Tensor:
    # 49 values for each region of the input image, sum to 1 per t value
    tracking = torch.rand(*t.shape, 7*7)
    tracking = tracking.softmax(dim=-1)
    t_tracked = tracking * torch.unsqueeze(t, -1)
    return t_tracked

@torch.no_grad()
def untracked_linear(t: torch.Tensor, linear: nn.Linear) -> torch.Tensor:
    return linear(t)

@torch.no_grad()
def einsum_linear_with_bias(t: torch.Tensor, linear: nn.Linear) -> torch.Tensor:
    return torch.einsum("bi,ji->bj", t, linear.weight) + linear.bias

@torch.no_grad()
def einsum_linear(t: torch.Tensor, linear: nn.Linear) -> torch.Tensor:
    return torch.einsum("bi,ji->bj", t, linear.weight)

@torch.no_grad()
def tracked_linear(t: torch.Tensor, linear: nn.Linear) -> torch.Tensor:
    t_tracked = track(t)
    weight = linear.weight[:, :, None]
    result = torch.einsum("bik,jik->bjk", t_tracked, weight)
    return result.sum(dim=2)

@torch.no_grad()
def einsum_linear_with_bias_3d(t: torch.Tensor, linear: nn.Linear) -> torch.Tensor:
    return torch.einsum("Bbi,Bji->Bbj", t, linear.weight[None]) + linear.bias

@torch.no_grad()
def einsum_linear_3d(t: torch.Tensor, linear: nn.Linear) -> torch.Tensor:
    return torch.einsum("Bbi,Bji->Bbj", t, linear.weight[None])

@torch.no_grad()
def tracked_linear_3d(t: torch.Tensor, linear: nn.Linear) -> torch.Tensor:
    t_tracked = track(t)
    weight = linear.weight[:, :, None]
    result = torch.einsum("Bbik,Bjik->Bbjk", t_tracked, weight[None])
    return result.sum(dim=-1)

print(f"linear layer has bias:   {linear_layer.bias is not None}")
print(f"untracked linear:        {untracked_linear(input_tensor, linear_layer)[0, :5]}")
print(f"einsum linear with bias: {einsum_linear_with_bias(input_tensor, linear_layer)[0, :5]}")
print(f"einsum linear:           {einsum_linear(input_tensor, linear_layer)[0, :5]}")
print(f"tracked linear:          {tracked_linear(input_tensor, linear_layer)[0, :5]}")

input_tensor = input_tensor.reshape(1, 2, 768)
print(f"3d untracked linear:        {untracked_linear(input_tensor, linear_layer)[0, 0, :5]}")
print(f"3d einsum linear with bias: {einsum_linear_with_bias_3d(input_tensor, linear_layer)[0, 0, :5]}")
print(f"3d einsum linear:           {einsum_linear_3d(input_tensor, linear_layer)[0, 0, :5]}")
print(f"3d tracked linear:          {tracked_linear_3d(input_tensor, linear_layer)[0, 0, :5]}")
linear layer has bias:   True
untracked linear:        tensor([-0.8449,  0.5872, -0.2979, -0.7776, -0.2705])
einsum linear with bias: tensor([-0.8449,  0.5872, -0.2979, -0.7776, -0.2705])
einsum linear:           tensor([ 0.2498,  0.8545, -0.1213,  0.4675,  0.1067])
tracked linear:          tensor([ 0.2498,  0.8545, -0.1213,  0.4675,  0.1067])
3d untracked linear:        tensor([-0.8449,  0.5872, -0.2979, -0.7776, -0.2705])
3d einsum linear with bias: tensor([-0.8449,  0.5872, -0.2979, -0.7776, -0.2705])
3d einsum linear:           tensor([ 0.2498,  0.8545, -0.1213,  0.4675,  0.1067])
3d tracked linear:          tensor([ 0.2498,  0.8545, -0.1213,  0.4675,  0.1067])

This approach using einstein summation works with the core operation, the bias needs some special handling.

We can try adding the bias to each value as a flat amount and see if we need to get more complicated later. It may well be that we need to assign bias to the values in accordance with their current value.

Code
import torch
from torch import nn

input_tensor = torch.rand(2, 768)

@torch.no_grad()
def tracked_linear_with_bias(t: torch.Tensor, linear: nn.Linear) -> torch.Tensor:
    # 49 values for each region of the input image, sum to 1 per t value
    t_tracked = track(t)
    tracking_size = t_tracked.shape[-1]
    t_mm = torch.einsum("bik,jik->bjk", t_tracked, linear.weight[:, :, None])
    t_bias = t_mm + (linear.bias[:, None] / tracking_size)
    return t_bias.sum(dim=-1)

@torch.no_grad()
def tracked_linear_with_bias_3d(t: torch.Tensor, linear: nn.Linear) -> torch.Tensor:
    # 49 values for each region of the input image, sum to 1 per t value
    t_tracked = track(t)
    tracking_size = t_tracked.shape[-1]
    t_mm = torch.einsum("Bbik,Bjik->Bbjk", t_tracked, linear.weight[None, :, :, None])
    t_bias = t_mm + (linear.bias[:, None] / tracking_size)
    return t_bias.sum(dim=-1)

print(f"linear layer has bias:    {linear_layer.bias is not None}")
print(f"untracked linear:         {untracked_linear(input_tensor, linear_layer)[0, :5]}")
print(f"einsum linear with bias:  {einsum_linear_with_bias(input_tensor, linear_layer)[0, :5]}")
print(f"tracked linear with bias: {tracked_linear_with_bias(input_tensor, linear_layer)[0, :5]}")

input_tensor = input_tensor.reshape(1, 2, 768)
print(f"tracked linear with bias: {tracked_linear_with_bias_3d(input_tensor, linear_layer)[0, 0, :5]}")
linear layer has bias:    True
untracked linear:         tensor([-0.6019,  0.5975, -0.1942, -0.9617, -0.5879])
einsum linear with bias:  tensor([-0.6019,  0.5975, -0.1942, -0.9617, -0.5879])
tracked linear with bias: tensor([-0.6019,  0.5975, -0.1942, -0.9617, -0.5879])
tracked linear with bias: tensor([-0.6019,  0.5975, -0.1942, -0.9617, -0.5879])

This is good enough to get going with. The assignment of the bias to the values may need work as it will skew the attribution quite heavily.

Code
import torch
from torch import nn

class TrackingLinear(nn.Module):
    def __init__(self, linear: nn.Linear, image_regions: int) -> None:
        super().__init__()
        self.weight = linear.weight[:, :, None]
        if linear.bias is not None:
            self.bias = linear.bias[:, None]
        else:
            self.bias = torch.zeros(self.weight.shape[0], 1)

    def forward(self, xs: torch.Tensor) -> torch.Tensor:
        if len(xs.shape) == 3:
            mm = torch.einsum("bik,jik->bjk", xs, self.weight)
        else:
            mm = torch.einsum("Bbik,Bjik->Bbjk", xs, self.weight[None])
        bias = mm + self.bias / xs.shape[-1]
        return bias

@torch.no_grad()
def tracked_linear_class(t: torch.Tensor, linear: nn.Linear) -> torch.Tensor:
    t_tracked = track(t)
    layer = TrackingLinear(linear, t_tracked.shape[-1])
    return layer(t_tracked).sum(dim=-1)

input_tensor = torch.rand(2, 768)
print(f"linear layer has bias:    {linear_layer.bias is not None}")
print(f"untracked linear:         {untracked_linear(input_tensor, linear_layer)[0, :5]}")
print(f"tracked linear class:     {tracked_linear_class(input_tensor, linear_layer)[0, :5]}")

input_tensor = input_tensor.reshape(1, 2, 768)
print(f"tracked linear class 3d:  {tracked_linear_class(input_tensor, linear_layer)[0, 0, :5]}")
linear layer has bias:    True
untracked linear:         tensor([-0.8051,  0.5180, -0.3463, -1.0407, -0.4379])
tracked linear class:     tensor([-0.8051,  0.5180, -0.3463, -1.0407, -0.4379])
tracked linear class 3d:  tensor([-0.8051,  0.5180, -0.3463, -1.0407, -0.4379])

This works then. The next thing is the batchwize matrix multiplication.

Traceable Batch Matrix Multiplication

Batch matrix multiplication is an adjustment to what we have been doing so far. The difference is that both sides of the multiplication are now tracked.

I’ve tried to get this working with a single einsum operation but the product of the two attributions means that the original values are not preserved. That may not matter when calculating attribution but until I have a working system I am not happy changing the results.

Code
import torch

left_tensor = torch.rand(2, 100, 768, requires_grad=False)
right_tensor = torch.rand(2, 768, 100, requires_grad=False)

@torch.no_grad()
def untracked_bmm(left: torch.Tensor, right: torch.Tensor) -> torch.Tensor:
    return torch.bmm(left_tensor, right_tensor)

@torch.no_grad()
def einsum_bmm(left: torch.Tensor, right: torch.Tensor) -> torch.Tensor:
    return torch.einsum("bnm,bmp->bnp", left, right)

@torch.no_grad()
def tracked_bmm(left: torch.Tensor, right: torch.Tensor) -> torch.Tensor:
    left_tracked = track(left)
    right_tracked = track(right)

    # kinda horrible here,
    # but the problem is that the full batch multiplication (i.e. bnmi,bmpi->bnpi over both tracked)
    # leads to a squaring of the values as the decomposed values are multiplied on both sides
    # to make this clear, imagine that left_tracked was [0.5, 0.5] and right_tracked was [0.5, 0.5]
    # if we do a pointwise multiplication of these we get [0.25, 0.25] which sums to 0.5
    # reversing this scaling effect is tricky as the actual effect depends on the tracking value itself
    
    # least squares inverse of the matrix
    # then use that to find the closest point to the untracked 

    left_bmm = torch.einsum("bnmi,bmpi->bnpi", left_tracked, right.unsqueeze(-1))
    right_bmm = torch.einsum("bnmi,bmpi->bnpi", left.unsqueeze(-1), right_tracked)
    bmm = left_bmm + right_bmm
    bmm = bmm / 2
    return bmm.sum(dim=-1)

print(f"untracked bmm: {untracked_bmm(left_tensor, right_tensor)[0, 0, :5]}")
print(f"einsum bmm:    {einsum_bmm(left_tensor, right_tensor)[0, 0, :5]}")
print(f"tracked bmm:   {tracked_bmm(left_tensor, right_tensor)[0, 0, :5]}")
untracked bmm: tensor([186.5563, 196.8026, 192.7868, 189.7014, 189.0686])
einsum bmm:    tensor([186.5563, 196.8026, 192.7868, 189.7014, 189.0686])
tracked bmm:   tensor([186.5563, 196.8026, 192.7869, 189.7014, 189.0686])
Code
@torch.no_grad()
def tracked_bmm_full(left: torch.Tensor, right: torch.Tensor) -> torch.Tensor:
    left_tracked = track(left)
    right_tracked = track(right)

    # kinda horrible here,
    # but the problem is that the full batch multiplication (i.e. bnmi,bmpi->bnpi over both tracked)
    # leads to a squaring of the values as the decomposed values are multiplied on both sides
    # to make this clear, imagine that left_tracked was [0.5, 0.5] and right_tracked was [0.5, 0.5]
    # if we do a pointwise multiplication of these we get [0.25, 0.25] which sums to 0.5
    # reversing this scaling effect is tricky as the actual effect depends on the tracking value itself
    
    # least squares inverse of the matrix
    # then use that to find the closest point to the untracked 

    tracked_bmm = torch.einsum("bnmi,bmpi->bnpi", left_tracked, right_tracked)
    tracked_bmm_ratio = tracked_bmm.abs().sum(dim=-1)
    true_bmm = torch.einsum("bnm,bmp->bnp", left, right)
    true_bmm_ratio = true_bmm.abs()

    scaled_bmm = tracked_bmm * (true_bmm_ratio / tracked_bmm_ratio).unsqueeze(-1)
    return scaled_bmm.sum(dim=-1)

print(f"untracked bmm: {untracked_bmm(left_tensor, right_tensor)[0, 0, :5]}")
print(f"tracked bmm:   {tracked_bmm_full(left_tensor, right_tensor)[0, 0, :5]}")
untracked bmm: tensor([186.5563, 196.8026, 192.7868, 189.7014, 189.0686])
tracked bmm:   tensor([186.5563, 196.8026, 192.7868, 189.7014, 189.0687])
Code
import torch
from torch import nn

class TrackingBMM(nn.Module):
    def forward(self, xs: torch.Tensor, ys: torch.Tensor) -> torch.Tensor:
        x_untracked = xs.sum(dim=-1).unsqueeze(-1)
        y_untracked = ys.sum(dim=-1).unsqueeze(-1)
        
        xs_bmm = torch.einsum("bnmi,bmpi->bnpi", xs, y_untracked)
        ys_bmm = torch.einsum("bnmi,bmpi->bnpi", x_untracked, ys)
        bmm = xs_bmm + ys_bmm
        bmm = bmm / 2
        return bmm

@torch.no_grad()
def tracked_bmm_class(left: torch.Tensor, right: torch.Tensor) -> torch.Tensor:
    layer = TrackingBMM()
    return layer(track(left), track(right)).sum(dim=-1)

print(f"untracked bmm: {untracked_bmm(left_tensor, right_tensor)[0, 0, :5]}")
print(f"einsum bmm:    {einsum_bmm(left_tensor, right_tensor)[0, 0, :5]}")
print(f"tracked bmm:   {tracked_bmm_class(left_tensor, right_tensor)[0, 0, :5]}")
untracked bmm: tensor([186.5563, 196.8026, 192.7868, 189.7014, 189.0686])
einsum bmm:    tensor([186.5563, 196.8026, 192.7868, 189.7014, 189.0686])
tracked bmm:   tensor([186.5563, 196.8026, 192.7869, 189.7014, 189.0686])

Traceable Quick GELU

This should be more straightforward. The quick GELU operation involves a sigmoid which should be done over the aggregated value, and the effect on the tracing values can be achieved by multiplying them by the resulting value.

Code
import torch
from torch import nn

gelu = model.vision_model.encoder.layers[0].mlp.activation_fn
input_tensor = torch.rand(2, 768, requires_grad=False)

@torch.no_grad()
def untracked_gelu(xs: torch.Tensor, activation_fn: nn.Module) -> torch.Tensor:
    # this is the quick gelu algorithm used by clip
    return activation_fn(xs)

@torch.no_grad()
def tracked_gelu(xs: torch.Tensor, activation_fn: nn.Module) -> torch.Tensor:
    xs_tracked = track(xs)
    # gelu would change a lot if it were run over the tracking parts of the value
    # so we have to run it on the overall value.
    # Since it has become the output value, to then apply it back to the tracking values
    # we need to normalize it.
    # for example, tracking values of [0.4, 0.4] -> gelu -> 0.8
    # if we multiply the tracking values by the gelu we get [0.32, 0.32]
    # but we want the sum to be 0.8, so we divide the gelu value by the sum of the tracking values
    # [0.4, 0.4] * (0.8 / sum([0.4, 0.4])) -> [0.4, 0.4]
    gelu = activation_fn(xs)
    normalized_gelu = gelu / xs
    xs_tracked = xs_tracked * normalized_gelu.unsqueeze(-1)
    return xs_tracked.sum(dim=-1)

print(f"untracked gelu: {untracked_gelu(input_tensor, gelu)[0, :5]}")
print(f"tracked gelu:   {tracked_gelu(input_tensor, gelu)[0, :5]}")
untracked gelu: tensor([0.0706, 0.0366, 0.7400, 0.3145, 0.1581])
tracked gelu:   tensor([0.0706, 0.0366, 0.7400, 0.3145, 0.1581])
Code
import torch
from torch import nn

class TrackingActivation(nn.Module):
    def __init__(self, activation_function: nn.Module) -> None:
        super().__init__()
        self.activation_function = activation_function

    def forward(self, xs: torch.Tensor) -> torch.Tensor:
        x_untracked = xs.sum(dim=-1)
        # if the sum is zero then dividing by it will result in nan
        # x_untracked[x_untracked == 0.] = 1.
        activation = self.activation_function(x_untracked)
        normalized_activation = activation / x_untracked
        normalized_activation[x_untracked == 0.] = 0. # make the tracked values be zero
        return xs * normalized_activation.unsqueeze(-1)

@torch.no_grad()
def tracked_gelu_class(xs: torch.Tensor, activation_function: nn.Module) -> torch.Tensor:
    layer = TrackingActivation(activation_function)
    return layer(track(xs)).sum(dim=-1)

print(f"untracked gelu: {untracked_gelu(input_tensor, gelu)[0, :5]}")
print(f"tracked gelu:   {tracked_gelu_class(input_tensor, gelu)[0, :5]}")
untracked gelu: tensor([0.0706, 0.0366, 0.7400, 0.3145, 0.1581])
tracked gelu:   tensor([0.0706, 0.0366, 0.7400, 0.3145, 0.1581])

Traceable Layer Normalization

This is a very similar approach to the GELU as layer normalization operates over the last two dimensions. So the tracked data needs to be untracked to get the normalized version and then scale the values by that. Since it’s so similar this can use the same TrackingActivation class.

Code
import torch
from torch import nn

input_tensor = torch.rand(2, 768, requires_grad=False)
layer_norm = model.vision_model.encoder.layers[0].layer_norm1

@torch.no_grad()
def untracked_layer_norm(xs: torch.Tensor, norm: nn.LayerNorm) -> torch.Tensor:
    return norm(xs)

@torch.no_grad()
def tracked_layer_norm(xs: torch.Tensor, norm: nn.LayerNorm) -> torch.Tensor:
    xs_tracked = track(xs)
    # layer normalization operates over the last two dimensions
    # so if it were run over the tracked values then it would fundamentally change them
    # instead the scaling approach from gelu will be used
    layer_norm = untracked_layer_norm(xs, norm)
    normalized_layer_norm = layer_norm / xs
    xs_tracked = xs_tracked * normalized_layer_norm.unsqueeze(-1)
    return xs_tracked.sum(dim=-1)

print(f"untracked layer norm: {untracked_layer_norm(input_tensor, layer_norm)[0, :5]}")
print(f"tracked layer norm:   {tracked_layer_norm(input_tensor, layer_norm)[0, :5]}")
untracked layer norm: tensor([ 0.1438, -0.3158,  0.1158, -0.2560, -0.2967])
tracked layer norm:   tensor([ 0.1438, -0.3158,  0.1158, -0.2560, -0.2967])
Code
import torch
from torch import nn

@torch.no_grad()
def tracked_layer_norm_class(xs: torch.Tensor, norm: nn.LayerNorm) -> torch.Tensor:
    layer = TrackingActivation(norm)
    return layer(track(xs)).sum(dim=-1)

print(f"untracked layer norm: {untracked_layer_norm(input_tensor, layer_norm)[0, :5]}")
print(f"tracked layer norm:   {tracked_layer_norm_class(input_tensor, layer_norm)[0, :5]}")
untracked layer norm: tensor([ 0.1438, -0.3158,  0.1158, -0.2560, -0.2967])
tracked layer norm:   tensor([ 0.1438, -0.3158,  0.1158, -0.2560, -0.2967])

Traceable Softmax

This is much the same as the GELU and Layer Norm.

Code
import torch

input_tensor = torch.rand(2, 768, requires_grad=False)

@torch.no_grad()
def untracked_softmax(xs: torch.Tensor) -> torch.Tensor:
    return torch.softmax(xs, dim=-1)

@torch.no_grad()
def tracked_softmax(xs: torch.Tensor) -> torch.Tensor:
    xs_tracked = track(xs)
    softmax = untracked_softmax(xs)
    normalized_softmax = softmax / xs
    xs_tracked = xs_tracked * normalized_softmax.unsqueeze(-1)
    return xs_tracked.sum(dim=-1)

print(f"untracked softmax: {untracked_softmax(input_tensor)[0, :5]}")
print(f"tracked softmax:   {tracked_softmax(input_tensor)[0, :5]}")
untracked softmax: tensor([0.0010, 0.0014, 0.0014, 0.0014, 0.0016])
tracked softmax:   tensor([0.0010, 0.0014, 0.0014, 0.0014, 0.0016])
Code
import torch
from torch import nn

@torch.no_grad()
def tracked_softmax_class(xs: torch.Tensor) -> torch.Tensor:
    layer = TrackingActivation(nn.Softmax(dim=-1))
    return layer(track(xs)).sum(dim=-1)

print(f"untracked softmax: {untracked_softmax(input_tensor)[0, :5]}")
print(f"tracked softmax:   {tracked_softmax_class(input_tensor)[0, :5]}")
untracked softmax: tensor([0.0010, 0.0014, 0.0014, 0.0014, 0.0016])
tracked softmax:   tensor([0.0010, 0.0014, 0.0014, 0.0014, 0.0016])

Traceable CLIP Attention

This is the smallest chunk of the clip layers that should be recoverable. It’s an attention block and involves some transposition.

Code
from transformers.models.clip.modeling_clip import CLIPAttention

class TrackingCLIPAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, layer: CLIPAttention, tracking_dim: int) -> None:
        super().__init__()
        self.layer = layer

        self.k_proj = TrackingLinear(layer.k_proj, tracking_dim)
        self.v_proj = TrackingLinear(layer.v_proj, tracking_dim)
        self.q_proj = TrackingLinear(layer.q_proj, tracking_dim)
        self.out_proj = TrackingLinear(layer.out_proj, tracking_dim)
        self.bmm = TrackingBMM()
        self.softmax = TrackingActivation(nn.Softmax(dim=-1))

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int, tracking_dim: int):
        return tensor.view(bsz, seq_len, self.layer.num_heads, self.layer.head_dim, tracking_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        causal_attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        """Input shape: Batch x Time x Channel x Tracking"""

        bsz, tgt_len, embed_dim, tracking_dim = hidden_states.size()

        # get query proj
        query_states = self.q_proj(hidden_states) * self.layer.scale
        key_states = self._shape(self.k_proj(hidden_states), -1, bsz, tracking_dim)
        value_states = self._shape(self.v_proj(hidden_states), -1, bsz, tracking_dim)

        proj_shape = (bsz * self.layer.num_heads, -1, self.layer.head_dim, tracking_dim)
        query_states = self._shape(query_states, tgt_len, bsz, tracking_dim).view(*proj_shape)
        key_states = key_states.view(*proj_shape)
        value_states = value_states.view(*proj_shape)

        src_len = key_states.size(1)
        attn_weights = self.bmm(query_states, key_states.transpose(1, 2))

        # if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
        #     raise ValueError(
        #         f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
        #         f" {attn_weights.size()}"
        #     )

        # # apply the causal_attention_mask first
        # if causal_attention_mask is not None:
        #     if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
        #         raise ValueError(
        #             f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
        #             f" {causal_attention_mask.size()}"
        #         )
        #     attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
        #     attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        # if attention_mask is not None:
        #     if attention_mask.size() != (bsz, 1, tgt_len, src_len):
        #         raise ValueError(
        #             f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
        #         )
        #     attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
        #     attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        attn_weights = self.softmax(attn_weights)

        if output_attentions:
            # this operation is a bit akward, but it's required to
            # make sure that attn_weights keeps its gradient.
            # In order to do so, attn_weights have to reshaped
            # twice and have to be reused in the following
            attn_weights_reshaped = attn_weights.view(bsz, self.layer.num_heads, tgt_len, src_len, tracking_dim)
            attn_weights = attn_weights_reshaped.view(bsz * self.layer.num_heads, tgt_len, src_len, tracking_dim)
        else:
            attn_weights_reshaped = None

        # attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
        attn_probs = attn_weights

        attn_output = self.bmm(attn_probs, value_states)

        # if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
        #     raise ValueError(
        #         f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
        #         f" {attn_output.size()}"
        #     )

        attn_output = attn_output.view(bsz, self.layer.num_heads, tgt_len, self.layer.head_dim, tracking_dim)
        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.reshape(bsz, tgt_len, embed_dim, tracking_dim)

        attn_output = self.out_proj(attn_output)

        return attn_output, attn_weights_reshaped
Code
import torch
from transformers.models.clip.modeling_clip import CLIPAttention

input_tensor = torch.rand(2, 2, 768, requires_grad=False)
layer_attention = model.vision_model.encoder.layers[0].self_attn

@torch.no_grad()
def untracked_attention(xs: torch.Tensor, attention: CLIPAttention) -> torch.Tensor:
    return attention(xs)[0]

@torch.no_grad()
def tracked_attention(xs: torch.Tensor, attention: CLIPAttention) -> torch.Tensor:
    tracked_attention = TrackingCLIPAttention(attention, 7*7)
    xs_tracked = track(xs)
    xs_tracked = tracked_attention(xs_tracked)[0]
    return xs_tracked.sum(dim=-1)

print(f"untracked attention: {untracked_attention(input_tensor, layer_attention)[0, 0, :5]}")
print(f"tracked attention:   {tracked_attention(input_tensor, layer_attention)[0, 0, :5]}")
untracked attention: tensor([0.0519, 0.0514, 0.0384, 0.0903, 0.1239])
tracked attention:   tensor([0.0519, 0.0514, 0.0384, 0.0903, 0.1239])

Traceable CLIP MLP

This is a simpler model which is required to make the overall encoder layer. The GELU activation function is used in this one.

Code
from transformers.models.clip.modeling_clip import CLIPMLP

class TrackingCLIPMLP(nn.Module):
    def __init__(self, layer: CLIPMLP, tracking_dim: int) -> None:
        super().__init__()
        self.layer = layer

        self.activation_fn = TrackingActivation(layer.activation_fn)
        self.fc1 = TrackingLinear(layer.fc1, tracking_dim)
        self.fc2 = TrackingLinear(layer.fc2, tracking_dim)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states = self.fc2(hidden_states)
        return hidden_states
Code
import torch
from transformers.models.clip.modeling_clip import CLIPMLP

input_tensor = torch.rand(2, 2, 768, requires_grad=False)
layer_mlp = model.vision_model.encoder.layers[0].mlp

@torch.no_grad()
def untracked_mlp(xs: torch.Tensor, mlp: CLIPMLP) -> torch.Tensor:
    return mlp(xs)

@torch.no_grad()
def tracked_mlp(xs: torch.Tensor, mlp: CLIPMLP) -> torch.Tensor:
    tracked_mlp = TrackingCLIPMLP(mlp, 7*7)
    xs_tracked = track(xs)
    xs_tracked = tracked_mlp(xs_tracked)
    return xs_tracked.sum(dim=-1)

print(f"untracked mlp: {untracked_mlp(input_tensor, layer_mlp)[0, 0, :5]}")
print(f"tracked mlp:   {tracked_mlp(input_tensor, layer_mlp)[0, 0, :5]}")
untracked mlp: tensor([-0.3790,  0.0192,  0.0858,  0.3703,  0.2209])
tracked mlp:   tensor([-0.3790,  0.0192,  0.0858,  0.3703,  0.2209])

Traceable CLIP Encoder

Finally! With this the entire image encoder can be traced.

Code
from typing import Optional
from transformers.models.clip.modeling_clip import CLIPEncoderLayer

class TrackingCLIPEncoderLayer(nn.Module):
    def __init__(self, layer: CLIPEncoderLayer, tracking_dim: int) -> None:
        super().__init__()
        self.layer = layer

        self.self_attn = TrackingCLIPAttention(layer.self_attn, tracking_dim=tracking_dim)
        self.layer_norm1 = TrackingActivation(layer.layer_norm1)
        self.mlp = TrackingCLIPMLP(layer.mlp, tracking_dim=tracking_dim)
        self.layer_norm2 = TrackingActivation(layer.layer_norm2)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor],
        causal_attention_mask: Optional[torch.Tensor],
        output_attentions: bool,
    ) -> torch.Tensor:
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
        hidden_states, attn_weights = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            causal_attention_mask=causal_attention_mask,
            output_attentions=output_attentions,
        )
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (attn_weights,)

        return outputs
Code
import torch
from transformers.models.clip.modeling_clip import CLIPEncoderLayer

input_tensor = torch.rand(2, 2, 768, requires_grad=False)
layer_encoder = model.vision_model.encoder.layers[0]

@torch.no_grad()
def untracked_encoder_layer(xs: torch.Tensor, encoder_layer: CLIPEncoderLayer) -> torch.Tensor:
    return encoder_layer(xs, attention_mask=None, causal_attention_mask=None)[0]

@torch.no_grad()
def tracked_encoder_layer(xs: torch.Tensor, encoder_layer: CLIPEncoderLayer) -> torch.Tensor:
    tracked_encoder_layer = TrackingCLIPEncoderLayer(encoder_layer, 7*7)
    xs_tracked = track(xs)
    xs_tracked = tracked_encoder_layer(
        xs_tracked,
        attention_mask=None,
        causal_attention_mask=None,
        output_attentions=False,
    )[0]
    return xs_tracked.sum(dim=-1)

print(f"untracked encoder: {untracked_encoder_layer(input_tensor, layer_encoder)[0, 0, :5]}")
print(f"tracked encoder:   {tracked_encoder_layer(input_tensor, layer_encoder)[0, 0, :5]}")
untracked encoder: tensor([ 0.7604, -0.0935,  0.4202,  0.4328,  1.1894])
tracked encoder:   tensor([ 0.7604, -0.0935,  0.4202,  0.4328,  1.1894])

Woo!

With this we should have a fully traceable encoder. The CLIPEncoder class can have the encoder layers replaced directly and should still work.

Trace CLIP Inference

Now it’s time to trace the cat as it passes through the dataset. We need to generate the traced inputs and then see how they progress as they move through the model.

I’ve cleaned up the code quite a bit and put it all together so we can now see the final result.

Code
# from src/main/python/blog/tracing/v2022/layers.py
from typing import Optional, Tuple

import torch
from torch import nn
from transformers import CLIPModel
from transformers.models.clip.modeling_clip import (
    CLIPMLP,
    CLIPAttention,
    CLIPEncoder,
    CLIPEncoderLayer,
    CLIPVisionTransformer,
)


def load_tracing_image_model(
    model_name: str = "openai/clip-vit-base-patch32",
) -> CLIPModel:
    model = CLIPModel.from_pretrained(model_name)
    tracing_model = TracingCLIPVisionTransformer(
        model=model.vision_model, projection=model.visual_projection
    )
    tracing_model.eval()
    return tracing_model


class TracingLinear(nn.Module):
    def __init__(self, linear: nn.Linear) -> None:
        super().__init__()
        self.weight = linear.weight[:, :, None]
        if linear.bias is not None:
            self.bias = linear.bias[:, None]
        else:
            self.bias = torch.zeros(self.weight.shape[0], 1)

    def forward(self, xs: torch.Tensor) -> torch.Tensor:
        if len(xs.shape) == 3:
            mm = self._matrix_multiply_3d(xs)
        else:
            mm = self._matrix_multiply_4d(xs)
        *_, image_regions = xs.shape
        bias = mm + self.bias / image_regions
        return bias

    def _matrix_multiply_3d(self, xs: torch.Tensor) -> torch.Tensor:
        return torch.einsum("bik,jik->bjk", xs, self.weight)

    def _matrix_multiply_4d(self, xs: torch.Tensor) -> torch.Tensor:
        return torch.einsum("Bbik,Bjik->Bbjk", xs, self.weight[None])


class TracingBMM(nn.Module):
    def forward(self, xs: torch.Tensor, ys: torch.Tensor) -> torch.Tensor:
        x_untracked = xs.sum(dim=-1)
        y_untracked = ys.sum(dim=-1)

        tracked_bmm = torch.einsum("bnmi,bmpi->bnpi", xs, ys)
        tracked_bmm_ratio = tracked_bmm.abs().sum(dim=-1)
        true_bmm = torch.einsum("bnm,bmp->bnp", x_untracked, y_untracked)
        true_bmm_ratio = true_bmm.abs()

        scaled_bmm = tracked_bmm * (true_bmm_ratio / tracked_bmm_ratio).unsqueeze(-1)
        return scaled_bmm


class TracingActivation(nn.Module):
    def __init__(self, activation_function: nn.Module) -> None:
        super().__init__()
        self.activation_function = activation_function

    def forward(self, xs: torch.Tensor) -> torch.Tensor:
        x_untracked = xs.sum(dim=-1)
        # if the sum is zero then dividing by it will result in nan
        # x_untracked[x_untracked == 0.0] = 1.0
        activation = self.activation_function(x_untracked)
        normalized_activation = activation / x_untracked

        # When values are zero after the application of the function then they
        # have no further contribution. That means we can set the values in xs to zero.
        normalized_activation[x_untracked == 0.0] = 0.0

        result = xs * normalized_activation.unsqueeze(-1)
        mask = normalized_activation.abs() < 0.01
        result[mask] = (normalized_activation[mask] / xs.shape[-1]).unsqueeze(-1)
        return result


class TracingCLIPAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, layer: CLIPAttention) -> None:
        super().__init__()
        self.layer = layer

        self.k_proj = TracingLinear(layer.k_proj)
        self.v_proj = TracingLinear(layer.v_proj)
        self.q_proj = TracingLinear(layer.q_proj)
        self.out_proj = TracingLinear(layer.out_proj)
        self.bmm = TracingBMM()
        self.softmax = TracingActivation(nn.Softmax(dim=-1))

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        causal_attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        """Input shape: Batch x Time x Channel x Tracing"""

        bsz, tgt_len, embed_dim, tracking_dim = hidden_states.size()

        # get query proj
        query_states = self.q_proj(hidden_states) * self.layer.scale
        key_states = self._shape(self.k_proj(hidden_states), -1, bsz, tracking_dim)
        value_states = self._shape(self.v_proj(hidden_states), -1, bsz, tracking_dim)

        proj_shape = (bsz * self.layer.num_heads, -1, self.layer.head_dim, tracking_dim)
        query_states = self._shape(query_states, tgt_len, bsz, tracking_dim).view(
            *proj_shape
        )
        key_states = key_states.view(*proj_shape)
        value_states = value_states.view(*proj_shape)

        src_len = key_states.size(1)
        attn_weights = self.bmm(query_states, key_states.transpose(1, 2))

        # code has been cut out here which relates to the
        # causal_attention_mask, attention_mask and error checking
        assert causal_attention_mask is None
        assert attention_mask is None

        attn_weights = self.softmax(attn_weights)

        if output_attentions:
            # this operation is a bit akward, but it's required to
            # make sure that attn_weights keeps its gradient.
            # In order to do so, attn_weights have to reshaped
            # twice and have to be reused in the following
            attn_weights_reshaped = attn_weights.view(
                bsz, self.layer.num_heads, tgt_len, src_len, tracking_dim
            )
            attn_weights = attn_weights_reshaped.view(
                bsz * self.layer.num_heads, tgt_len, src_len, tracking_dim
            )
        else:
            attn_weights_reshaped = None

        # this is not intended for training so dropout is removed entirely
        # attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
        attn_probs = attn_weights

        attn_output = self.bmm(attn_probs, value_states)
        attn_output = attn_output.view(
            bsz, self.layer.num_heads, tgt_len, self.layer.head_dim, tracking_dim
        )
        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.reshape(bsz, tgt_len, embed_dim, tracking_dim)

        attn_output = self.out_proj(attn_output)

        return attn_output, attn_weights_reshaped

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int, tracking_dim: int):
        return (
            tensor.view(
                bsz, seq_len, self.layer.num_heads, self.layer.head_dim, tracking_dim
            )
            .transpose(1, 2)
            .contiguous()
        )


class TracingCLIPMLP(nn.Module):
    def __init__(self, layer: CLIPMLP) -> None:
        super().__init__()
        self.activation_fn = TracingActivation(layer.activation_fn)
        self.fc1 = TracingLinear(layer.fc1)
        self.fc2 = TracingLinear(layer.fc2)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states = self.fc2(hidden_states)
        return hidden_states


class TracingCLIPEncoderLayer(nn.Module):
    def __init__(self, layer: CLIPEncoderLayer) -> None:
        super().__init__()
        self.self_attn = TracingCLIPAttention(layer.self_attn)
        self.layer_norm1 = TracingActivation(layer.layer_norm1)
        self.mlp = TracingCLIPMLP(layer.mlp)
        self.layer_norm2 = TracingActivation(layer.layer_norm2)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor],
        causal_attention_mask: Optional[torch.Tensor],
        output_attentions: bool,
    ) -> torch.Tensor:
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
        hidden_states, attn_weights = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            causal_attention_mask=causal_attention_mask,
            output_attentions=output_attentions,
        )
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (attn_weights,)

        return outputs


class TracingCLIPVisionTransformer(nn.Module):
    def __init__(self, model: CLIPVisionTransformer, projection: nn.Linear) -> None:
        super().__init__()
        self.embeddings = model.embeddings
        # misspelling present in model
        self.pre_layernorm = TracingActivation(model.pre_layrnorm)
        self.encoder = CLIPEncoder(model.encoder.config)
        self.encoder.layers = nn.ModuleList(
            [TracingCLIPEncoderLayer(layer) for layer in model.encoder.layers]
        )
        self.post_layernorm = TracingActivation(model.post_layernorm)
        self.visual_projection = TracingLinear(projection)

    def forward(self, pixel_values: Optional[torch.Tensor] = None) -> torch.Tensor:
        embeddings = self.embeddings(pixel_values=pixel_values)
        _, image_regions, _ = embeddings.shape
        embeddings_traced = torch.zeros(
            *embeddings.shape, image_regions, device=embeddings.device
        )
        for i in range(image_regions):
            embeddings_traced[:, i, :, i] = embeddings[:, i, :]

        embeddings_normalized = self.pre_layernorm(embeddings_traced)
        encoded = self.encoder(
            inputs_embeds=embeddings_normalized,
            output_attentions=False,
            output_hidden_states=False,
            return_dict=False,
        )
        encoded = encoded[0]
        encoded = encoded[:, 0, :]
        encoded = self.post_layernorm(encoded)
        return self.visual_projection(encoded)
Code
tracing_model = load_tracing_image_model()

image_features = processor.feature_extractor(image, return_tensors="pt")
with torch.inference_mode():
    image_embedding = tracing_model(**image_features)
image_embedding.shape
torch.Size([1, 512, 50])

This shape is the original 1 (one image) 512 (image embedding) shape with the additional 50 image regions from the embedding. The 0 index is the class embeddings, so that doesn’t correspond to any region of the image. Patches of the image come after so we can now try to correlate what it is paying attention to in the image with the image itself.

The original embedding indices of interest were:

index value
92 -7.000749
98 -1.338176
198 -1.162172
258 -1.042850
389 -1.041473
index value
321 1.693858
39 1.293888
376 1.234921

Looking at 92 first, we are looking for the input regions with a negative value:

Code
show_region_attention(
    image_features.pixel_values,
    image_embedding[0, 92, 1:].reshape(1, 7,7) < 0,
)

Code
image_embedding[0, 92, 1:]
tensor([ 19249.2695,  51099.9062,  65452.1328,  70444.7578,  49403.6406,
         57936.8125,  65320.5234,  48276.9531,  56572.6172,  57694.8203,
         47658.5234,  63807.6797,  63375.8672,  45262.5195,  57963.7734,
         66889.7109,  55779.8516,  49412.5859,  87784.8828,  71330.4375,
         56336.2578,  75340.0859,  55011.8203,  70323.7734,  59636.6562,
         68673.2500,  78986.4766,  59559.0938,  55952.9844,  54741.3594,
         62718.3750,  69416.1484,  70280.3750,  73366.3672,  58585.6250,
         30614.5195,  62875.6953,  84539.3516,  48815.9688,  59098.1719,
         65311.5469,  64728.7734,  45546.6758,  47916.9766,  57140.4609,
         66092.4141, -65479.3672,  53917.0234,  68347.8438])

This seems to be looking at the tail and whiskers of the cat. The next most interesting region is 321 which is positive.

Code
show_region_attention(
    image_features.pixel_values,
    image_embedding[0, 321, 1:].reshape(1, 7,7) > 0,
)

This is covering the whole thing.

When reviewing this the most important thing is the degree to which this has faithfully tracked the activations as they pass through the image. I’ve had to adjust some aspects of the tracing calculation both for correctness (divide by zero) and mathematical simplicity (multiplication of two tracked matricies). These changes are likely to lead to compounding differences between the original outputs and the new ones.

Code
tracing_difference = (image_embedding.sum(dim=-1) - image_embeds).abs()

tracing_difference.max(), tracing_difference.mean()
(tensor(6.8679), tensor(0.6289))
Code
image_embeds.max(), image_embeds.mean()
(tensor(1.6939), tensor(0.0007))

This shows me that the tracing has unfortunately lost it’s way. The difference between the original output and the traced output is significant so unfortunately these results are not reliable.