How does the past parameter work with bi-directional models?

GPT-2 can take previous activations to process the following tokens in the sentence. How does this work with a bi-directional model like RoBERTa?
Published

December 11, 2021

The GPT-2 model can take a parameter called past_key_values which allows you to partially process some text and then save it. You can use the saved input to run several subsequent sequences which allows you to try out many different prompts on a single bit of text. This is really handy if you have different prompts for different tasks, like sentiment and emotion.

This works because GPT-2 is a uni-directional model which means that the processing of a token is only influenced by the tokens before it. A diagram of the model might help here:

The grey arrows within the model show the direction of information flow. Given this we can see that if we split the model into two horizontally then we can save the activations and reuse them.

The split does not change the output at all. This all relies on the fact that the model is uni-directional - information flows from early activations to later ones only.

GPT-2 Past

GPT-2 is a uni-directional model that supports this behaviour. In transformers this is implemented using the past_key_values argument. We can see this here:

Code
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
model.eval() ; None
Code
import torch

@torch.no_grad()
def last_unsplit_logits(text: str) -> torch.Tensor:
    tokens = tokenizer(
        text, return_tensors="pt", add_special_tokens=False
    ).input_ids
    logits = model(tokens).logits
    return logits[0, -1]

@torch.no_grad()
def last_split_logits(text: str) -> torch.Tensor:
    tokens = tokenizer(
        text, return_tensors="pt", add_special_tokens=False
    ).input_ids
    past_key_values = model(tokens[:, :-1]).past_key_values
    logits = model(tokens[:, -1:], past_key_values=past_key_values).logits

    return logits[0, 0]
Code
unsplit = last_unsplit_logits("I like to eat")
split = last_split_logits("I like to eat")

average = torch.abs(unsplit).mean().item()
difference = torch.abs(unsplit - split).mean().item()

print(f"Average output is {average} and average difference is {difference}")
Average output is 120.7281265258789 and average difference is 1.722940396575723e-05

While there is some difference in the output we can see that it is less than one in a million.

RoBERTa Past

RoBERTa is a BERT based model which is bi-directional. This means that tokens receive information from future tokens as well as past tokens. I think that this breaks the use of past_key_values as you cannot segment the model activations.

The transformers documentation for the RoBERTa model show that it can also accept the past_key_values argument. This surprises me and I want to see if it works.

Code
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("roberta-base")
model = AutoModelForCausalLM.from_pretrained("roberta-base")
model.eval() ; None
If you want to use `RobertaLMHeadModel` as a standalone, add `is_decoder=True.`
Code
unsplit = last_unsplit_logits("I like to eat")
split = last_split_logits("I like to eat")

average = torch.abs(unsplit).mean().item()
difference = torch.abs(unsplit - split).mean().item()

print(f"Average output is {average} and average difference is {difference}")
Average output is 3.3664541244506836 and average difference is 2.536721706390381

On the plus side, the model didn’t error. However the difference between the outputs is so large that this technique is worthless as a multi prompting approach.

Perhaps a prefix prompt could work with the limitations inherent in this approach? Either way getting the model to work with splits seems to be significantly harder and much less valuable.