Dreaming of Prompts

Using DeepDream techniques to generate a language model prompt
Published

April 13, 2021

I’ve been thinking about Language Model prompts recently. They can be used to perform natural language tasks without retraining. This is due to the deep understanding that modern language models internalize as part of the training. GPT-3 has even internalized enough to start being able to perform arithmetic.

The biggest problem is coming up with an appropriate prompt for the task. So maybe data science is the search for the best prompt?

prompt engineering

I speak adequate English and no other languages. I am unlikely to write a poem or become a wordsmith. Inventing an appropriate prompt seems to be a hard problem.

I don’t want to spend time trying to come up with ways to trick the language model into providing the results I want. Tricking language models won’t make me a better writer. It would be better to come up with a way to produce the correct prompt from the input and results, much like how a neural net works to begin with.


Deep Dreaming of Prompts

DeepDream was created in 2014 and involves using back propagation to alter the input instead of the model. The input is changed to produce a certain kind of strong output, which leads to psychedelic imagery.

Deep Dreaming of Cats

I want to perform this same approach using a language model. The prompt is part of the input to the language model and I want to “train” the prompt to perform the task that I desire.


Implementing Prompt Training

So how is this going to work? Let’s start with an idea of how a neural network is trained.

An optimizer collects model parameters that are to be optimized. The input to the model is passed through the model and the influence that the parameters have over the output quality is tracked. Quality must be a positive scalar value where lower is better.

This quality is referred to as the loss, and a loss of zero is produced by the best model possible.

When deep dreaming the model is not the target of the optimizer, instead the input is.

This means that the adjustments that the optimizer makes change the input image instead of the model. In the same way I want to change the prompt. The prompt is unusual though as it is normally part of the input instead of the whole input.

This is where GPT-2 is useful. GPT-2 has the ability to use previously computed output as part of the calculation, referred to as the past. This is a separate input. By pushing the tweet text to the past I can easily define the prompt as an optimizable parameter.

This is a good start but it’s not enough. The prompt is text, which can’t be optimized because it’s not a tensor. The tokenized prompt is a long tensor, which can’t be optimized because it’s not continuous.

The first stage of a language model is an embedding layer. This is a simple layer that converts each token into a 1 dimensional float tensor, so the sequence of tokens becomes a 2 dimensional float tensor. At this point it can be optimized.

This involves altering the model though, as the GPT-2 model is expecting to receive the tokenized prompt as it’s input. So a little surgery is required.

Let’s start thinking about the dataset now.


Dataset

Since this is just a simple evaluation of my idea I’m going to use a simple dataset. A sentiment dataset should be sufficiently simple. Sentiment140 looks like a suitable dataset for my task.

Code
from pathlib import Path
import pandas as pd

SENTIMENT_DATASET = Path("/data/blog/2021-04-13-dreaming-of-prompts/training.1600000.processed.noemoticon.csv")

# there is a bad utf-8 character in the file :(
with open(SENTIMENT_DATASET, encoding="utf-8", errors="ignore") as handle:
    df = pd.read_csv(
        handle,
        names=["sentiment", "id", "date", "query", "user", "text"],
    )
df = df[["sentiment", "text"]]
df
sentiment text
0 0 @switchfoot http://twitpic.com/2y1zl - Awww, t...
1 0 is upset that he can't update his Facebook by ...
2 0 @Kenichan I dived many times for the ball. Man...
3 0 my whole body feels itchy and like its on fire
4 0 @nationwideclass no, it's not behaving at all....
... ... ...
1599995 4 Just woke up. Having no school is the best fee...
1599996 4 TheWDB.com - Very cool to hear old Walt interv...
1599997 4 Are you ready for your MoJo Makeover? Ask me f...
1599998 4 Happy 38th Birthday to my boo of alll time!!! ...
1599999 4 happy #charitytuesday @theNSPCC @SparksCharity...

1600000 rows × 2 columns

df.sentiment.value_counts()
0    800000
4    800000
Name: sentiment, dtype: int64

0 is negative sentiment and 4 is positive sentiment. So this is a binary classification problem, which should be pretty easy.

Since I want to train a language model, I’m going to map these to words.

df = df.copy()
df["sentiment"] = df.sentiment.map({0: "bad", 4: "good"})
df
sentiment text
0 bad @switchfoot http://twitpic.com/2y1zl - Awww, t...
1 bad is upset that he can't update his Facebook by ...
2 bad @Kenichan I dived many times for the ball. Man...
3 bad my whole body feels itchy and like its on fire
4 bad @nationwideclass no, it's not behaving at all....
... ... ...
1599995 good Just woke up. Having no school is the best fee...
1599996 good TheWDB.com - Very cool to hear old Walt interv...
1599997 good Are you ready for your MoJo Makeover? Ask me f...
1599998 good Happy 38th Birthday to my boo of alll time!!! ...
1599999 good happy #charitytuesday @theNSPCC @SparksCharity...

1600000 rows × 2 columns

I could clean these up to remove the @handles and urls and such like but I’m not sure I need to bother. The one thing I want to be sure of is that the sentiment word that I have chosen is a single token in the tokenizer. This is because I want to measure the quality of the prompt based on the relative ranking of these two tokens.

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer("good")
{'input_ids': [11274], 'attention_mask': [1]}
tokenizer("bad")
{'input_ids': [14774], 'attention_mask': [1]}
GOOD_TOKEN = 11274
BAD_TOKEN = 14774

At this point we have our targets and the dataset. The next thing is to produce the past.


Splitting Tweet Text and Prompt

So the past output is produced when the model runs over some input. It’s quite straightforward to produce.

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("gpt2")
from typing import *
import torch

Past = Tuple[Tuple[torch.Tensor, ...], ...]

@torch.no_grad()
def get_past(
    text: str,
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
) -> Past:
    tokens = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
    ).to(model.device)
    return model(**tokens).past_key_values
len(get_past("hello world", model=model, tokenizer=tokenizer))
12

These past values will become the tweet input.


Working with a Trainable Prompt

The model is expecting the tokens for the prompt, which we cannot optimize. So instead we need to alter the model to accept the embedding instead. Since we are using the model to produce the past, above, this modification should be done in a way that doesn’t break the existing functionality.

After looking at the model carefully I can see that there is actually support for an input_embeds parameter in the model.transformer.forward method. These embeddings are produced using inputs_embeds = self.wte(input_ids), so I can first try using that. The biggest problem is that these are not exposed through the model forward method.

with torch.no_grad():
    tokens = tokenizer("hello world", return_tensors="pt")["input_ids"]
    token_logits = model(input_ids=tokens).logits
    transformer_state = model.transformer(input_ids=tokens).last_hidden_state
    transformer_logits = model.lm_head(transformer_state)
torch.all(torch.eq(token_logits, transformer_logits))
tensor(True)

As you can see the last state of the transformer is collected and then passed to the classification head. Since this GPT2 model has been loaded as a language model that is a language model head. I should now be able to compare the embedding approach.

with torch.no_grad():
    tokens = tokenizer("hello world", return_tensors="pt")["input_ids"]
    token_logits = model(input_ids=tokens).logits
    embeds = model.transformer.wte(tokens)
    embed_state = model.transformer(inputs_embeds=embeds).last_hidden_state
    embed_logits = model.lm_head(embed_state)
torch.all(torch.eq(token_logits, embed_logits))
tensor(True)

So we can see that the input_embeds can be used as a substitution for the input_ids quite easily.


Experiment

Now I can try training my prompt. The first thing is to have a train / validation split, and a dataloader that lets me prepare the data.

from sklearn.model_selection import train_test_split

train_df, validation_df = train_test_split(df, test_size=10_000)

Dataloader

The dataloader preprocesses the tweets to make it easier to train or perform inference. It has two modes of operation:

  • If you iterate over it, it will endlessly produce random batches. A batch contains unique tweets but future batches may include past tweets. This is useful for training.
  • If you access it by index then you will receive a fixed batch. This is useful for validation.

I could put more effort in and track the tweets that have been issued however the training runs that are performed are substantially smaller than the full dataset (most of them are 100 batches, so 3,200 tweets out of nearly 1.6 million). I don’t think that there is a big problem with selecting the same tweet twice.

Code
class PastDataloader:
    def __init__(
        self,
        model: AutoModelForCausalLM,
        tokenizer: AutoTokenizer,
        df: pd.DataFrame,
        batch_size: int,
        device: torch.device = torch.device("cuda")
    ) -> None:
        tokenizer.pad_token = tokenizer.eos_token # needed to enable padding
        model.to(device)
        
        self.model = model
        self.tokenizer = tokenizer
        self.df = df
        self.batch_size = batch_size
        self.device = device

    def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
        """ Returns an iterator that randomly samples preprocessed tweets.
            Used for training, this is sampling with replacement.
            An individual batch has unique elements. """
        while True:
            yield self.get()

    # Get a random selection of tweets
    def get(self) -> Dict[str, torch.Tensor]:
        """ Get a random selection of preprocessed tweets. """
        rows = self.df.sample(n=self.batch_size)
        return self._get(rows)

    def __len__(self) -> int:
        """ Returns the total number of full batches that can be returned. """
        return len(self.df) // self.batch_size

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """ Returns unshuffled tweets, allowing every tweet to be collected.
            Used for validation. """
        # get without shuffling
        start = idx * self.batch_size
        end = start + self.batch_size
        rows = self.df.loc[self.df.index[start:end]]
        return self._get(rows)

    def _get(self, rows: pd.DataFrame) -> Dict[str, torch.Tensor]:
        tokens = self.tokenizer(
            rows.text.tolist(),
            return_tensors="pt",
            padding=True
        ).to(self.device)
        past_key_values = self.model(**tokens).past_key_values
        labels = torch.tensor([
            GOOD_TOKEN if label == "good" else BAD_TOKEN
            for label in rows.sentiment
        ], dtype=torch.long, device=self.device)
        return {
            "past_key_values": past_key_values,
            "attention_mask": tokens["attention_mask"],
            "labels": labels
        }
train_dataloader = PastDataloader(model=model, tokenizer=tokenizer, df=train_df, batch_size=32)

Training Loop

Now we have some involved code to perform the training. This allows you to provide a custom loss function and will randomly generate the prompt from the model embeddings. If the loss becomes low enough then it can stop training early.

Code
from tqdm.auto import tqdm

def train(
    dl: PastDataloader,
    model: AutoModelForCausalLM,
    prompt_tokens: int,
    rows: int,
    log_step: int,
    loss_fn: Callable[[torch.Tensor, str], torch.Tensor],
    early_stopping: bool = True
) -> torch.Tensor:
    batch_size = dl.batch_size
    prompt, prompt_attention = make_prompt(
        model=model,
        prompt_tokens=prompt_tokens,
        batch_size=batch_size,
        device=dl.device
    )
    
    # optimize just the prompt
    optimizer = torch.optim.Adam([prompt], lr=1e-3)

    total_loss = 0.

    for idx, batch in tqdm(zip(range(rows), dl), total=rows):
        optimizer.zero_grad()

        logits = get_output(
            model=model,
            prompt=prompt,
            attention_mask=prompt_attention,
            past=batch["past_key_values"],
            past_attention_mask=batch["attention_mask"],
            batch_size=batch_size,
        )
        loss = loss_fn(logits, batch["labels"])

        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        if idx and idx % log_step == 0:
            average_loss = total_loss/log_step
            print(f"loss: mean {average_loss:0.4f}, last {loss.item():0.4f}")

            if early_stopping and average_loss < 0.1:
                print("stopping early...")
                break
            total_loss = 0.
    return prompt.data

def make_prompt(
    model: AutoModelForCausalLM,
    prompt_tokens: int,
    batch_size: int,
    device: torch.device
) -> Tuple[torch.nn.Parameter, torch.Tensor]:
    """ Generate the prompt by randomly choosing tokens and then converting to embeddings """
    prompt_indexes = torch.randint(
        size=(prompt_tokens,),
        low=0,
        high=tokenizer.vocab_size,
        device=device
    )
    prompt_attention = torch.ones(
        size=(batch_size, prompt_tokens),
        dtype=torch.long,
        device=device
    )
    prompt = torch.nn.Parameter(
        model.transformer.wte(prompt_indexes).clone()[None, :, :]
    )
    return prompt, prompt_attention

def get_output(
    model: AutoModelForCausalLM,
    prompt: torch.nn.Parameter,
    attention_mask: torch.Tensor,
    past: Past,
    past_attention_mask: torch.Tensor,
    batch_size: int,
) -> torch.Tensor:
    """ Get the predictions for the next token after the prompt """
    # concatenate the past attention with the prompt attention
    attention_mask = torch.cat([
        past_attention_mask, attention_mask
    ], dim=-1)

    # expand the prompt to match the batch size
    input_ids = prompt.repeat_interleave(batch_size, dim=0)
    
    state = model.transformer(
        inputs_embeds=input_ids,
        attention_mask=attention_mask,
        past_key_values=past,
    ).last_hidden_state
    logits = model.lm_head(state)
    return logits[:, -1]

Evaluation

It’s important to know how the trained prompt performs. To calculate this I need a way to run the dataloader through the model and evaluate the accuracy of the classifications. Adjusting the training loop and using the non-random approach for the dataloader makes this quite easy.

I’ve got two approaches to training - one which only considers the “good” and “bad” tokens, and one which considers all tokens. As such I need two accuracy measures. If I didn’t do this then the accuracy for the restricted prompts would be 0%, as the training does not adjust the prompt based on the scores of other tokens.

Code
valid_dataloader_quick = PastDataloader(
    model=model,
    tokenizer=tokenizer,
    df=validation_df[:100],
    batch_size=32,
)
valid_dataloader = PastDataloader(
    model=model,
    tokenizer=tokenizer,
    df=validation_df,
    batch_size=32,
)
Code
from sklearn.metrics import classification_report

@torch.no_grad()
def restricted_accuracy(
    dl: PastDataloader,
    model: AutoModelForCausalLM,
    prompt: torch.Tensor,
) -> None:
    batch_size = dl.batch_size
    prompt_attention = torch.ones((dl.batch_size, prompt.shape[1]), device=dl.device)
    
    predictions = []
    targets = []

    for idx in tqdm(range(len(dl)), total=len(dl)):
        batch = dl[idx]

        logits = get_output(
            model=model,
            prompt=prompt,
            attention_mask=prompt_attention,
            past=batch["past_key_values"],
            past_attention_mask=batch["attention_mask"],
            batch_size=batch_size,
        )
        
        # restrict the accuracy measurement to just the good and bad tokens
        logits = logits[:, [GOOD_TOKEN, BAD_TOKEN]]
        
        targets.extend(
            # bad = 1
            (batch["labels"] == BAD_TOKEN).long().tolist()
        )
        predictions.extend(
            logits.argmax(dim=-1).tolist()
        )

    print(classification_report(y_true=targets, y_pred=predictions, target_names=["good", "bad"]))
Code
@torch.no_grad()
def accuracy(
    dl: PastDataloader,
    model: AutoModelForCausalLM,
    prompt: torch.Tensor,
) -> float:
    batch_size = dl.batch_size
    prompt_attention = torch.ones((dl.batch_size, prompt.shape[1]), device=dl.device)

    predictions = []
    targets = []

    for idx in tqdm(range(len(dl)), total=len(dl)):
        batch = dl[idx]

        logits = get_output(
            model=model,
            prompt=prompt,
            attention_mask=prompt_attention,
            past=batch["past_key_values"],
            past_attention_mask=batch["attention_mask"],
            batch_size=batch_size,
        )

        targets.extend(
            # bad = 1
            (batch["labels"] == BAD_TOKEN).long().tolist()
        )
        predictions.extend(
            0 if prediction == GOOD_TOKEN else 
            (1 if prediction == BAD_TOKEN else 2)
            for prediction in logits.argmax(dim=-1).tolist()
        )
    
    print(classification_report(y_true=targets, y_pred=predictions, target_names=["good", "bad"]))

Results

Here we can try out a few different training approaches and see how well they perform.

Difference Loss

The first one is just to take the difference between the confidence for the “good” token and the “bad” token. It’s conceptually simple but the implementation of the loss is somewhat involved.

def difference_loss(
    logits: torch.Tensor,
    label: torch.Tensor,
) -> torch.Tensor:
    difference = (logits[:, GOOD_TOKEN] - logits[:, BAD_TOKEN]).sigmoid()
    loss = ((1 - difference) * (label == GOOD_TOKEN)) + (difference * (label == BAD_TOKEN))
    return loss.sum() / loss.shape[0]
    

trained_prompt = train(
    dl=train_dataloader,
    model=model,
    prompt_tokens=5,
    rows=1_000,
    log_step=100,
    loss_fn=difference_loss
)
loss: mean 0.4742, last 0.4590
loss: mean 0.4298, last 0.4327
loss: mean 0.3962, last 0.4039
loss: mean 0.3566, last 0.3761
loss: mean 0.3510, last 0.4240
loss: mean 0.3245, last 0.3902
loss: mean 0.3212, last 0.2942
loss: mean 0.3082, last 0.2972
loss: mean 0.3064, last 0.3204
trained_prompt.max(), trained_prompt.min(), trained_prompt.mean(), trained_prompt.std()
(tensor(0.5556, device='cuda:0'),
 tensor(-0.6800, device='cuda:0'),
 tensor(0.0008, device='cuda:0'),
 tensor(0.1574, device='cuda:0'))
(
    model.transformer.wte.weight.max(),
    model.transformer.wte.weight.min(),
    model.transformer.wte.weight.mean(),
    model.transformer.wte.weight.std(),
)
(tensor(1.7852, device='cuda:0', grad_fn=<MaxBackward1>),
 tensor(-1.2698, device='cuda:0', grad_fn=<MinBackward1>),
 tensor(0.0004, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(0.1437, device='cuda:0', grad_fn=<StdBackward0>))

So the training looks like it works, however it does result in a prompt that looks quite different to the embedding matrix.

restricted_accuracy(
    dl=valid_dataloader_quick, model=model, prompt=trained_prompt
)

              precision    recall  f1-score   support

        good       0.86      0.65      0.74        46
         bad       0.74      0.90      0.81        50

    accuracy                           0.78        96
   macro avg       0.80      0.78      0.78        96
weighted avg       0.79      0.78      0.78        96

So for a restricted evaluation this isn’t terrible? Maybe the more traditional cross entropy loss would be better.


Restricted Cross Entropy Loss

Cross Entropy is most commonly used for classification loss. Here I use cross entropy over the “good” and “bad” tokens only.

def restricted_ce_loss(logits: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
    target = (label == BAD_TOKEN).long() # bad = 1
    logits = logits[:, [GOOD_TOKEN, BAD_TOKEN]]
    return torch.nn.functional.cross_entropy(logits, target)

trained_prompt = train(
    dl=train_dataloader,
    model=model,
    prompt_tokens=5,
    rows=1_000,
    log_step=100,
    loss_fn=restricted_ce_loss
)
loss: mean 0.7202, last 0.6900
loss: mean 0.6701, last 0.6914
loss: mean 0.6473, last 0.6631
loss: mean 0.6110, last 0.6409
loss: mean 0.5769, last 0.6355
loss: mean 0.5355, last 0.5165
loss: mean 0.5158, last 0.3714
loss: mean 0.5055, last 0.4442
loss: mean 0.4903, last 0.4672
trained_prompt.max(), trained_prompt.min(), trained_prompt.mean(), trained_prompt.std()
(tensor(0.4892, device='cuda:0'),
 tensor(-0.6306, device='cuda:0'),
 tensor(-0.0002, device='cuda:0'),
 tensor(0.1419, device='cuda:0'))
restricted_accuracy(
    dl=valid_dataloader_quick,
    model=model,
    prompt=trained_prompt,
)

              precision    recall  f1-score   support

        good       0.88      0.63      0.73        46
         bad       0.73      0.92      0.81        50

    accuracy                           0.78        96
   macro avg       0.80      0.78      0.77        96
weighted avg       0.80      0.78      0.78        96

Using cross entropy is almost indistinguishable from the direct comparison. The training is working however the prompt embedding still differs quite a bit from the normal ones. I do wonder if this is a problem. I can try training with cross entropy loss against the entire set of tokens to see if that differs.


Cross Entropy Loss

Now training is quite easy as the underlying cross entropy method is perfect.

trained_prompt = train(
    dl=train_dataloader,
    model=model,
    prompt_tokens=5,
    rows=1_000,
    log_step=100,
    loss_fn=torch.nn.functional.cross_entropy
)
loss: mean 4.6426, last 0.7410
loss: mean 0.7333, last 0.6661
loss: mean 0.6904, last 0.6684
loss: mean 0.6539, last 0.8113
loss: mean 0.6363, last 0.7122
loss: mean 0.6151, last 0.8285
loss: mean 0.6153, last 0.4979
loss: mean 0.5884, last 0.6413
loss: mean 0.6095, last 0.6105
trained_prompt.max(), trained_prompt.min(), trained_prompt.mean(), trained_prompt.std()
(tensor(0.6912, device='cuda:0'),
 tensor(-0.8610, device='cuda:0'),
 tensor(0.0005, device='cuda:0'),
 tensor(0.1548, device='cuda:0'))

So this trained super fast as well. It’s interesting that the loss was comparable to the restricted cross entropy, so restricting the output didn’t seem to achieve that much.

Lets calculate accuracy. Now I can use the full accuracy calculation as I’m no longer restricting the output classes.

accuracy(
    dl=valid_dataloader_quick, model=model, prompt=trained_prompt
)

              precision    recall  f1-score   support

        good       0.80      0.70      0.74        46
         bad       0.75      0.84      0.79        50

    accuracy                           0.77        96
   macro avg       0.78      0.77      0.77        96
weighted avg       0.77      0.77      0.77        96

It’s worse than the restricted case but not terrible.


Prompt Tokens

It would be nice to try to translate the prompt to the nearest tokens to see what it came up with. I’m going to find the closest token by using cosine similarity which is a measurement between -1 (opposite) and 1 (parallel).

Code
def to_nearest_text(
    prompt: torch.Tensor,
    tokenizer: AutoTokenizer
) -> Tuple[str, List[float]]:
    tokens = [
        torch.nn.functional.cosine_similarity(
            trained_prompt[0, i],
            model.transformer.wte.weight,
            dim=-1
        ).max(dim=-1)
        for i in range(trained_prompt.shape[1])
    ]

    similarity = [token.values.item() for token in tokens]
    text = tokenizer.decode([token.indices for token in tokens])
    return text, similarity
to_nearest_text(trained_prompt, tokenizer)
('udence tragedyiden vou mort',
 [0.9209538698196411,
  0.896049976348877,
  0.9400998950004578,
  0.9492326378822327,
  0.9255229234695435])

This looks like junk. Still, the tokens are scoring highly.


Manual Evaluation

The last thing is to just run some text through the model.

Code
@torch.no_grad()
def evaluate(
    text: str,
    prompt: torch.Tensor,
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer
) -> Dict[str, Any]:
    past = get_past(text=text, model=model, tokenizer=tokenizer)
    state = model.transformer(
        inputs_embeds=prompt,
        past_key_values=past,
    ).last_hidden_state
    logits = model.lm_head(state)[0, -1]
    return {
        "prediction": tokenizer.decode(logits.argmax()),
        "good_score": logits[GOOD_TOKEN].item(),
        "bad_score": logits[BAD_TOKEN].item(),
        "prediction_score": logits[logits.argmax()].item()
    }
evaluate(
    "i hate you",
    prompt=trained_prompt, model=model, tokenizer=tokenizer,
)
{'prediction': 'bad',
 'good_score': -36.58251190185547,
 'bad_score': -36.02908706665039,
 'prediction_score': -36.02908706665039}
evaluate(
    "my whole body feels itchy and like its on fire",
    prompt=trained_prompt, model=model, tokenizer=tokenizer,
)
{'prediction': 'bad',
 'good_score': -40.91442108154297,
 'bad_score': -40.47472381591797,
 'prediction_score': -40.47472381591797}
evaluate(
    "yay awesome!",
    prompt=trained_prompt, model=model, tokenizer=tokenizer,
)
{'prediction': 'good',
 'good_score': -54.655311584472656,
 'bad_score': -55.99317169189453,
 'prediction_score': -54.655311584472656}
evaluate(
    "@princess_die good movie pick. ttyl nite",
    prompt=trained_prompt, model=model, tokenizer=tokenizer,
)
{'prediction': 'good',
 'good_score': -60.58497619628906,
 'bad_score': -61.500526428222656,
 'prediction_score': -60.58497619628906}

This spot check seems pretty solid.


Long Train

The loss is still quite high though. I’m just going to randomly change some things now.

I could train this for much longer and then evaluate against the whole validation set. Could also try increasing the number of tokens in the prompt.

trained_prompt = train(
    dl=train_dataloader,
    model=model,
    prompt_tokens=20,
    rows=30_000,
    log_step=3_000,
    loss_fn=torch.nn.functional.cross_entropy,
)
loss: mean 0.6002, last 0.4017
loss: mean 0.4553, last 0.5202
loss: mean 0.4379, last 0.5954
loss: mean 0.4301, last 0.4289
loss: mean 0.4255, last 0.3840
loss: mean 0.4203, last 0.4118
loss: mean 0.4135, last 0.4337
loss: mean 0.4121, last 0.3975
loss: mean 0.4097, last 0.3775
Code
valid_dataloader = PastDataloader(
    model=model, tokenizer=tokenizer, df=validation_df, batch_size=32
)
accuracy(
    dl=valid_dataloader, model=model, prompt=trained_prompt
)

              precision    recall  f1-score   support

        good       0.81      0.83      0.82      5036
         bad       0.82      0.80      0.81      4948

    accuracy                           0.82      9984
   macro avg       0.82      0.82      0.82      9984
weighted avg       0.82      0.82      0.82      9984

So this technique seems to work well. I wonder how it would perform on a more difficult problem. There will be some upper limit to the accuracy that any technique can get with this dataset, it would be good to see how close this technique got to that.

evaluate(
    "i hate you",
    prompt=trained_prompt, model=model, tokenizer=tokenizer,
)
{'prediction': 'bad',
 'good_score': -42.061744689941406,
 'bad_score': -40.328826904296875,
 'prediction_score': -40.328826904296875}
evaluate(
    "my whole body feels itchy and like its on fire",
    prompt=trained_prompt, model=model, tokenizer=tokenizer,
)
{'prediction': 'bad',
 'good_score': -41.80435562133789,
 'bad_score': -38.674713134765625,
 'prediction_score': -38.674713134765625}
evaluate(
    "yay awesome!",
    prompt=trained_prompt, model=model, tokenizer=tokenizer,
)
{'prediction': 'good',
 'good_score': -51.909061431884766,
 'bad_score': -54.857200622558594,
 'prediction_score': -51.909061431884766}
evaluate(
    "@princess_die good movie pick. ttyl nite",
    prompt=trained_prompt, model=model, tokenizer=tokenizer,
)
{'prediction': 'good',
 'good_score': -38.031211853027344,
 'bad_score': -39.029014587402344,
 'prediction_score': -38.031211853027344}
Code
torch.save(trained_prompt, "/data/blog/2021-04-13-dreaming-of-prompts/trained-prompt-1-20-768.pt")

Final Thoughts

This dataset doesn’t seem to be well studied. I couldn’t find any results for it on papers with code. The best I can find is a random kaggle kernel that claims 87% accuracy with BERT (here). However I don’t even see the results of that model so I’m really not sure. This github repo gets 80% accuracy.

Either way, this was about 1h 30m of training on approximately 1m tweets (likely lower because of random selection) and it looks like the results are competitive for this dataset.

model_parameters = sum(p.numel() for p in model.parameters())
prompt_parameters = trained_prompt.numel()

print(f"Fine tuning the prompt alters {100 * prompt_parameters / model_parameters:0.3f}% of the parameters")
print(f"The model has {model_parameters:,} parameters")
Fine tuning the prompt alters 0.012% of the parameters
The model has 124,439,808 parameters

It’s interesting that the model can produce results like these with such a small alteration. This is working with GPT2-small as well, so there could be performance improvements by moving to GPT2-{medium,large}.


Baseline Comparison

What is the performance of the model without any kind of prompt? Is the model already saying good and bad about this without prompting at all?

Let’s try evaluating the model, first restricting it just to the two interesting tokens and then without restriction (so mirroring the training).

Code
@torch.no_grad()
def accuracy_no_prompt_highest_good_bad(
    df: pd.DataFrame,
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer
) -> None:
    predictions = []
    targets = []

    for row in tqdm(df.iloc, total=len(df)):
        tokens = tokenizer(row.text, return_tensors="pt").to(model.device)
        logits = model(**tokens).logits
        predictions.append(logits[0, -1, [GOOD_TOKEN, BAD_TOKEN]].argmax(dim=-1).item())
        targets.append(0 if row.sentiment == "good" else 1)

    print(classification_report(y_true=targets, y_pred=predictions, target_names=["good", "bad"]))
accuracy_no_prompt_highest_good_bad(validation_df, model, tokenizer)

              precision    recall  f1-score   support

        good       0.50      0.96      0.66      5043
         bad       0.51      0.04      0.08      4957

    accuracy                           0.50     10000
   macro avg       0.51      0.50      0.37     10000
weighted avg       0.51      0.50      0.37     10000

An accuracy of 50% for a two class classification problem is basically random. It looks like it heavily favours predicting “good” as well. How well does it perform when we just take the top token?

Code
@torch.no_grad()
def accuracy_no_prompt(
    df: pd.DataFrame,
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer
) -> None:
    predictions = []
    targets = []

    for row in tqdm(df.iloc, total=len(df)):
        tokens = tokenizer(row.text, return_tensors="pt").to(model.device)
        logits = model(**tokens).logits
        prediction = logits[0, -1].argmax(dim=-1).item()
        predictions.append(
            0 if prediction == GOOD_TOKEN else
            (1 if prediction == BAD_TOKEN else 2)
        )
        targets.append(0 if row.sentiment == "good" else 1)

    print(classification_report(
        y_true=targets,
        y_pred=predictions,
        target_names=["good", "bad", "other"],
        zero_division=0
    ))
accuracy_no_prompt(validation_df, model, tokenizer)

              precision    recall  f1-score   support

        good       0.00      0.00      0.00    5043.0
         bad       0.00      0.00      0.00    4957.0
       other       0.00      0.00      0.00       0.0

    accuracy                           0.00   10000.0
   macro avg       0.00      0.00      0.00   10000.0
weighted avg       0.00      0.00      0.00   10000.0

It never predicts the correct token freely. So we can see that the prompt training drastically alters the output. This seems to be a very promising technique!