Prompt Training - IMDB Movie Review Sentiment

Evaluating prompt training on a more popular dataset
Published

April 28, 2021

I’ve recently investigated using prompt training and I’ve read a paper by google on the same. The biggest problem with my evaluation was the inability to compare my results to the work of others. It just seems that the sentiment140 dataset is not widely studied.

So I’ve found another dataset which is more popular {% cite maas-EtAl:2011:ACL-HLT2011 %}. There are several entries for it on papers with code and it’s another binary sentiment classification problem, so I should be able to apply the same techniques as before.

This is going to be a comparatively focused evaluation of prompt training GPT-2 small on the dataset. If you want more details on the technique then consider reading my previous posts (my proposal or the google paper review).


Dataset

The dataset is 50,000 movie reviews, split into equal numbers of positive and negative reviews. A test train split has already been provided (50/50 ratio) which I will honor because I want to compare the results of prompt training to other work on this. Each review is in a separate file.

The first thing to do will be to load the data into dataframes.

Code
from pathlib import Path
import pandas as pd

def load(path: Path) -> pd.DataFrame:
    positive_files = sorted(path.glob("pos/*.txt"))
    negative_files = sorted(path.glob("neg/*.txt"))
    
    return pd.DataFrame(
        [
            {"sentiment": "good", "text": file.read_text()}
            for file in positive_files
        ] +
        [
            {"sentiment": "bad", "text": file.read_text()}
            for file in negative_files
        ]
    )
Code
train_df = load(Path("/data/sentiment/imdb-movie-reviews/train"))
validation_df = load(Path("/data/sentiment/imdb-movie-reviews/test"))
Code
train_df
sentiment text
0 good Bromwell High is a cartoon comedy. It ran at t...
1 good Homelessness (or Houselessness as George Carli...
2 good Brilliant over-acting by Lesley Ann Warren. Be...
3 good This is easily the most underrated film inn th...
4 good This is not the typical Mel Brooks film. It wa...
... ... ...
24995 bad Towards the end of the movie, I felt it was to...
24996 bad This is the kind of movie that my enemies cont...
24997 bad I saw 'Descent' last night at the Stockholm Fi...
24998 bad Some films that you pick up for a pound turn o...
24999 bad This is one of the dumbest films, I've ever se...

25000 rows × 2 columns


Data Loader

I can reuse the dataloader from my previous evaluation as the dataframe matches. Once again I am using GPT-2 so the past is available. I’ve adjusted the dataloader to better match regular dataloaders - it works by epoch now. This is because the dataset is small enough to want to iterate over several times.

Code
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
model.eval() ; None
Code
from typing import *
import torch

Past = Tuple[Tuple[torch.Tensor, ...], ...]

class PastDataloader:
    def __init__(
        self,
        model: AutoModelForCausalLM,
        tokenizer: AutoTokenizer,
        df: pd.DataFrame,
        batch_size: int,
        max_length: int,
        device: torch.device = torch.device("cuda"),
        shuffle: bool = True,
    ) -> None:
        tokenizer.pad_token = tokenizer.eos_token # needed to enable padding
        model.to(device)
        
        self.model = model
        self.tokenizer = tokenizer
        self.df = df
        self.batch_size = batch_size
        self.max_length = max_length
        self.device = device
        self.shuffle = shuffle

    def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
        """ Returns an iterator that returns the batched rows in a random order.
            This always returns full batches. """
        if self.shuffle:
            df = self.df.sample(frac=1).reset_index(drop=True)
        else:
            df = self.df
        batch_size = self.batch_size

        for i in range(len(self)):
            start = i * batch_size
            end = start + batch_size
            yield self._get(df[start:end])

    def __len__(self) -> int:
        """ Returns the total number of full batches that can be returned. """
        return len(self.df) // self.batch_size

    @torch.no_grad()
    def _get(self, rows: pd.DataFrame) -> Dict[str, Union[torch.Tensor, Past]]:
        tokens = self.tokenizer(
            rows.text.tolist(),
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=self.max_length
        ).to(self.device)
        past_key_values = self.model(**tokens).past_key_values
        labels = torch.tensor([
            GOOD_TOKEN if label == "good" else BAD_TOKEN
            for label in rows.sentiment
        ], dtype=torch.long, device=self.device)
        return {
            "past_key_values": past_key_values,
            "attention_mask": tokens["attention_mask"],
            "labels": labels
        }
Code
BATCH_SIZE = 32
MAX_LENGTH = 128

train_dataloader = PastDataloader(
    model=model,
    tokenizer=tokenizer,
    df=train_df,
    batch_size=BATCH_SIZE,
    max_length=MAX_LENGTH,
    shuffle=True,
)
validation_dataloader = PastDataloader(
    model=model,
    tokenizer=tokenizer,
    df=validation_df,
    batch_size=BATCH_SIZE,
    max_length=MAX_LENGTH,
    shuffle=False,
)

Training

Now we need some code for training and evaluation. This is broadly a copy of the code from the previous evaluation, with a couple of changes:

  • Training now works by epoch instead of batch as the dataset is smaller
  • Accuracy takes a function to extract the results from the batch
Code
GOOD_TOKEN = 11274
BAD_TOKEN = 14774
Code
from tqdm.auto import tqdm

def train(
    dl: PastDataloader,
    model: AutoModelForCausalLM,
    prompt_tokens: int,
    epochs: int,
    loss_fn: Callable[[torch.Tensor, str], torch.Tensor],
) -> torch.Tensor:
    batch_size = dl.batch_size
    prompt, prompt_attention = make_prompt(
        model=model,
        prompt_tokens=prompt_tokens,
        batch_size=batch_size,
        device=dl.device
    )
    
    # optimize just the prompt
    optimizer = torch.optim.Adam([prompt], lr=1e-3)

    total_loss = 0.

    for epoch in tqdm(range(epochs)):
        for batch in tqdm(dl):
            optimizer.zero_grad()

            logits = get_output(
                model=model,
                prompt=prompt,
                attention_mask=prompt_attention,
                past=batch["past_key_values"],
                past_attention_mask=batch["attention_mask"],
                batch_size=batch_size,
            )
            loss = loss_fn(logits, batch["labels"])

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        average_loss = total_loss/len(dl)
        print(f"epoch {epoch}: mean loss {average_loss:0.4f}, last loss {loss.item():0.4f}")

        total_loss = 0.

    return prompt.data

def make_prompt(
    model: AutoModelForCausalLM,
    prompt_tokens: int,
    batch_size: int,
    device: torch.device
) -> Tuple[torch.nn.Parameter, torch.Tensor]:
    """ Generate the prompt by randomly choosing tokens and then converting to embeddings """
    prompt_indexes = torch.randint(
        size=(prompt_tokens,),
        low=0,
        high=tokenizer.vocab_size,
        device=device
    )
    prompt_attention = torch.ones(
        size=(batch_size, prompt_tokens),
        dtype=torch.long,
        device=device
    )
    prompt = torch.nn.Parameter(
        model.transformer.wte(prompt_indexes).clone()[None, :, :]
    )
    return prompt, prompt_attention

def get_output(
    model: AutoModelForCausalLM,
    prompt: torch.nn.Parameter,
    attention_mask: torch.Tensor,
    past: Past,
    past_attention_mask: torch.Tensor,
    batch_size: int,
) -> torch.Tensor:
    """ Get the predictions for the next token after the prompt """
    # concatenate the past attention with the prompt attention
    attention_mask = torch.cat([
        past_attention_mask, attention_mask
    ], dim=-1)

    # expand the prompt to match the batch size
    input_ids = prompt.repeat_interleave(batch_size, dim=0)
    
    state = model.transformer(
        inputs_embeds=input_ids,
        attention_mask=attention_mask,
        past_key_values=past,
    ).last_hidden_state
    logits = model.lm_head(state)
    return logits[:, -1]
Code
from sklearn.metrics import classification_report

@torch.no_grad()
def accuracy(
    dl: PastDataloader,
    model: AutoModelForCausalLM,
    prompt: torch.Tensor,
    label_fn: Callable[[torch.Tensor, torch.Tensor], Tuple[List[int], List[int]]]
) -> None:
    batch_size = dl.batch_size
    prompt_attention = torch.ones((dl.batch_size, prompt.shape[1]), device=dl.device)
    
    labels = []
    predictions = []

    for batch in tqdm(dl):
        logits = get_output(
            model=model,
            prompt=prompt,
            attention_mask=prompt_attention,
            past=batch["past_key_values"],
            past_attention_mask=batch["attention_mask"],
            batch_size=batch_size,
        )

        current_labels, current_predictions = label_fn(
            batch["labels"], logits
        )
        
        labels.extend(current_labels)
        predictions.extend(current_predictions)

    used_labels = sorted(set(predictions))
    if 2 in used_labels:
        # other is present
        print(classification_report(
            y_true=labels,
            y_pred=predictions,
            labels=[0, 1, 2],
            target_names=["good", "bad", "other"],
            zero_division=0
        ))
    else:
        print(classification_report(
            y_true=labels,
            y_pred=predictions,
            target_names=["good", "bad"],
            zero_division=0
        ))

def restricted_labels(labels: torch.Tensor, logits: torch.Tensor) -> Tuple[List[int], List[int]]:
    targets = (
        (labels == BAD_TOKEN) # bad == 1
            .long()
            .tolist()
    )
    predictions = (
        logits[:, [GOOD_TOKEN, BAD_TOKEN]]
            .argmax(dim=-1)
            .tolist()
    )
    return targets, predictions

def full_labels(labels: torch.Tensor, logits: torch.Tensor) -> Tuple[List[int], List[int]]:
    targets = (
        (labels == BAD_TOKEN) # bad == 1
            .long()
            .tolist()
    )
    predictions = (
        0 if prediction == GOOD_TOKEN else 
        (1 if prediction == BAD_TOKEN else 2)
        for prediction in logits.argmax(dim=-1).tolist()
    )
    return targets, predictions

Results

Now lets see how the model performs with our custom prompts. This is going to vary the number of tokens from the text (as some of the reviews are longer than the model can take, the input will be truncated sometimes no matter what we do). We will also evaluate 5 and 20 token prompts as prompts that size performed strongly in the google evaluation.

Baseline Comparison

We can start with a baseline comparison to see what the model is like without a prompt.

Code
@torch.no_grad()
def accuracy_no_prompt_highest_good_bad(
    df: pd.DataFrame,
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer
) -> None:
    predictions = []
    labels = []

    for row in tqdm(df.iloc, total=len(df)):
        tokens = tokenizer(
            row.text,
            return_tensors="pt",
            truncation=True
        ).to(model.device)
        logits = model(**tokens).logits

        prediction = logits[0, -1, [GOOD_TOKEN, BAD_TOKEN]].argmax(dim=-1).item()
        label = 0 if row.sentiment == "good" else 1

        predictions.append(prediction)
        labels.append(label)

    print(classification_report(y_true=labels, y_pred=predictions, target_names=["good", "bad"]))
Code
accuracy_no_prompt_highest_good_bad(validation_df, model, tokenizer)

              precision    recall  f1-score   support

        good       0.59      0.84      0.69     12500
         bad       0.71      0.41      0.52     12500

    accuracy                           0.62     25000
   macro avg       0.65      0.62      0.60     25000
weighted avg       0.65      0.62      0.60     25000

It’s quite interesting that the accuracy of the unaltered model is better than 50%. This suggest to me that the problem is relatively easy.

Restricted Labels

Let’s just compare the good and bad labels and ignore all the others.

Code
def restricted_ce_loss(logits: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
    target = (label == BAD_TOKEN).long() # bad = 1
    logits = logits[:, [GOOD_TOKEN, BAD_TOKEN]]
    return torch.nn.functional.cross_entropy(logits, target)

trained_prompt = train(
    dl=train_dataloader,
    model=model,
    prompt_tokens=5,
    epochs=3,
    loss_fn=restricted_ce_loss
)

epoch 0: mean loss 0.4296, last loss 0.3525

epoch 1: mean loss 0.3503, last loss 0.4284

epoch 2: mean loss 0.3400, last loss 0.2807
Code
accuracy(
    dl=validation_dataloader,
    model=model,
    prompt=trained_prompt,
    label_fn=restricted_labels
)

              precision    recall  f1-score   support

        good       0.84      0.89      0.87     12497
         bad       0.88      0.83      0.86     12495

    accuracy                           0.86     24992
   macro avg       0.86      0.86      0.86     24992
weighted avg       0.86      0.86      0.86     24992

Current SOTA on this dataset is 97.4 with other training data, or 94.5 without other training data (Al-Shedivat, Dubey, and Xing 2020), so I have some way to go. This result is around the bottom of the results in the papers with code charts. I think that using GPT-2 would count as using a model that has other training data (since it is fine tuning a trained language model). It’s interesting to compare to no-extra-data evaluations because that can suggest how well this can perform on novel tasks that may’ve previously been performed by statistical methods (e.g. random forests).

Al-Shedivat, Maruan, Avinava Dubey, and Eric P. Xing. 2020. “Contextual Explanation Networks.” https://arxiv.org/abs/1705.10301.

The next thing would be to test increasing the length of the prompt.

Code
trained_prompt = train(
    dl=train_dataloader,
    model=model,
    prompt_tokens=20,
    epochs=10,
    loss_fn=restricted_ce_loss
)

epoch 0: mean loss 0.4090, last loss 0.3148

epoch 1: mean loss 0.3453, last loss 0.3484

epoch 2: mean loss 0.3325, last loss 0.2168

epoch 3: mean loss 0.3270, last loss 0.2135

epoch 4: mean loss 0.3202, last loss 0.2446

epoch 5: mean loss 0.3174, last loss 0.2110

epoch 6: mean loss 0.3119, last loss 0.3064

epoch 7: mean loss 0.3073, last loss 0.3566

epoch 8: mean loss 0.3038, last loss 0.3662

epoch 9: mean loss 0.3010, last loss 0.1374
Code
accuracy(
    dl=validation_dataloader,
    model=model,
    prompt=trained_prompt,
    label_fn=restricted_labels
)

              precision    recall  f1-score   support

        good       0.89      0.84      0.86     12496
         bad       0.85      0.90      0.87     12496

    accuracy                           0.87     24992
   macro avg       0.87      0.87      0.87     24992
weighted avg       0.87      0.87      0.87     24992

That didn’t improve things very much.


Full Labels

Let’s see about the full cross entropy.

Code
trained_prompt = train(
    dl=train_dataloader,
    model=model,
    prompt_tokens=5,
    epochs=3,
    loss_fn=torch.nn.functional.cross_entropy
)

epoch 0: mean loss 1.3601, last loss 0.4808

epoch 1: mean loss 0.4444, last loss 0.2940

epoch 2: mean loss 0.3792, last loss 0.6485
Code
accuracy(
    dl=validation_dataloader,
    model=model,
    prompt=trained_prompt,
    label_fn=full_labels
)

              precision    recall  f1-score   support

        good       0.83      0.86      0.85     12497
         bad       0.86      0.82      0.84     12495

    accuracy                           0.84     24992
   macro avg       0.84      0.84      0.84     24992
weighted avg       0.84      0.84      0.84     24992
Code
trained_prompt = train(
    dl=train_dataloader,
    model=model,
    prompt_tokens=20,
    epochs=10,
    loss_fn=torch.nn.functional.cross_entropy
)

epoch 0: mean loss 0.7189, last loss 0.3417

epoch 1: mean loss 0.3510, last loss 0.3746

epoch 2: mean loss 0.3398, last loss 0.3530

epoch 3: mean loss 0.3316, last loss 0.2980

epoch 4: mean loss 0.3267, last loss 0.2666

epoch 5: mean loss 0.3209, last loss 0.3737

epoch 6: mean loss 0.3135, last loss 0.1405

epoch 7: mean loss 0.3125, last loss 0.3509

epoch 8: mean loss 0.3072, last loss 0.3693

epoch 9: mean loss 0.3021, last loss 0.1862
Code
accuracy(
    dl=validation_dataloader,
    model=model,
    prompt=trained_prompt,
    label_fn=full_labels
)

              precision    recall  f1-score   support

        good       0.90      0.83      0.86     12496
         bad       0.84      0.90      0.87     12496

    accuracy                           0.87     24992
   macro avg       0.87      0.87      0.87     24992
weighted avg       0.87      0.87      0.87     24992

Looks like it is hard to move the dial. Lets try with more text.


1,000 Tokens

I’ve had to limit the text length because some of the tokenized reviews exceed what the model can handle. It also makes training faster if I limit it to 128 tokens as I can run a batch size of 32. Inferring sentiment from few tokens is harder, so I might be able to improve accuracy by making more tokens available.

The token limit needs to consider the size of the prompt as that contributes to the overall token count. Currently I’m varying the prompt between 5 and 20 tokens. Since the maximum token count is 1,024 I can just let the text go to 1,000 tokens.

Code
BATCH_SIZE = 4
MAX_LENGTH = 1_000

train_dataloader = PastDataloader(
    model=model,
    tokenizer=tokenizer,
    df=train_df,
    batch_size=BATCH_SIZE,
    max_length=MAX_LENGTH,
)
validation_dataloader = PastDataloader(
    model=model,
    tokenizer=tokenizer,
    df=validation_df,
    batch_size=BATCH_SIZE,
    max_length=MAX_LENGTH,
)
Code
trained_prompt = train(
    dl=train_dataloader,
    model=model,
    prompt_tokens=5,
    epochs=3,
    loss_fn=restricted_ce_loss
)

epoch 0: mean loss 0.3105, last loss 0.0847

epoch 1: mean loss 0.2330, last loss 0.1886

epoch 2: mean loss 0.2248, last loss 0.0420
Code
accuracy(
    dl=validation_dataloader,
    model=model,
    prompt=trained_prompt,
    label_fn=restricted_labels
)

              precision    recall  f1-score   support

        good       0.89      0.94      0.91     12500
         bad       0.94      0.88      0.91     12500

    accuracy                           0.91     25000
   macro avg       0.91      0.91      0.91     25000
weighted avg       0.91      0.91      0.91     25000

This is now just 3% off the SOTA for a model that doesn’t use other data.

Code
trained_prompt = train(
    dl=train_dataloader,
    model=model,
    prompt_tokens=20,
    epochs=10,
    loss_fn=restricted_ce_loss
)

epoch 0: mean loss 0.3323, last loss 0.8553

epoch 1: mean loss 0.2383, last loss 0.9740

epoch 2: mean loss 0.2266, last loss 0.3056

epoch 3: mean loss 0.2204, last loss 0.0884

epoch 4: mean loss 0.2163, last loss 0.7145

epoch 5: mean loss 0.2132, last loss 0.2740

epoch 6: mean loss 0.2108, last loss 0.8848

epoch 7: mean loss 0.2071, last loss 0.0230

epoch 8: mean loss 0.2144, last loss 0.2019

epoch 9: mean loss 0.2112, last loss 0.2228
Code
accuracy(
    dl=validation_dataloader,
    model=model,
    prompt=trained_prompt,
    label_fn=restricted_labels
)

              precision    recall  f1-score   support

        good       0.92      0.92      0.92     12500
         bad       0.92      0.92      0.92     12500

    accuracy                           0.92     25000
   macro avg       0.92      0.92      0.92     25000
weighted avg       0.92      0.92      0.92     25000
Code
torch.save(trained_prompt, "/data/blog/2021-04-28-imbd-prompt-training/trained-prompt-1-20-768.pt")

So it edges slightly closer with a longer train. It’s possible that this has overfitted the training data and that a better result could’ve been established at a different point on the train. I also think that using a bigger model could lead to better performance. It’s good to understand the limitations of GPT-2 small though.

All in all this is very encouraging.


Sanity Check

I just want to load the tokenizer / model / prompt from disk and check that the accuracy is still the same. Previously when doing this kind of thing I have found that the model changes even though it is not touched by the optimizer. If I have made such a mistake then I want to know about it, and then fix it.

Have I messed this up? Has the model been altered? I guess there is one way to check - compare it to the pretrained model.

Code
def changed_layers(model: AutoModelForCausalLM) -> List[str]:
    model.cpu()
    model_state_dict = model.state_dict()
    base_state_dict = (
        AutoModelForCausalLM.from_pretrained("gpt2")
            .state_dict()
    )
    layer_names = [
        name
        for name, state in model_state_dict.items()
        if not torch.all(torch.eq(state, base_state_dict[name]))
    ]
    model.cuda()

    return layer_names
Code
changed_layers(model)
[]

Looks like none of the layers have changed in the model. Another way to do this sanity check would be to load the model and prompt from disk and see how they compare to the previous results.

Code
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
model.eval()
model.cuda()

trained_prompt = torch.load("/data/blog/2021-04-28-imbd-prompt-training/trained-prompt-1-20-768.pt")
Code
BATCH_SIZE = 8
MAX_LENGTH = 1_000

validation_dataloader = PastDataloader(
    model=model,
    tokenizer=tokenizer,
    df=validation_df,
    batch_size=BATCH_SIZE,
    max_length=MAX_LENGTH,
    shuffle=False,
)
Code
accuracy(
    dl=validation_dataloader,
    model=model,
    prompt=trained_prompt,
    label_fn=restricted_labels
)

              precision    recall  f1-score   support

        good       0.92      0.92      0.92     12500
         bad       0.92      0.92      0.92     12500

    accuracy                           0.92     25000
   macro avg       0.92      0.92      0.92     25000
weighted avg       0.92      0.92      0.92     25000

It’s producing duplicate results so I’m happy that these results reflect the quality of the prompt rather than sneaky fine tuning of the model.