Prompt Training - Clustering Raw Output

Classification with trained prompts by clustering the raw transformer output instead
Published

May 4, 2021

I’ve been trying out different forms of prompt training recently and I think the technique really has value. One problem that I have encountered is picking the appropriate tokens to represent the different classifications. I suspect that there are good choices and bad choices for this task.

When I did the sentiment classification the use of “good” and “bad” as target tokens seems to have worked well. I’ve also evaluated arbitrary classifications using tokens “relevant” and “irr” (being the start of irrelevant), and that has performed significantly worse.

What I really need is a way to measure the output of the model and assign it to a category. I don’t actually care what concrete token it produces. Since the output isn’t fixed, just the relationship, the loss cannot be measured for an individual row. The loss is the distance between each pair of rows, measured against some ideal distance.

Let’s start trying to make this concrete. I will continue using GPT-2 small as that is a nice reference point. This technique should work with any language model.


Model Surgery

The first thing is to strip the language model head from the GPT-2 model. This is because the output of that head is not useful. The translation of the model output to the individual tokens is not useful.

Let’s check this assumption by reviewing the structure of the language model head.

Code
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("gpt2")
model.to("cuda")
model.eval() ; None
Code
model.lm_head
Linear(in_features=768, out_features=50257, bias=False)

So this is exactly what I expected. The language model head is just a linear layer. Translating the embedding matrix into the output features is not useful at this point. It would be better to just get the embedding.

The next thing to test is to load the pretrained prompt that was created before and assess the difference in output for different inputs.

Code
from transformers import AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token # needed to enable padding
prompt = torch.load("/data/blog/2021-04-13-dreaming-of-prompts/trained-prompt-1-20-768-002.pt")
Code
@torch.no_grad()
def get_output(
    text: str,
    prompt: torch.Tensor,
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer
) -> torch.Tensor:
    tokens = (
        tokenizer(text, return_tensors="pt")["input_ids"]
            .to("cuda")
    )
    token_embedding = model.transformer.wte(tokens)
    
    # join the tensors - dim 0 is the batch, 1 is the tokens, 2 is the specific embedding value
    full_embedding = torch.cat([token_embedding, prompt], dim=1)
    return (
        model.transformer(inputs_embeds=full_embedding)
            .last_hidden_state
            [0, -1] # single batch, last token
    )
Code
bad_output = get_output("i hate you", prompt=prompt, model=model, tokenizer=tokenizer)
good_output = get_output("yay awesome!", prompt=prompt, model=model, tokenizer=tokenizer)

Let’s first check that these produce the correct tokens.

Code
tokenizer.decode(model.lm_head(bad_output).argmax())
'bad'
Code
tokenizer.decode(model.lm_head(good_output).argmax())
'good'

Now we can actually see how different the output is.

Code
torch.cosine_similarity(bad_output, good_output, dim=0)
tensor(0.9981, device='cuda:0')

This is extremely interesting. My preferred measure of similarity for this suggests that these tokens are extremely similar. Cosine similarity is my preferred measure of similarity because it is normalized (produces output between -1 and 1) and it should be unaffected by being passed through a linear layer (bias would alter this though). Are these two outputs really that similar?

Code
(bad_output - good_output).abs().sum() / bad_output.abs().sum()
tensor(0.2914, device='cuda:0')
Code
((bad_output - good_output) ** 2).mean()
tensor(0.7285, device='cuda:0')

These two measurements (percentage difference and mean square difference) suggest that the cosine similarity is doing a poor job of measuring these two. I wonder if the tensors are of different magnitude?

Code
bad_output.norm(), good_output.norm()
(tensor(171.1756, device='cuda:0'), tensor(192.0267, device='cuda:0'))

So these tensors are of different magnitude. I still wonder if cosine similarity is a good loss measurement. Still, I can work with the mean square difference and cosine similarity when it comes to training.


Training

At this point we have enough to work with. Let’s revive the past dataloader from the previous iterations.

Code
#collapse

from typing import Dict, Iterator, Optional, Tuple, Union

import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

Past = Tuple[Tuple[torch.Tensor, ...], ...]
TextBatch = Dict[str, torch.Tensor]
PastBatch = Dict[str, Union[torch.Tensor, Past]]


class TextDataloader:
    """Provides a dataloader over a text dataframe"""

    def __init__(
        self,
        df: pd.DataFrame,
        *,
        tokenizer: AutoTokenizer,
        batch_size: int,
        max_length: int,
        device: torch.device = torch.device("cuda"),
        shuffle: bool = True,
    ) -> None:
        self.tokenizer = tokenizer
        self.df = df
        self.batch_size = batch_size
        self.max_length = max_length
        self.device = device
        self.shuffle = shuffle

    def __iter__(self) -> Iterator[TextBatch]:
        """Returns an iterator that returns batches.
        The final batch can be a partial batch."""
        if self.shuffle:
            df = self.df.sample(frac=1).reset_index(drop=True)
        else:
            df = self.df
        batch_size = self.batch_size

        for i in range(len(self)):
            start = i * batch_size
            end = start + batch_size
            yield self.to_batch(df[start:end])

    def __len__(self) -> int:
        """Returns the total number of batches that can be returned."""
        full_batches = len(self.df) // self.batch_size
        if len(self.df) % self.batch_size:
            return full_batches + 1
        return full_batches

    def to_batch(self, rows: pd.DataFrame) -> TextBatch:
        """Converts the rows into a batch"""
        tokens = self.tokenizer(
            rows.text.tolist(),
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=self.max_length,
        ).to(self.device)
        labels = torch.tensor(rows.label.tolist(), dtype=torch.long, device=self.device)
        return {
            "input_ids": tokens["input_ids"],
            "attention_mask": tokens["attention_mask"],
            "labels": labels,
        }


class PastDataloader(TextDataloader):  # pylint: disable=too-few-public-methods
    """Provides a dataloader which converts the text into past tensors"""

    def __init__(
        self,
        df: pd.DataFrame,
        *,
        model: AutoModelForCausalLM,
        tokenizer: AutoTokenizer,
        batch_size: int,
        max_length: int,
        label_map: Optional[Dict[str, int]] = None,
        device: torch.device = torch.device("cuda"),
        shuffle: bool = True,
    ) -> None:
        if label_map:
            df = df.copy()
            df["label"] = df.label.map(label_map)
        super().__init__(
            df=df,
            tokenizer=tokenizer,
            batch_size=batch_size,
            max_length=max_length,
            device=device,
            shuffle=shuffle,
        )
        model.to(device)
        self.model = model

    @torch.no_grad()
    def to_batch(self, rows: pd.DataFrame) -> PastBatch:
        batch = super().to_batch(rows)
        past_key_values = self.model(
            input_ids=batch["input_ids"],
            attention_mask=batch.get("attention_mask", None),
        ).past_key_values
        return {
            "past_key_values": past_key_values,
            "attention_mask": batch["attention_mask"],
            "labels": batch["labels"],
        }

And then we can adjust the training loop. What is most important here is to have a loss function that can compare every row to every other row according to the mutual labels.

Code
#collapse

from typing import Callable, Dict, Tuple, Union

import torch
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

LossFunction = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]


def train(
    *,
    dl: PastDataloader,
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    prompt_tokens: int,
    epochs: int,
    loss_fn: LossFunction,
) -> torch.Tensor:
    """Train the prompt"""
    prompt, prompt_attention = _make_prompt(
        model=model,
        tokenizer=tokenizer,
        prompt_tokens=prompt_tokens,
        device=dl.device,
    )

    # optimize just the prompt
    optimizer = torch.optim.Adam([prompt], lr=1e-3)

    total_loss = 0.0

    with tqdm(
        range(epochs), leave=False, bar_format="loss: {postfix[0]:>8.4f}", postfix=[0.0]
    ) as bar:
        for _epoch in bar:
            for batch in tqdm(dl, leave=False):
                total_loss += _process(
                    batch=batch,
                    model=model,
                    optimizer=optimizer,
                    prompt=prompt,
                    prompt_attention=prompt_attention,
                    loss_fn=loss_fn,
                )

            average_loss = total_loss / len(dl)
            bar.postfix[0] = average_loss
            print(f"Average loss: {average_loss:0.4f}")
            total_loss = 0.0

    return prompt.data


def _make_prompt(
    *,
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    prompt_tokens: int,
    device: torch.device,
) -> Tuple[torch.nn.Parameter, torch.Tensor]:
    """Generate the prompt by randomly choosing tokens and then converting to embeddings"""
    prompt_indexes = torch.randint(
        size=(prompt_tokens,), low=0, high=tokenizer.vocab_size, device=device
    )
    prompt_attention = torch.ones(
        size=(1, prompt_tokens), dtype=torch.long, device=device
    )
    prompt = torch.nn.Parameter(
        model.transformer.wte(prompt_indexes).clone()[None, :, :]
    )
    return prompt, prompt_attention


def _process(
    *,
    batch: Dict[str, Union[torch.Tensor, Past]],
    model: AutoModelForCausalLM,
    optimizer: torch.optim.Optimizer,
    prompt: torch.nn.Parameter,
    prompt_attention: torch.Tensor,
    loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
) -> float:
    optimizer.zero_grad()

    logits = _get_output_with_past(
        model=model,
        prompt=prompt,
        attention_mask=prompt_attention,
        past=batch["past_key_values"],
        past_attention_mask=batch["attention_mask"],
    )
    loss = loss_fn(logits, batch["labels"])

    loss.backward()
    optimizer.step()

    return loss.item()


def _get_output_with_past(
    *,
    model: AutoModelForCausalLM,
    prompt: torch.nn.Parameter,
    attention_mask: torch.Tensor,
    past: Past,
    past_attention_mask: torch.Tensor,
) -> torch.Tensor:
    """Get the predictions for the next token after the prompt"""
    # concatenate the past attention with the prompt attention
    batch_size = past_attention_mask.shape[0]
    attention_mask = attention_mask.repeat_interleave(batch_size, dim=0)
    attention_mask = torch.cat([past_attention_mask, attention_mask], dim=-1)

    # expand the prompt to match the batch size
    input_ids = prompt.repeat_interleave(batch_size, dim=0)

    state = model.transformer(
        inputs_embeds=input_ids,
        attention_mask=attention_mask,
        past_key_values=past,
    ).last_hidden_state
    return state[:, -1]
Code
from pathlib import Path
import pandas as pd

def load(path: Path) -> pd.DataFrame:
    positive_files = sorted(path.glob("pos/*.txt"))
    negative_files = sorted(path.glob("neg/*.txt"))
    
    return pd.DataFrame(
        [
            {"label": "good", "text": file.read_text()}
            for file in positive_files
        ] +
        [
            {"label": "bad", "text": file.read_text()}
            for file in negative_files
        ]
    )
Code
train_df = load(Path("/data/sentiment/imdb-movie-reviews/train"))
validation_df = load(Path("/data/sentiment/imdb-movie-reviews/test"))
Code
BATCH_SIZE = 32
MAX_LENGTH = 1_000

train_dataloader = PastDataloader(
    model=model,
    tokenizer=tokenizer,
    df=train_df,
    batch_size=BATCH_SIZE,
    max_length=MAX_LENGTH,
    shuffle=True,
    label_map={"bad": 0, "good": 1},
)
validation_dataloader = PastDataloader(
    model=model,
    tokenizer=tokenizer,
    df=validation_df,
    batch_size=BATCH_SIZE,
    max_length=MAX_LENGTH,
    shuffle=False,
    label_map={"bad": 0, "good": 1},
)

Cosine Similarity Loss Training

So now we have to consider the loss function. Ideally it would be fast to compute, as I will need to run it against every pair of rows in the batch. For now let’s just take the cosine similarity and then see if the labels match.

Code
def cosine_loss_fn(output: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
    batch_size = output.shape[0]
    
    cycled_output = output.repeat((batch_size, 1))
    # This repeats the tensor as if using cycle()
    # [1, 2, 3] -> [1, 2, 3, 1, 2, 3...]
    interleaved_output = output.repeat_interleave(batch_size, dim=0)
    # This repeats each element of the tensor
    # [1, 2, 3] -> [1, 1.., 2, 2.., 3, 3..]

    cycled_labels = labels.repeat(batch_size)
    interleaved_labels = labels.repeat_interleave(batch_size, dim=0)
    
    repeated_labels = (cycled_labels == interleaved_labels).long()
    repeated_labels = (repeated_labels * 2) - 1
    # label needs to be -1 for different or 1 for same
    # true  -> 1 -> 1*2-1 ->  1
    # false -> 0 -> 0*2-1 -> -1
    
    return torch.nn.functional.cosine_embedding_loss(
        cycled_output,
        interleaved_output,
        repeated_labels
    )

I’m not 100% sure this is correct but I’m going to give it a go anyway.

Code
cosine_trained_prompt = train(
    dl=train_dataloader,
    model=model,
    tokenizer=tokenizer,
    prompt_tokens=5,
    epochs=3,
    loss_fn=cosine_loss_fn
)
Average loss: 0.4847
Average loss: 0.4844
Average loss: 0.4841

So this trained relatively slowly and didn’t seem to change the loss much. I wonder how the output of the model has changed. It may also be that my loss function was insufficiently good at guiding the optimizer.

Code
bad_output = get_output(
    "i hate you",
    prompt=cosine_trained_prompt,
    model=model,
    tokenizer=tokenizer
)
good_output = get_output(
    "yay awesome!",
    prompt=cosine_trained_prompt,
    model=model,
    tokenizer=tokenizer
)
Code
(
    tokenizer.decode(model.lm_head(bad_output).argmax()),
    tokenizer.decode(model.lm_head(good_output).argmax())
)
(',', '!')
Code
torch.cosine_similarity(bad_output, good_output, dim=0)
tensor(0.9993, device='cuda:0')
Code
(bad_output - good_output).abs().sum() / bad_output.abs().sum()
tensor(0.2326, device='cuda:0')
Code
((bad_output - good_output) ** 2).mean()
tensor(1.2199, device='cuda:0')

So these are different to the outputs of the “good” and “bad” trained prompt. What I really want is to be able to visualize the spread of the values to see if they cluster by label. I guess PCA or similar could be used to show the layout of the outputs.

Code
torch.save(
    cosine_trained_prompt,
    "/data/blog/2021-05-04-prompt-training-clustering/trained-prompt-1-20-768-cosine.pt"
)

Distance Loss Training

Instead of cosine, which is based on the direction of the vector, I am going to try the absolute distance between the points that each vector represents. If the points are for the same class then the distance should be minimized, and if they are for different classes then it should be maximized.

My initial thought is to have two functions:

\[ \begin{aligned} L_{same} &= dist(x_a, x_b) \\ L_{different} &= \frac{1}{dist(x_a, x_b)} \end{aligned} \]

Because for the same class as the distance tends to zero, the loss will too. Equally for different classes as the distance tends to infinity, the loss will tend to zero. These may need balancing in some way.

Code
def distance_loss_fn(output: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
    batch_size = output.shape[0]
    
    cycled_output = output.repeat((batch_size, 1))
    # This repeats the tensor as if using cycle()
    # [1, 2, 3] -> [1, 2, 3, 1, 2, 3...]
    interleaved_output = output.repeat_interleave(batch_size, dim=0)
    # This repeats each element of the tensor
    # [1, 2, 3] -> [1, 1.., 2, 2.., 3, 3..]

    cycled_labels = labels.repeat(batch_size)
    interleaved_labels = labels.repeat_interleave(batch_size, dim=0)
    
    different_labels_mask = cycled_labels != interleaved_labels
    
    distance = torch.pairwise_distance(
        cycled_output,
        interleaved_output
    )

    # there is a problem with this,
    # it is considered an in-place operation which causes gradient calculations to fail
    # distance[different_labels_mask] = 1 / distance[different_labels_mask]
    
    return distance[~different_labels_mask].sum() + (1 / distance[different_labels_mask]).sum()
Code
distance_trained_prompt = train(
    dl=train_dataloader,
    model=model,
    tokenizer=tokenizer,
    prompt_tokens=5,
    epochs=3,
    loss_fn=distance_loss_fn
)
Average loss: 2447.2548
Average loss: 1159.9298
Average loss: 1079.7484
Code
bad_output = get_output(
    "i hate you",
    prompt=distance_trained_prompt,
    model=model,
    tokenizer=tokenizer
)
good_output = get_output(
    "yay awesome!",
    prompt=distance_trained_prompt,
    model=model,
    tokenizer=tokenizer
)
Code
(
    tokenizer.decode(model.lm_head(bad_output).argmax()),
    tokenizer.decode(model.lm_head(good_output).argmax())
)
(' I', ' I')
Code
torch.cosine_similarity(bad_output, good_output, dim=0)
tensor(1.0000, device='cuda:0')
Code
torch.pairwise_distance(bad_output[None, :], good_output[None, :])
tensor([6.2066], device='cuda:0')
Code
torch.save(
    distance_trained_prompt,
    "/data/blog/2021-05-04-prompt-training-clustering/trained-prompt-1-20-768-distance.pt"
)

This is really interesting. The points are quite different by distance but identical by cosine similarity.


Evaluation

Now that we have a trained prompt we can try evaluating it. Evaluating the prompt is hard because we don’t know what output corresponds to a given class. The training just aims to separate the outputs for the two classes.

So I think the evaluation should try to visualize the outputs for the different classes and then we can see if they are separate. The code in this section will collect the outputs for the validation set and then use PCA to reduce them to two dimensions. At that point they can be visualized.

Code
@torch.no_grad()
def get_all_outputs(
    dl: PastDataloader,
    model: AutoModelForCausalLM,
    prompt: torch.Tensor
) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]:
    prompt_attention = torch.ones((1, prompt.shape[1]), device=dl.device)

    for batch in tqdm(dl):
        logits = _get_output_with_past(
            model=model,
            prompt=prompt,
            attention_mask=prompt_attention,
            past=batch["past_key_values"],
            past_attention_mask=batch["attention_mask"],
        )
        yield logits.cpu(), batch["labels"].cpu()
Code
#collapse

import plotly.graph_objects as go
from IPython.display import HTML

def show_clusters(
    outputs: Iterator[Tuple[torch.Tensor, torch.Tensor]],
    dl: PastDataloader,
    title: str
) -> None:
    outputs = list(outputs)
    all_outputs = torch.cat([
        output
        for output, _ in outputs
    ], dim=0)
    all_labels = torch.cat([
        label
        for _, label in outputs
    ])
    pca_outputs = torch.pca_lowrank(all_outputs, q=2)
    
    df = pd.DataFrame([
        {
            "x": output[0].item(),
            "y": output[1].item(),
            "label": "blue" if label.item() else "red",
            "text": text
        }
        for output, label, text in zip(
            pca_outputs[0],
            all_labels,
            dl.df.text
        )
    ])
    
    limited_df = pd.concat([
        df[df["label"] == label].sample(n=500)
        for label in df.label.unique()
    ])
    
    fig = go.Figure(data=go.Scatter(
        x=limited_df.x,
        y=limited_df.y,
        mode='markers',
        marker_color=limited_df.label,
        text=limited_df.text.str[:100]
    ))
    fig.update_layout(title=title)
    fig.show()

Cosine Similarity Loss Evaluation

Lets see how well the cosine similarity loss function has performed. The points are colored by label - blue is good and red is bad.

Code
cosine_trained_prompt = torch.load("/data/blog/2021-05-04-prompt-training-clustering/trained-prompt-1-20-768-cosine.pt")
Code
cosine_validation_outputs = list(
    get_all_outputs(dl=validation_dataloader, model=model, prompt=cosine_trained_prompt)
)
Code
show_clusters(
    outputs=cosine_validation_outputs,
    dl=validation_dataloader,
    title="Cosine Similarity Clustering"
)

Well this doesn’t look clearly separated. I think it would be worth trying a different loss function.


Distance Loss Evaluation

Lets see how well the distance loss function has performed. Once again, the points are colored by label - blue is good and red is bad.

Code
distance_validation_outputs = list(
    get_all_outputs(dl=validation_dataloader, model=model, prompt=distance_trained_prompt)
)
Code
show_clusters(
    outputs=distance_validation_outputs,
    dl=validation_dataloader,
    title="Distance Clustering"
)

So it’s still having a rough time with a problem that this technique dealt with extremely well before.


Label Clustering

I suspect that the loss of the language model layer is actually a big problem, and that if I was to measure the clustering of the original trained prompt I would find that it clusters poorly. If this is the case then I should redo the training including the language model head.

Code
original_prompt = torch.load("/data/blog/2021-04-13-dreaming-of-prompts/trained-prompt-1-20-768-002.pt")
Code
original_validation_outputs = list(
    get_all_outputs(dl=validation_dataloader, model=model, prompt=original_prompt)
)
Code
show_clusters(
    outputs=original_validation_outputs,
    dl=validation_dataloader,
    title="Label Clustering"
)

So this isn’t separated either. I need to redo this evaluation while retaining the language model head. I’ll do that in another post as this has become quite long already.