Prompt Training - Clustering Tokens

Classification with trained prompts by clustering the token confidence
Published

May 6, 2021

In the previous post I investigated moving away from picking the tokens to target. I used that as an opportunity to skip the language model token confidence as well. The results were not great.

Since the prompt training is viable I think that stripping the language model head was a mistake. I’m now going to investigate clustering the outputs once again, but this time after the conversion to token confidences.


Mandatory Block of Code

Here is the dataloader

Code
#collapse

from typing import Dict, Iterator, Optional, Tuple, Union

import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

Past = Tuple[Tuple[torch.Tensor, ...], ...]
TextBatch = Dict[str, torch.Tensor]
PastBatch = Dict[str, Union[torch.Tensor, Past]]


class TextDataloader:
    """Provides a dataloader over a text dataframe"""

    def __init__(
        self,
        df: pd.DataFrame,
        *,
        tokenizer: AutoTokenizer,
        batch_size: int,
        max_length: int,
        device: torch.device = torch.device("cuda"),
        shuffle: bool = True,
    ) -> None:
        self.tokenizer = tokenizer
        self.df = df
        self.batch_size = batch_size
        self.max_length = max_length
        self.device = device
        self.shuffle = shuffle

    def __iter__(self) -> Iterator[TextBatch]:
        """Returns an iterator that returns batches.
        The final batch can be a partial batch."""
        if self.shuffle:
            df = self.df.sample(frac=1).reset_index(drop=True)
        else:
            df = self.df
        batch_size = self.batch_size

        for i in range(len(self)):
            start = i * batch_size
            end = start + batch_size
            yield self.to_batch(df[start:end])

    def __len__(self) -> int:
        """Returns the total number of batches that can be returned."""
        full_batches = len(self.df) // self.batch_size
        if len(self.df) % self.batch_size:
            return full_batches + 1
        return full_batches

    def to_batch(self, rows: pd.DataFrame) -> TextBatch:
        """Converts the rows into a batch"""
        tokens = self.tokenizer(
            rows.text.tolist(),
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=self.max_length,
        ).to(self.device)
        labels = torch.tensor(rows.label.tolist(), dtype=torch.long, device=self.device)
        return {
            "input_ids": tokens["input_ids"],
            "attention_mask": tokens["attention_mask"],
            "labels": labels,
        }


class PastDataloader(TextDataloader):  # pylint: disable=too-few-public-methods
    """Provides a dataloader which converts the text into past tensors"""

    def __init__(
        self,
        df: pd.DataFrame,
        *,
        model: AutoModelForCausalLM,
        tokenizer: AutoTokenizer,
        batch_size: int,
        max_length: int,
        label_map: Optional[Dict[str, int]] = None,
        device: torch.device = torch.device("cuda"),
        shuffle: bool = True,
    ) -> None:
        if label_map:
            df = df.copy()
            df["label"] = df.label.map(label_map)
        super().__init__(
            df=df,
            tokenizer=tokenizer,
            batch_size=batch_size,
            max_length=max_length,
            device=device,
            shuffle=shuffle,
        )
        model.to(device)
        self.model = model

    @torch.no_grad()
    def to_batch(self, rows: pd.DataFrame) -> PastBatch:
        batch = super().to_batch(rows)
        past_key_values = self.model(
            input_ids=batch["input_ids"],
            attention_mask=batch.get("attention_mask", None),
        ).past_key_values
        return {
            "past_key_values": past_key_values,
            "attention_mask": batch["attention_mask"],
            "labels": batch["labels"],
        }

Here is the training loop

Code
#collapse

from typing import Callable, Dict, Tuple, Union

import torch
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

LossFunction = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]


def train(
    *,
    dl: PastDataloader,
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    prompt_tokens: int,
    epochs: int,
    loss_fn: LossFunction,
) -> torch.Tensor:
    """Train the prompt"""
    prompt, prompt_attention = _make_prompt(
        model=model,
        tokenizer=tokenizer,
        prompt_tokens=prompt_tokens,
        device=dl.device,
    )

    # optimize just the prompt
    optimizer = torch.optim.Adam([prompt], lr=1e-3)

    total_loss = 0.0

    with tqdm(
        range(epochs), leave=False, bar_format="loss: {postfix[0]:>8.4f}", postfix=[0.0]
    ) as bar:
        for _epoch in bar:
            for batch in tqdm(dl, leave=False):
                total_loss += _process(
                    batch=batch,
                    model=model,
                    optimizer=optimizer,
                    prompt=prompt,
                    prompt_attention=prompt_attention,
                    loss_fn=loss_fn,
                )

            average_loss = total_loss / len(dl)
            bar.postfix[0] = average_loss
            print(f"Average loss: {average_loss:0.4f}")
            total_loss = 0.0

    return prompt.data


def _make_prompt(
    *,
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    prompt_tokens: int,
    device: torch.device,
) -> Tuple[torch.nn.Parameter, torch.Tensor]:
    """Generate the prompt by randomly choosing tokens and then converting to embeddings"""
    prompt_indexes = torch.randint(
        size=(prompt_tokens,), low=0, high=tokenizer.vocab_size, device=device
    )
    prompt_attention = torch.ones(
        size=(1, prompt_tokens), dtype=torch.long, device=device
    )
    prompt = torch.nn.Parameter(
        model.transformer.wte(prompt_indexes).clone()[None, :, :]
    )
    return prompt, prompt_attention


def _process(
    *,
    batch: Dict[str, Union[torch.Tensor, Past]],
    model: AutoModelForCausalLM,
    optimizer: torch.optim.Optimizer,
    prompt: torch.nn.Parameter,
    prompt_attention: torch.Tensor,
    loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
) -> float:
    optimizer.zero_grad()

    logits = _get_output_with_past(
        model=model,
        prompt=prompt,
        attention_mask=prompt_attention,
        past=batch["past_key_values"],
        past_attention_mask=batch["attention_mask"],
    )
    loss = loss_fn(logits, batch["labels"])

    loss.backward()
    optimizer.step()

    return loss.item()


def _get_output_with_past(
    *,
    model: AutoModelForCausalLM,
    prompt: torch.nn.Parameter,
    attention_mask: torch.Tensor,
    past: Past,
    past_attention_mask: torch.Tensor,
) -> torch.Tensor:
    """Get the predictions for the next token after the prompt"""
    # concatenate the past attention with the prompt attention
    batch_size = past_attention_mask.shape[0]
    attention_mask = attention_mask.repeat_interleave(batch_size, dim=0)
    attention_mask = torch.cat([past_attention_mask, attention_mask], dim=-1)

    # expand the prompt to match the batch size
    input_ids = prompt.repeat_interleave(batch_size, dim=0)

    state = model.transformer(
        inputs_embeds=input_ids,
        attention_mask=attention_mask,
        past_key_values=past,
    ).last_hidden_state
    logits = model.lm_head(state)
    return logits[:, -1]

Here is a method to get the token confidence for a piece of text

Code
#collapse

@torch.no_grad()
def get_output(
    text: str,
    prompt: torch.Tensor,
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer
) -> torch.Tensor:
    tokens = (
        tokenizer(text, return_tensors="pt")["input_ids"]
            .to("cuda")
    )
    token_embedding = model.transformer.wte(tokens)
    
    # join the tensors - dim 0 is the batch, 1 is the tokens, 2 is the specific embedding value
    full_embedding = torch.cat([token_embedding, prompt], dim=1)
    state = model.transformer(inputs_embeds=full_embedding).last_hidden_state
    logits = model.lm_head(state)
    return logits[0, -1]

Here is the imdb data loading code

Code
#collapse

from pathlib import Path
import pandas as pd

def load(path: Path) -> pd.DataFrame:
    positive_files = sorted(path.glob("pos/*.txt"))
    negative_files = sorted(path.glob("neg/*.txt"))
    
    return pd.DataFrame(
        [
            {"label": "good", "text": file.read_text()}
            for file in positive_files
        ] +
        [
            {"label": "bad", "text": file.read_text()}
            for file in negative_files
        ]
    )

Now lets load the deep learning model and datasets

Code
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("gpt2")
model.to("cuda")
model.eval()

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token # needed to enable padding
Code
train_df = load(Path("/data/sentiment/imdb-movie-reviews/train"))
validation_df = load(Path("/data/sentiment/imdb-movie-reviews/test"))
Code
BATCH_SIZE = 32
MAX_LENGTH = 1_000

train_dataloader = PastDataloader(
    model=model,
    tokenizer=tokenizer,
    df=train_df,
    batch_size=BATCH_SIZE,
    max_length=MAX_LENGTH,
    shuffle=True,
    label_map={"bad": 0, "good": 1},
)
validation_dataloader = PastDataloader(
    model=model,
    tokenizer=tokenizer,
    df=validation_df,
    batch_size=BATCH_SIZE,
    max_length=MAX_LENGTH,
    shuffle=False,
    label_map={"bad": 0, "good": 1},
)

We can finally get on with training.


Training

Now that we have the code and dataloaders, we can use the different training approaches.

Cosine Similarity Loss Training

So now we have to consider the loss function. Ideally it would be fast to compute, as I will need to run it against every pair of rows in the batch. For now let’s just take the cosine similarity and then see if the labels match.

Code
def cosine_loss_fn(output: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
    batch_size = output.shape[0]
    
    cycled_output = output.repeat((batch_size, 1))
    # This repeats the tensor as if using cycle()
    # [1, 2, 3] -> [1, 2, 3, 1, 2, 3...]
    interleaved_output = output.repeat_interleave(batch_size, dim=0)
    # This repeats each element of the tensor
    # [1, 2, 3] -> [1, 1.., 2, 2.., 3, 3..]

    cycled_labels = labels.repeat(batch_size)
    interleaved_labels = labels.repeat_interleave(batch_size, dim=0)
    
    repeated_labels = (cycled_labels == interleaved_labels).long()
    repeated_labels = (repeated_labels * 2) - 1
    # label needs to be -1 for different or 1 for same
    # true  -> 1 -> 1*2-1 ->  1
    # false -> 0 -> 0*2-1 -> -1
    
    return torch.nn.functional.cosine_embedding_loss(
        cycled_output,
        interleaved_output,
        repeated_labels
    )

I’m not 100% sure this is correct but I’m going to give it a go anyway.

Code
cosine_trained_prompt = train(
    dl=train_dataloader,
    model=model,
    tokenizer=tokenizer,
    prompt_tokens=5,
    epochs=3,
    loss_fn=cosine_loss_fn
)
Average loss: 0.4848
Average loss: 0.4854
Average loss: 0.4842
Code
bad_output = get_output(
    "What a script, what a story, what a mess!",
    prompt=cosine_trained_prompt,
    model=model,
    tokenizer=tokenizer
)
good_output = get_output(
    "Brilliant and moving performances by Tom Courtenay and Peter Finch.",
    prompt=cosine_trained_prompt,
    model=model,
    tokenizer=tokenizer
)
Code
(
    tokenizer.decode(bad_output.argmax()),
    tokenizer.decode(good_output.argmax())
)
('otle', ' and')
Code
torch.cosine_similarity(bad_output, good_output, dim=0).item()
0.9999494552612305
Code
torch.pairwise_distance(bad_output[None, :], good_output[None, :]).item()
1122.8638916015625
Code
torch.save(
    cosine_trained_prompt,
    "/data/blog/2021-05-06-prompt-training-clustering-again/trained-prompt-1-20-768-cosine.pt"
)

Distance Loss Training

Instead of cosine, which is based on the direction of the vector, I am going to try the absolute distance between the points that each vector represents. If the points are for the same class then the distance should be minimized, and if they are for different classes then it should be maximized.

My initial thought is to have two functions:

\[ \begin{aligned} L_{same} &= dist(x_a, x_b) \\ L_{different} &= \frac{1}{dist(x_a, x_b)} \end{aligned} \]

Because for the same class as the distance tends to zero, the loss will too. Equally for different classes as the distance tends to infinity, the loss will tend to zero. These may need balancing in some way.

Code
def distance_loss_fn(
    output: torch.Tensor,
    labels: torch.Tensor
) -> torch.Tensor:
    batch_size = output.shape[0]
    
    cycled_output = output.repeat((batch_size, 1))
    # This repeats the tensor as if using cycle()
    # [1, 2, 3] -> [1, 2, 3, 1, 2, 3...]
    interleaved_output = output.repeat_interleave(batch_size, dim=0)
    # This repeats each element of the tensor
    # [1, 2, 3] -> [1, 1.., 2, 2.., 3, 3..]

    cycled_labels = labels.repeat(batch_size)
    interleaved_labels = labels.repeat_interleave(batch_size, dim=0)
    
    different_labels_mask = cycled_labels != interleaved_labels
    
    distance = torch.pairwise_distance(
        cycled_output,
        interleaved_output
    )

    # there is a problem with this,
    # it is considered an in-place operation which causes gradient calculations to fail
    # distance[different_labels_mask] = 1 / distance[different_labels_mask]
    
    return (
        distance[~different_labels_mask].sum()
        + (1 / distance[different_labels_mask]).sum()
    )
Code
distance_trained_prompt = train(
    dl=train_dataloader,
    model=model,
    tokenizer=tokenizer,
    prompt_tokens=5,
    epochs=3,
    loss_fn=distance_loss_fn
)
Average loss: 574885.1110
Average loss: 149492.0697
Average loss: 84900.0906
Code
bad_output = get_output(
    "What a script, what a story, what a mess!",
    prompt=distance_trained_prompt,
    model=model,
    tokenizer=tokenizer
)
good_output = get_output(
    "Brilliant and moving performances by Tom Courtenay and Peter Finch.",
    prompt=distance_trained_prompt,
    model=model,
    tokenizer=tokenizer
)
Code
(
    tokenizer.decode(bad_output.argmax()),
    tokenizer.decode(good_output.argmax())
)
('ed', 'ed')
Code
torch.cosine_similarity(bad_output, good_output, dim=0).item()
0.9999787211418152
Code
torch.pairwise_distance(bad_output[None, :], good_output[None, :]).item()
1688.7255859375
Code
torch.save(
    distance_trained_prompt,
    "/data/blog/2021-05-06-prompt-training-clustering-again/trained-prompt-1-20-768-distance.pt"
)

Evaluation

Now that we have a trained prompt we can try evaluating it. Evaluating the prompt is hard because we don’t know what output corresponds to a given class. The training just aims to separate the outputs for the two classes.

So I think the evaluation should try to visualize the outputs for the different classes and then we can see if they are separate. The code in this section will collect the outputs for the validation set and then use PCA to reduce them to two dimensions. At that point they can be visualized.

Code
#collapse

@torch.no_grad()
def get_all_outputs(
    dl: PastDataloader,
    model: AutoModelForCausalLM,
    prompt: torch.Tensor
) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]:
    prompt_attention = torch.ones((1, prompt.shape[1]), device=dl.device)

    for batch in tqdm(dl):
        logits = _get_output_with_past(
            model=model,
            prompt=prompt,
            attention_mask=prompt_attention,
            past=batch["past_key_values"],
            past_attention_mask=batch["attention_mask"],
        )
        yield logits.cpu(), batch["labels"].cpu()
Code
#collapse

def show_clusters(
    outputs: Iterator[Tuple[torch.Tensor, torch.Tensor]],
    dl: PastDataloader,
    title: str,
    points: int = 500,
) -> None:
    outputs = list(outputs)
    all_outputs = torch.cat([
        output
        for output, _ in outputs
    ], dim=0)
    all_labels = torch.cat([
        label
        for _, label in outputs
    ])
    pca_outputs = torch.pca_lowrank(all_outputs, q=2)
    
    df = pd.DataFrame([
        {
            "x": output[0].item(),
            "y": output[1].item(),
            "label": "blue" if label.item() else "red",
            "text": text
        }
        for output, label, text in zip(
            pca_outputs[0],
            all_labels,
            dl.df.text
        )
    ])
    
    limited_df = pd.concat([
        df[df["label"] == label].sample(n=points//2)
        for label in df.label.unique()
    ])
    
    df.plot.scatter(x="x", y="y", s=0.1, c="label")

Cosine Similarity Loss Evaluation

Lets see how well the cosine similarity loss function has performed. The points are colored by label - blue is good and red is bad.

Code
cosine_trained_prompt = torch.load(
    "/data/blog/2021-05-06-prompt-training-clustering-again/trained-prompt-1-20-768-cosine.pt"
)
Code
cosine_validation_outputs = list(
    get_all_outputs(dl=validation_dataloader, model=model, prompt=cosine_trained_prompt)
)
Code
show_clusters(
    outputs=cosine_validation_outputs,
    dl=validation_dataloader,
    title="Cosine Similarity Clustering"
)

Well this doesn’t look clearly separated. I think it would be worth trying a different loss function.


Distance Loss Evaluation

Lets see how well the distance loss function has performed. Once again, the points are colored by label - blue is good and red is bad.

Code
distance_trained_prompt = torch.load(
    "/data/blog/2021-05-06-prompt-training-clustering-again/trained-prompt-1-20-768-distance.pt"
)
Code
distance_validation_outputs = list(
    get_all_outputs(dl=validation_dataloader, model=model, prompt=distance_trained_prompt)
)
Code
show_clusters(
    outputs=distance_validation_outputs,
    dl=validation_dataloader,
    title="Distance Clustering"
)

So it’s still having a rough time with a problem that this technique dealt with extremely well before.


Label Clustering

I suspect that PCA is not separating the clusters because the dominant dimensions that the points differ on are not the same dimensions that all of the points primarily vary on. Since the original prompts worked quite well for this task if I use this same visualization on them then it should be possible to see if PCA is actually helping. If it isn’t then I’ll have to come up with another assessment method.

Code
original_prompt = torch.load(
    "/data/blog/2021-04-13-dreaming-of-prompts/trained-prompt-1-20-768-002.pt"
)
Code
original_validation_outputs = list(
    get_all_outputs(dl=validation_dataloader, model=model, prompt=original_prompt)
)
Code
show_clusters(
    outputs=original_validation_outputs,
    dl=validation_dataloader,
    title="Label Clustering"
)

So this isn’t separated either. I need to review how to evaluate this training approach.

As a last spot check I can review the two reviews I have been using for the manual evaluation. This is more of a sanity check to confirm that I am loading and using this prompt correctly.

Code
bad_output = get_output(
    "What a script, what a story, what a mess!",
    prompt=original_prompt,
    model=model,
    tokenizer=tokenizer
)
good_output = get_output(
    "Brilliant and moving performances by Tom Courtenay and Peter Finch.",
    prompt=original_prompt,
    model=model,
    tokenizer=tokenizer
)
Code
(
    tokenizer.decode(bad_output.argmax()),
    tokenizer.decode(good_output.argmax())
)
('bad', 'good')
Code
torch.cosine_similarity(bad_output, good_output, dim=0).item()
0.9998844265937805
Code
torch.pairwise_distance(bad_output[None, :], good_output[None, :]).item()
438.8519287109375

So the reviews are correctly classified. It’s interesting that the cosine similarity is so high and that the distance between the tokens is so low. This does suggest that the training is doing something. I need to determine more systematically if these clusters are separable.