Prompt Training - Huggingface

Using Huggingface Trainer and Dataset to do Prompt Training
Published

June 13, 2021

As part of the prompt training paper I need a methodical set of results. To do this I want to be able to sweep across several different combinations of model, prompt size, and task. It would also be good to compare the prompt training approach to fine tuning.

I’ve also found that the documentation for linking together the huggingface Trainer and Datasets is somewhat lacking. This post can serve as an example of joining those two together. Moving away from my custom training loop has the benefit of making the code easier to understand, which would help when people who are interested in the technique attempt to reproduce it.


Requirements

This is going to link the hugginface Trainer with the huggingface Datasets to train a model. The model will be a fixed language model that is using a custom prompt and linear classification layer. Only the prompt and classification layer can train.

The huggingface Trainer is quite prescriptive. It is made to work with the huggingface model structure. That means taking input_ids and possibly an attention_mask. It then trains against the label or labels that are provided.

The dataset then provides an abstraction over what behaves like a list of dictionaries. Underlying that is more efficient in memory representations that can be indexed in various ways. It’s quite neat and it’s almost totally model agnostic - it does not handle tokenizing automatically.

The biggest challenge to address is the translation from input_ids to inputs_embeds - the prompt is in the embedding space, not the token space, so it cannot be added to the tokens. I have already tried to do this within the dataset and it does not seem to fit with the intent of that. The Trainer expects to receive input_ids and does not correctly handle inputs_embeds, further reinforcing my opinion. Finally the prompt itself is a trainable parameter, so associating it with the model is consistent with the other parameters.


Language Model for Prompting

I want the model to be a GPT2 model, so extending from that makes sense. The task is a classification task, so using GPT2ForSequenceClassification as the base class makes a lot of sense.

This class has a forward method which I can use if I just perform the translation before invoking it. So my model will be overriding forward to create the inputs_embeds and then passing them up the chain.

Code
import torch
from transformers import GPT2ForSequenceClassification

class GPT2ForPromptTraining(GPT2ForSequenceClassification):
    def __init__(self, config) -> None:
        super().__init__(config)

        embedding = self.transformer.wte
        vocab_size = embedding.weight.shape[0]
        prompt_indexes = torch.randint(
            size=(config.prompt_tokens,),
            low=0,
            high=vocab_size,
            device=self.device
        )
        self.prompt = torch.nn.Parameter(
            embedding(prompt_indexes).clone()[None, :, :]
        )

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        **kwargs,
    ):
        inputs_embeds = self._add_prompt(self._to_embedding(input_ids))
        if attention_mask is not None:
            attention_mask = self._extend_attention_mask(attention_mask)
        return super().forward(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            **kwargs
        )

    def _to_embedding(self, tokens: torch.Tensor) -> torch.Tensor:
        return self.transformer.wte(tokens)

    def _add_prompt(self, inputs_embeds: torch.Tensor) -> torch.Tensor:
        # so at this point it would be nice to add the prompt to the end of the attended tokens
        # something for after this is working...
        batch_size = inputs_embeds.shape[0]
        prompt_embeds = (
            self.prompt
                .repeat_interleave(batch_size, dim=0)
        )
        return torch.cat([
            inputs_embeds,
            prompt_embeds,
        ], dim=1)

    def _extend_attention_mask(self, attention_mask: torch.Tensor) -> torch.Tensor:
        batch_size = attention_mask.shape[0]
        prompt_size = self.prompt.shape[1]
        return torch.cat([
            attention_mask,
            torch.ones((batch_size, prompt_size), device=self.device),
        ], dim=1)

Now we can create a configuration that contains the number of prompt tokens to use. It’s quite important to have a valid configuration for this as it allows the model to be created correctly by the framework. Adding custom parameters to the constructor works when creating the model directly, but not when the trainer wants to recreate the model.

Code
from transformers import GPT2Config

config = GPT2Config.from_pretrained("gpt2")
config.pad_token_id = config.eos_token_id
config.prompt_tokens = 5

Now we can put this together with the tokenizer. I want the main part of the model to be in eval mode as I do not want to train it.

Code
from transformers import GPT2TokenizerFast

tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
model = GPT2ForPromptTraining.from_pretrained("gpt2", config=config)
model.transformer.eval()
# model.prompt.train()
model.score.train()

tokenizer.pad_token = tokenizer.eos_token
Some weights of GPT2ForPromptTraining were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight', 'prompt']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Code
model(**tokenizer("hello world", return_tensors="pt")) ; None
GPT2ForPromptTraining will not detect padding tokens in `inputs_embeds`. Results may be unexpected if using padding tokens in conjunction with `inputs_embeds.`

This message is produced when passing the inputs_embeds to a model where the configuration has a padding token set. We want the padding token, but this message is going to get very noisy during training. The message can be suppressed by altering the logging level of the GPT2 module.

Code
import logging
import transformers.models.gpt2.modeling_gpt2 as gpt2_module

gpt2_module.logger.setLevel(logging.CRITICAL)

So the model is working. What I need now is to hook it up to the trainer and try it on a task. Using the existing datasets that are available should make this easy.


Dataset Transformation

As far as datasets go I think that GLUE is well studied and has several classification tasks. A quick review shows that Recognizing Textual Entailment and Stanford Sentiment Treebank tasks look great. They both seem to be text classification tasks and that should fit with the model I have created.

Code
#hide_output

from datasets import load_dataset

rte_dataset = load_dataset("glue", "rte")
sst2_dataset = load_dataset("glue", "sst2")
Reusing dataset glue (/home/matthew/.cache/huggingface/datasets/glue/rte/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
Reusing dataset glue (/home/matthew/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
Code
rte_dataset["train"][0]
{'idx': 0,
 'label': 1,
 'sentence1': 'No Weapons of Mass Destruction Found in Iraq Yet.',
 'sentence2': 'Weapons of Mass Destruction Found in Iraq.'}
Code
sst2_dataset["train"][0]
{'idx': 0,
 'label': 0,
 'sentence': 'hide new secretions from the parental units '}

The load_dataset method returns a dict of collections that are either for training, validation or testing. Those collections are then the Dataset objects. These behave like lists of dictionaries but they have several useful methods for transforming or indexing the data.

Looking at these two I think that sst2 is the better one as transforming the dataset to the desired form should be easier. Handling the pair of sentences for RTE takes a little bit of extra care.

The next task is to transform the train and validation datasets so that they have the required input_ids column. We can achieve that by tokenizing the sentence column.

Code
from datasets import Dataset

def transform_dataset(dataset: Dataset) -> Dataset:
    return (
        dataset.map(tokenizer, input_columns="sentence")
        # the label column is dropped by the map function,
        # so we need to restore it.
            .add_column(
                "label", dataset["label"]
            )
            .remove_columns("sentence")
    )

sst2_dataset = load_dataset("glue", "sst2")

train_ds = transform_dataset(sst2_dataset["train"])
valid_ds = transform_dataset(sst2_dataset["validation"])
Reusing dataset glue (/home/matthew/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
Loading cached processed dataset at /home/matthew/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-ea20a903b4675e27.arrow
Loading cached processed dataset at /home/matthew/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-5a7d29392a2b324d.arrow

Training

Now that we have the model and the datasets we can train the model. We want to know how well the training is going so we need a metric to track that. Luckily the datasets framework has an associated metric for the glue task.

We want the metric, which takes a set of predictions and the associated references, to work with the evaluation callback of the trainer. The evaluation callback returns an object that has predictions and label_ids. The predictions that the callback receives are the raw model outputs, so they need to be translated into the favoured class.

Code
from datasets import load_metric

sst2_metric = load_metric("glue", "sst2")
def compute_metrics(run):
    targets = run.label_ids
    predictions = run.predictions.argmax(axis=1)
    return sst2_metric.compute(predictions=predictions, references=targets)

Now we have everything required to perform the training. This uses a custom optimizer which can only later the prompt and the score (which is the classification layer). When passing this to the Trainer we pass it in a tuple with None. This pairing is done to allow you to provide a custom scheduler for the optimizer.

Code
from pathlib import Path
from transformers import TrainingArguments, Trainer
from datasets import load_metric

MODEL_RUN_FOLDER = Path("/data/blog/2021-06-13-prompt-training-using-transformers/runs")
MODEL_RUN_FOLDER.mkdir(parents=True, exist_ok=True)

optimizer = torch.optim.AdamW(
    [model.prompt]
    + list(model.score.parameters())
)

training_args = TrainingArguments(
    report_to=[],            
    output_dir=MODEL_RUN_FOLDER / "output",
    overwrite_output_dir=True,
    per_device_train_batch_size=128,
    learning_rate=5e-5,
    num_train_epochs=5,
    evaluation_strategy="epoch",
    logging_dir=MODEL_RUN_FOLDER / "output",
    logging_steps=100,
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    optimizers=(optimizer, None)
)

trainer.train()
[2635/2635 19:28, Epoch 5/5]
Epoch Training Loss Validation Loss Accuracy Runtime Samples Per Second
1 0.423000 0.334669 0.860092 2.084100 418.412000
2 0.391400 0.307949 0.877294 2.172300 401.425000
3 0.394200 0.307660 0.871560 2.187400 398.641000
4 0.377600 0.313920 0.872706 2.126400 410.081000
5 0.384900 0.306334 0.877294 2.193800 397.480000

TrainOutput(global_step=2635, training_loss=0.41107175779976024, metrics={'train_runtime': 1169.4491, 'train_samples_per_second': 2.253, 'total_flos': 1.1866525027310592e+16, 'epoch': 5.0, 'init_mem_cpu_alloc_delta': 2172895232, 'init_mem_gpu_alloc_delta': 511169536, 'init_mem_cpu_peaked_delta': 465125376, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 655671296, 'train_mem_gpu_alloc_delta': 1015777280, 'train_mem_cpu_peaked_delta': 366313472, 'train_mem_gpu_peaked_delta': 11247239168})

Sanity Check

It would be good to check that the weights of the transformer part of the model are unaltered. This would confirm that the performance of the model is entirely down to the prompt and classification layer.

Code
from typing import *

def changed_layers(model: GPT2ForPromptTraining) -> List[str]:
    model.cpu()
    comparison = GPT2ForSequenceClassification.from_pretrained("gpt2")

    # only comparing the transformer as the score layer is part of the trained parameters
    model_state_dict = model.transformer.state_dict()
    base_state_dict = (
        comparison
            .transformer
            .state_dict()
    )

    layer_names = [
        name
        for name, state in model_state_dict.items()
        if not torch.all(torch.eq(state, base_state_dict[name]))
    ]

    return layer_names
Code
changed_layers(model)
Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[]

That’s great. The training has not altered the transformer part of the model.


Fine Tuning

The performance of the model seems ok considering that this is working with GPT2-small. How well can GPT2-small perform when it has been fine tuned? Knowing this will show how well the prompt is performing.

Code
model = GPT2ForSequenceClassification.from_pretrained("gpt2", config=config)

training_args = TrainingArguments(
    report_to=[],            
    output_dir=MODEL_RUN_FOLDER / "output",
    overwrite_output_dir=True,
    per_device_train_batch_size=128,
    learning_rate=5e-5,
    num_train_epochs=5,
    evaluation_strategy="epoch",
    logging_dir=MODEL_RUN_FOLDER / "output",
    logging_steps=100,
    load_best_model_at_end=True, # doesn't matter for this comparison
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    # optimizers=(torch.optim.AdamW([model.prompt] + list(model.score.parameters())), None)
    # optimize everything
)

trainer.train()
Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[2635/2635 18:42, Epoch 5/5]
Epoch Training Loss Validation Loss Accuracy Runtime Samples Per Second
1 0.217800 0.240399 0.912844 1.309900 665.720000
2 0.156600 0.258620 0.909404 1.373200 635.021000
3 0.121500 0.319108 0.908257 1.341500 650.019000
4 0.087200 0.316802 0.913991 1.315500 662.865000
5 0.070700 0.349986 0.916284 1.379500 632.091000

TrainOutput(global_step=2635, training_loss=0.14710736473778394, metrics={'train_runtime': 1123.3031, 'train_samples_per_second': 2.346, 'total_flos': 1.1866158862428672e+16, 'epoch': 5.0, 'init_mem_cpu_alloc_delta': -65011712, 'init_mem_gpu_alloc_delta': 513573376, 'init_mem_cpu_peaked_delta': 65011712, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 1998848, 'train_mem_gpu_alloc_delta': 2014632448, 'train_mem_cpu_peaked_delta': 8552448, 'train_mem_gpu_peaked_delta': 10310513152})

So we can easily see that prompt training loses accuracy compared to fine tuning the model. That is to be expected - it is only able to alter a tiny fraction of the parameters.


Improvements

How can the prompt training setup be improved? I think that there are a few adjustments that can be made which could improve performance.

Attention Span

It would be good to perform this evaluation with the prompt directly following the text, currently any gaps in attention occur between the text and prompt. Describing this may be easier with a visualization. We have two sentences that are in a batch, and they have been tokenized. One of the sentences turns into a much shorter list of tokens.

sentence token 1 token 2 token 3
something short 123 0 0
something long 123 456 789

To tell the model not to pay attention to the end of the first sentence the attention mask is used. This indicates which tokens are valid:

sentence mask 1 mask 2 mask 3
something short 1 0 0
something long 1 1 1

We want to add the prompt to this. At the moment we are just adding it directly on to the end:

sentence token 1 token 2 token 3 prompt 1 prompt 2
something short 123 0 0 ppp ppp ppp
something long 123 456 789 ppp ppp ppp

But this leaves big gaps in the middle of the input that the model may handle poorly. The attention mechanism is really there to make batching easier, so the model was trained to handle inputs of different lengths. I don’t think that unattended tokens within the input is good for performance.

Classification Bias

A likely smaller contributing factor is adding bias to the score. This might make it easier to optimise the classifier and is a reasonably small change to make.

Code
import torch

class GPT2ForPromptTraining(GPT2ForSequenceClassification):
    def __init__(self, config) -> None:
        super().__init__(config)

        assert self.config.pad_token_id is not None

        # enabling bias on this didn't help - see below
        self.score = torch.nn.Linear(
            in_features=self.score.in_features,
            out_features=self.score.out_features,
            bias=True
        )

        embedding = self.transformer.wte
        vocab_size = embedding.weight.shape[0]
        prompt_indexes = torch.randint(
            size=(config.prompt_tokens,),
            low=0,
            high=vocab_size,
            device=self.device
        )
        self.prompt = torch.nn.Parameter(
            embedding(prompt_indexes).clone()[None, :, :]
        )

        # We add this to the input before adding the prompt tokens, it's full of the padding tokens.
        # Making it a parameter means it will get put on the right device and loss will work even though it is not trained.
        input_extension = torch.ones(
            (config.prompt_tokens,),
            dtype=int,
            device=self.device
        )
        input_extension *= self.config.pad_token_id
        self.input_extension = torch.nn.Parameter(
            embedding(input_extension)[None, :, :]
        )

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        **kwargs,
    ):
        if attention_mask is not None:
            inputs_embeds = self._extend_inputs_embeds(
                self._to_embedding(input_ids)
            )
            attention_mask = self._extend_attention_mask(attention_mask)
            self._copy_prompt(
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
            )
        else:
            inputs_embeds = self._add_prompt(
                self._to_embedding(input_ids)
            )

        return super().forward(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)

    def _to_embedding(self, tokens: torch.Tensor) -> torch.Tensor:
        return self.transformer.wte(tokens)

    def _extend_inputs_embeds(self, inputs_embeds: torch.Tensor) -> torch.Tensor:
        batch_size = inputs_embeds.shape[0]
        input_extension = (
            self.input_extension
                .repeat_interleave(batch_size, dim=0)
        )
        return torch.cat([
            inputs_embeds,
            input_extension,
        ], dim=1)

    def _extend_attention_mask(self, attention_mask: torch.Tensor) -> torch.Tensor:
        batch_size = attention_mask.shape[0]
        prompt_size = self.prompt.shape[1]
        return torch.cat([
            attention_mask,
            torch.zeros((batch_size, prompt_size), device=self.device),
        ], dim=1)

    def _copy_prompt(self, inputs_embeds: torch.Tensor, attention_mask: torch.Tensor) -> None:
        prompt = self.prompt
        prompt_size = prompt.shape[1]
        attention_indexes = attention_mask.sum(dim=1).long().tolist()
        for batch_index, token_index in enumerate(attention_indexes):
            end_index = token_index + prompt_size
            inputs_embeds[batch_index, token_index:end_index] = prompt[0]
            attention_mask[batch_index, token_index:end_index] = 1
Code
config = GPT2Config.from_pretrained("gpt2")
config.pad_token_id = config.eos_token_id
config.prompt_tokens = 5

tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
model = GPT2ForPromptTraining.from_pretrained("gpt2", config=config)
model.transformer.eval()
model.score.train()

tokenizer.pad_token = tokenizer.eos_token
Some weights of GPT2ForPromptTraining were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight', 'input_extension', 'score.bias', 'prompt']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Code
training_args = TrainingArguments(
    report_to=[],            
    output_dir=MODEL_RUN_FOLDER / "output",
    overwrite_output_dir=True,
    per_device_train_batch_size=128,
    learning_rate=5e-5,
    num_train_epochs=5,
    evaluation_strategy="epoch",
    logging_dir=MODEL_RUN_FOLDER / "output",
    logging_steps=100,
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    optimizers=(torch.optim.AdamW([model.prompt] + list(model.score.parameters())), None)
)

trainer.train()
[2635/2635 22:18, Epoch 5/5]
Epoch Training Loss Validation Loss Accuracy Runtime Samples Per Second
1 0.430500 0.345811 0.848624 4.375400 199.296000
2 0.401500 0.336704 0.849771 4.301200 202.734000
3 0.402400 0.320549 0.865826 4.315400 202.065000
4 0.384000 0.321338 0.858945 4.578700 190.447000
5 0.389000 0.317897 0.864679 4.685400 186.112000

TrainOutput(global_step=2635, training_loss=0.4206086059222864, metrics={'train_runtime': 1338.8794, 'train_samples_per_second': 1.968, 'total_flos': 1.1866891382903388e+16, 'epoch': 5.0, 'init_mem_cpu_alloc_delta': -75022336, 'init_mem_gpu_alloc_delta': 512233984, 'init_mem_cpu_peaked_delta': 75022336, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 864256, 'train_mem_gpu_alloc_delta': 1018737664, 'train_mem_cpu_peaked_delta': 7725056, 'train_mem_gpu_peaked_delta': 11245451776})
Code
model.score.bias
Parameter containing:
tensor([0.0091, 0.0107], device='cuda:0', requires_grad=True)

The original prompt train got 0.877294 accuracy.

Adding the bias to this has achieved almost nothing. I can see that because the two values in it are so similar.

Training with the prompt directly appended has resulted in no improvement in accuracy. I do still feel that the direct append is the correct approach. It might need a bit of tuning? Maybe I have made a mistake with the implementation?


Classifier Only Training

Another sanity check for this would be to train only the classifier of the model. We can do this with a custom optimizer over the GPT2ForSequenceClassification model. This would show how much the prompt is contributing to the accuracy.

Code
model = GPT2ForSequenceClassification.from_pretrained("gpt2", config=config)
model.transformer.eval()
model.score.train()

training_args = TrainingArguments(
    report_to=[],            
    output_dir=MODEL_RUN_FOLDER / "output",
    overwrite_output_dir=True,
    per_device_train_batch_size=128,
    learning_rate=5e-5,
    num_train_epochs=5,
    evaluation_strategy="epoch",
    logging_dir=MODEL_RUN_FOLDER / "output",
    logging_steps=100,
    load_best_model_at_end=True, # doesn't matter for this comparison
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    # optimize only the score (classification layer)
    optimizers=(torch.optim.AdamW(model.score.parameters()), None)
)

trainer.train()
Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[2635/2635 20:02, Epoch 5/5]
Epoch Training Loss Validation Loss Accuracy Runtime Samples Per Second
1 0.512300 0.436623 0.801606 5.172100 168.598000
2 0.496800 0.409076 0.819954 4.798300 181.730000
3 0.488300 0.399698 0.830275 5.218100 167.111000
4 0.481400 0.402876 0.825688 4.915300 177.404000
5 0.479300 0.396408 0.832569 5.015800 173.851000

TrainOutput(global_step=2635, training_loss=0.5017196546695717, metrics={'train_runtime': 1202.647, 'train_samples_per_second': 2.191, 'total_flos': 1.1866158862428672e+16, 'epoch': 5.0, 'init_mem_cpu_alloc_delta': 12288, 'init_mem_gpu_alloc_delta': 511154176, 'init_mem_cpu_peaked_delta': 0, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 1150976, 'train_mem_gpu_alloc_delta': 1013609472, 'train_mem_cpu_peaked_delta': 7798784, 'train_mem_gpu_peaked_delta': 10305080320})

So we can see that the prompt is adding about 4 percent accuracy. That’s a measurable difference. The prompt can be tuned by extending it, in the original google paper they used 20 tokens and got better performance. So there are things that can be done.

When doing the train for the paper I should rerun it a few times. I’ve seen a bit of variance in the results.

All in all this has been quite successful.