Huggingface Distillation Workshop

Create a distilled model
model distillation
Published

April 13, 2022

Huggingface Workshop

I attended a workshop on model distillation (available on youtube here). It was a combination of distillation and compilation to AWS inferentia, and had a heavy AWS element.

I’m going to use this blog post to recreate that workshop locally.

Workshop Video

The video from the workshop has been uploaded to youtube and is available here:

youtube: https://www.youtube.com/watch?v=3fulTyMXhWQ

Workshop Code

The code that was used for the workshop is available here.

Huggingface Book

One of the presenters, Lewis Tunstall, wrote a book on nlp with transformers. You can see that at https://transformersbook.com/

Workshop Notes

My initial notes on the workshop were:

What is Distillation?

in distillation the student is trained on the task, and is also trained to produce the same class distribution as the teacher (referred to as a knowledge distillation loss parameter). It has a composite loss function where the accuracy of the student on the task is combined with the degree to which the student output matches the teacher.

Why is distillation better than just fine tuning?

The large model has likely learned details about the decision boundaries between the different classes. The knowledge distribution loss function provides hints about the boundaries to the student.

What does the teacher teach?

The teacher has learned decision boundaries between the different classes. This is the information that the teacher needs to transmit to the student.

The teacher is likely very confident in it’s output, so it may well assign more than 99% of the probability weight to the correct class. To effectively pass the information about the decision boundaries to the student we need to adjust the teacher output to show these boundaries more clearly.

There is a temperature parameter which is a way of spreading the predicted class. The teacher is likely very confident in the prediction, so the student is likely to get little from the addition of the teacher distribution. Spreading the distribution means that the distribution still reflects the teacher, but shows more of the other classes. This means that the student then has to put effort into matching this as well as the actual correct class.

What values of temperature are good?

usually values of T=2-5 are a good starting point, but in practice you can get better models even with larger values like 10-20

having T<1 will tend to put the teacher’s probabilities to be very similar to the ground truth label, so it isn’t very common in my experience (since you don’t gain any new information)

Does Distillation work across Architectures?

paper suggests that distillation doesn’t work if the teacher and students have different architectures https://arxiv.org/pdf/2010.13382.pdf

What is inferentia?

Inferentia is a CPU machine that can compile GPU instructions. The compilation does involve some loss of precision. Once the model has been compiled in this way it can run efficiently on the inferentia machines, faster even than on a GPU. It also compresses the model. In the workshop the model changed size to about ~\(\frac{2}{3}\) of the original.

The compilation process involves mapping the model operations to inferentia operations. Some of this compilation is referred to as fusing which is optimizing. In the workshop some 95% of the model operations were fused.

There are GPU operations which have no corresponding inferentia operation. These cannot be run on inferentia.

An inferentia model does not support dynamic sized input. All input needs to be of a fixed size.

Code
import blog.transformers_logging

Distillation Code

The workshop went over the following code which is in the notebook.

Code
from sagemaker.huggingface import HuggingFace
from huggingface_hub import HfFolder

# hyperparameters, which are passed into the training job
hyperparameters={
    'teacher_id':'optimum/roberta-large-finetuned-clinc',           
    'student_id':'nreimers/MiniLMv2-L12-H384-distilled-from-RoBERTa-Large',           
    'dataset_id':'clinc_oos',           
    'dataset_config':'plus',             
    'epochs': 10,             
    # distillation parameter
    'alpha': 0.055199695773231194, # 0.5,
    'temparature': 19, # 4 
    'learning_rate': 1e-4, # 3e-5
    # hpo parameter
    "run_hpo": False,
    "n_trials": 100,   
    # push to hub config
    'push_to_hub': True,                            
    'hub_model_id': 'MiniLMv2-L12-H384-distilled-finetuned-clinc', 
    'hub_token': HfFolder.get_token()      
}

# define Training Job Name 
job_name = f'knowledge-distillation'

# create the Estimator
huggingface_estimator = HuggingFace(
    entry_point          = 'knowledge_distillation.py',        
    source_dir           = './scripts',       
    instance_type        = 'ml.p3.2xlarge',   
    instance_count       = 1,                 
    role                 = role,    
    base_job_name        = job_name, 
    transformers_version = '4.17',            
    pytorch_version      = '1.10',             
    py_version           = 'py38',            
    hyperparameters      = hyperparameters,   
)
ModuleNotFoundError: No module named 'sagemaker'

The code is heavily reliant on sagemaker. It also appears to lack anything to do with distillation.

If we review it carefully it really is just a way of triggering a train on another machine. The actual code is referred to here:

entry_point = 'knowledge_distillation.py',        
source_dir  = './scripts', 

The knowledge_distillation.py script contains everything needed to perform distillation locally.

It’s also reasonable to think that huggingface has code related to distillation that could be worth a review. While I have enough with the knowledge_distillation.py script, I did also find a very exciting script on transformers about distilling a classifier which seems solid.

Local Distillation

Copying from the knowledge_distillation.py script we can see that the training is achieved by overriding the training arguments and trainer. It’s actually quite a simple extension.

Code
# from src/main/python/blog/distillation/trainer.py
import logging
import sys
from pathlib import Path
from typing import Any, Dict, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from datasets import load_dataset, load_metric
from torch import nn
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    EvalPrediction,
    Trainer,
    TrainingArguments,
)


def distill(
    *,
    teacher_id: str,
    student_id: str,
    dataset_id: str,
    dataset_name: str,
    output_dir: Path,
    epochs: int,
    per_device_train_batch_size: int,
    per_device_eval_batch_size: int,
    fp16: bool,
    learning_rate: float,
    alpha: float,
    temperature: float,
) -> None:
    # Set up logging
    logging.basicConfig(
        level=logging.getLevelName("INFO"),
        handlers=[logging.StreamHandler(sys.stdout)],
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    )

    # init tokenizer
    tokenizer = AutoTokenizer.from_pretrained(teacher_id)
    student_tokenizer = AutoTokenizer.from_pretrained(student_id)

    # sample input
    sample = "This is a basic example, with different words to test."

    # assert results
    assert tokenizer(sample) == student_tokenizer(
        sample
    ), "Tokenizers are not compatible"

    # load datasets
    dataset = load_dataset(dataset_id, dataset_name)

    # process dataset
    def process(examples: Dict[str, Any]) -> Dict[str, Any]:
        tokenized_inputs = tokenizer(examples["text"], truncation=True, max_length=512)
        return tokenized_inputs

    tokenized_datasets = dataset.map(process, batched=True)
    tokenized_datasets = tokenized_datasets.rename_column("intent", "labels")

    # define metrics and metrics function
    accuracy_metric = load_metric("accuracy")

    def compute_metrics(eval_pred: EvalPrediction) -> Dict[str, float]:
        predictions, labels = eval_pred
        predictions = np.argmax(predictions, axis=1)
        acc = accuracy_metric.compute(predictions=predictions, references=labels)
        return {
            "accuracy": acc["accuracy"],
        }

    # create label2id, id2label dicts for nice outputs for the model
    labels = tokenized_datasets["train"].features["labels"].names
    num_labels = len(labels)
    id2label = dict(enumerate(labels))
    label2id = {label: label_id for label_id, label in id2label.items()}

    # define training args
    training_args = DistillationTrainingArguments(
        output_dir=output_dir,
        num_train_epochs=epochs,
        per_device_train_batch_size=per_device_train_batch_size,
        per_device_eval_batch_size=per_device_eval_batch_size,
        fp16=fp16,
        learning_rate=learning_rate,
        seed=33,
        # logging & evaluation strategies
        logging_dir=output_dir / "logs",
        logging_strategy="epoch",  # to get more information to TB
        evaluation_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=2,
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
        report_to="tensorboard",
        # distilation parameters
        alpha=alpha,
        temperature=temperature,
    )

    # define data_collator
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    # define teach model
    teacher_model = AutoModelForSequenceClassification.from_pretrained(
        teacher_id,
        num_labels=num_labels,
        id2label=id2label,
        label2id=label2id,
    )

    # init method is needed when using hpo
    def student_init():
        return AutoModelForSequenceClassification.from_pretrained(
            student_id, num_labels=num_labels, id2label=id2label, label2id=label2id
        )

    trainer = DistillationTrainer(
        model_init=student_init,
        args=training_args,
        teacher_model=teacher_model,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["validation"],
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )

    # train model with inital hyperparameters or hyperparameters from the best run
    trainer.train()

    # Saves the model to s3 uses os.environ["SM_MODEL_DIR"] to make sure checkpointing works
    trainer.save_model(output_dir / "best_model")


class DistillationTrainingArguments(TrainingArguments):
    def __init__(
        self, *args, alpha: float = 0.5, temperature: float = 2.0, **kwargs
    ) -> None:
        super().__init__(*args, **kwargs)
        self.alpha = alpha
        self.temperature = temperature


class DistillationTrainer(Trainer):
    def __init__(
        self, *args, teacher_model: AutoModelForSequenceClassification = None, **kwargs
    ) -> None:
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        # place teacher on same device as student
        self._move_model_to_device(self.teacher, self.model.device)
        self.teacher.eval()

    def compute_loss(
        self,
        model: AutoModelForSequenceClassification,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        return_outputs: bool = False,
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
        # compute student output
        outputs_student = model(**inputs)
        student_loss = outputs_student.loss
        # compute teacher output
        with torch.no_grad():
            outputs_teacher = self.teacher(**inputs)

        # assert size
        assert outputs_student.logits.size() == outputs_teacher.logits.size()

        # Soften probabilities and compute distillation loss
        loss_function = nn.KLDivLoss(reduction="batchmean")
        loss_logits = (
            loss_function(
                F.log_softmax(outputs_student.logits / self.args.temperature, dim=-1),
                F.softmax(outputs_teacher.logits / self.args.temperature, dim=-1),
            )
            * (self.args.temperature ** 2)
        )
        # Return weighted student loss
        loss = self.args.alpha * student_loss + (1.0 - self.args.alpha) * loss_logits
        return (loss, outputs_student) if return_outputs else loss
Code
from pathlib import Path

# the parameters here are copied from the settings, above, and from the defaults in the script.
distill(
    teacher_id="optimum/roberta-large-finetuned-clinc",
    student_id="nreimers/MiniLMv2-L12-H384-distilled-from-RoBERTa-Large",
    dataset_id="clinc_oos",
    dataset_name="plus",
    output_dir=Path("/data/blog/2022-04-13-huggingface-distillation-workshop"),
    epochs=10,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    fp16=True,

    # distillation parameter
    learning_rate=1e-4,
    alpha=0.055199695773231194,
    temperature=19,
)
2022-04-26 10:09:07,068 - datasets.builder - WARNING - Reusing dataset clinc_oos (/home/matthew/.cache/huggingface/datasets/clinc_oos/plus/1.0.0/abcc41d382f8137f039adc747af44714941e8196e845dfbdd8ae7a7e020e6ba1)
2022-04-26 10:09:07,086 - datasets.arrow_dataset - WARNING - Loading cached processed dataset at /home/matthew/.cache/huggingface/datasets/clinc_oos/plus/1.0.0/abcc41d382f8137f039adc747af44714941e8196e845dfbdd8ae7a7e020e6ba1/cache-87263ad3bb40db18.arrow
2022-04-26 10:09:07,102 - datasets.arrow_dataset - WARNING - Loading cached processed dataset at /home/matthew/.cache/huggingface/datasets/clinc_oos/plus/1.0.0/abcc41d382f8137f039adc747af44714941e8196e845dfbdd8ae7a7e020e6ba1/cache-16acf3948860a75f.arrow
[2390/2390 03:26, Epoch 10/10]
Epoch Training Loss Validation Loss Accuracy
1 0.827000 0.704925 0.376129
2 0.612300 0.492485 0.799355
3 0.431400 0.348086 0.893548
4 0.308900 0.254955 0.938065
5 0.233900 0.202458 0.945161
6 0.188300 0.170242 0.947097
7 0.160000 0.150101 0.948710
8 0.141900 0.136279 0.950000
9 0.130800 0.128748 0.952581
10 0.125000 0.126360 0.950323

The settings from the workshop have achieved an accuracy of 95.2%. This compares very well to the teacher accuracy of 97%.

What is interesting is that I have run this training with a few other settings. When the fp16 setting was not used then the student model failed to train at all. This might be worth further investigation.

What does Temperature change?

The workshop described the temperature parameter as smoothing the output of the teacher model. This is because the teacher can be very confident in it’s prediction, which means that there is little additional information for the student compared to just training against the gold label.

The teacher is there to show where suitable decision boundaries are to be found, so the teacher shows the other classes which are close to the target class. By smoothing the probabilities it should increase the probabilies of the classes other than the target.

Let’s see how that looks in practice.

We are going to load the teacher, get it’s output for an example from the test set (available here), and then see how changing the temperature changes that output.

Code
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset

tokenizer = AutoTokenizer.from_pretrained("optimum/roberta-large-finetuned-clinc")
model = AutoModelForSequenceClassification.from_pretrained("optimum/roberta-large-finetuned-clinc")
model.eval()
model.cuda()

dataset = load_dataset("clinc_oos", "plus")
id2label = dict(enumerate(dataset["train"].features["intent"].names))
Reusing dataset clinc_oos (/home/matthew/.cache/huggingface/datasets/clinc_oos/plus/1.0.0/abcc41d382f8137f039adc747af44714941e8196e845dfbdd8ae7a7e020e6ba1)
Code
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import pandas as pd

def top_and_bottom(
    text: str,
    tokenizer: AutoTokenizer,
    model: AutoModelForSequenceClassification,
    temperature: float,
) -> pd.DataFrame:
    logits = get_logits(
        text,
        tokenizer=tokenizer,
        model=model
    )[0]
    logits = logits / temperature
    logits = logits.softmax(dim=-1) 
    logit_indices = logits.argsort(descending=True)[:10].tolist()
    
    top_5_labels = [
        id2label[label_id]
        for label_id in logit_indices[:5]
    ]
    bottom_4_labels = [
        id2label[label_id]
        for label_id in logit_indices[-4:]
    ]

    df = pd.DataFrame(
        {
            "label": top_5_labels + ["..."] + bottom_4_labels,
            "probability": (
                logits[logit_indices[:5]].tolist()
                + [0.]
                + logits[logit_indices[-4:]].tolist()
            ),
        }
    )
    return df

@torch.no_grad()
def get_logits(
    text: str,
    tokenizer: AutoTokenizer,
    model: AutoModelForSequenceClassification
) -> torch.Tensor:
    tokens = tokenizer(text, return_tensors="pt")
    tokens = tokens.to(model.device)
    return model(**tokens).logits

Temperature of 1

Here we can see that 99% of the probability has been assigned to the correct class, definition.

Code
import pandas as pd

df = top_and_bottom(
    "what is the definiton of auspicious",
    tokenizer=tokenizer,
    model=model,
    temperature=1,
)

df.plot(x="label", y="probability", kind="bar")
df.iloc[[0,1,9]]
label probability
0 definition 0.991237
1 meaning_of_life 0.000330
9 reminder_update 0.000106

Temperature of 5

A temperature of 5 was suggested as a good starting point for distillation. Now the most confident label only has ~5% of the probability.

Code
import pandas as pd

df = top_and_bottom(
    "what is the definiton of auspicious",
    tokenizer=tokenizer,
    model=model,
    temperature=5,
)

df.plot(x="label", y="probability", kind="bar")
df.iloc[[0,1,9]]
label probability
0 definition 0.046073
1 meaning_of_life 0.009290
9 reminder_update 0.007396

Temperature of 19

A temperature of 19 was used in the workshop. Now the most confident label only has ~1% of the probability, and the second label is ~70% of the most confident.

Code
import pandas as pd

df = top_and_bottom(
    "what is the definiton of auspicious",
    tokenizer=tokenizer,
    model=model,
    temperature=19,
)

df.plot(x="label", y="probability", kind="bar")
df.iloc[[0,1,9]]
label probability
0 definition 0.011118
1 meaning_of_life 0.007295
9 reminder_update 0.006870

The original model is so confident that the difference between the second best prediction and the worst prediction is not that great.

Further Work

There are two ways in which I would be interested in taking this new technique. The first is to consider if temperature is the best way to transmit the boundary information. The second is to consider if the KL Divergence approach could be used to internalize prompted tasks into a student.

Is Temperature the best Boundary Communicator?

As we saw earlier the temperature parameter tends to just boost all of the non-correct answers.

I wonder what it would be like to feed in something that is not intententful. At a temperature of 19 the difference between the first wrong answer and the last wrong answer is only ~6% (0.007295 vs 0.006870). How much information is the model gaining from such a low relative difference?

The example we investigated was a clear cut answer. If we try an input where the teacher is uncertain how does temperature affect that?

Code
import pandas as pd

df = top_and_bottom(
    "Call me Ishmael. "
    "Some years ago, never mind how long precisely, "
    "having little or no money in my purse, "
    "and nothing particular to interest me on shore, "
    "I thought I would sail about a little and see the watery part of the world",
    tokenizer=tokenizer,
    model=model,
    temperature=1,
)

df.plot(x="label", y="probability", kind="bar")
df.iloc[[0,1,9]]
label probability
0 balance 0.363756
1 oos 0.059927
9 maybe 0.010566

Code
import pandas as pd

df = top_and_bottom(
    "Call me Ishmael. "
    "Some years ago, never mind how long precisely, "
    "having little or no money in my purse, "
    "and nothing particular to interest me on shore, "
    "I thought I would sail about a little and see the watery part of the world",
    tokenizer=tokenizer,
    model=model,
    temperature=19,
)

df.plot(x="label", y="probability", kind="bar")
df.iloc[[0,1,9]]
label probability
0 balance 0.008549
1 oos 0.007775
9 maybe 0.007096

This is a poorly classified sequence and the original output is uncertain. After applying temperature all of the probabilities have been boosted, so the relative difference has been lost. You can see here that the difference between the second class and the last class has largely vanished. The uncertain answer retains more of this difference, but the difference has still become small.

text temperature absolute second score absolute last score relative difference %
what is … 1 0.000330 0.000106 211.3
what is … 5 0.009290 0.007396 25.6
what is … 19 0.007295 0.006870 6.1
Call me … 1 0.059927 0.010566 467.1
Call me … 19 0.007775 0.007096 9.5

Two things occur to me based on this.

The first is that the student is being informed of the correct answer already, so the loss of information about the correct class is acceptable, as that information is available through the other loss function.

The second is that this is intended to show the student appropriate decision boundaries. Does such a flat probability distribution do this well? I wonder if a more discriminitive process would be better for transmitting the boundary information.

Can KL Divergence Loss be used to Internalize Prompted Output?

KL Divergence is matching the distribution of the output of the two models. This could be used to alter a model to internalize the prompt.

The aim would be to have a token sensitive prompt (e.g. the sentiment of an utterance towards a specific mentioned entity) where the teacher output is interpreted according to that specific token. Then the student is evaluated on the raw output of the model for that specific token. This would train the student to predict entity sentiment for each token, allowing entity sentiment to be predicted for every entity in the utterance in one pass.

This would be an extension of my previous work in this area. I will call it prompt internalization.