Distillation - Changing Temperature

Create a distilled model
model distillation
Published

April 19, 2022

Distillation is the process of training a small model on a task with the help of a larger, trained, model. The small model is referred to as the student and the large model is the teacher. Teaching the student involves the regular process of measuring accuracy and teaching the student to match the distribution of predictions from the teacher. Matching the distribution shows the student what other classes are similar to the correct class, and guides it to the correct weights in a more holistic way.

The teacher is large and well trained and that means that the teacher predicts the correct class very strongly. For the teacher to add value to the training we must alter this output to show the class distribution more. The way we do that is to alter the outputs, boosting the low probability classes and reducing the high probability class(es). Probability alteration is controlled using a parameter called temperature.

When trying out distillation I was concerned that the temperature parameter was not shaping the teacher output in an appropriate way. The teacher starts by strongly predicting a single class and the remaining probability is distributed between the other classes:

original output

The student is already being trained to correctly predict the class, the teacher is used to show appropriate decision boundaries. This means that the probability of the incorrect classes is what matters. Since the teacher strongly predicts the correct class, the influence of the incorrect classes is low.

Temperature takes probability from the most confident class and reassigns it to the others:

temperature output

However the reassignment increases the probability of all of the classes together, meaning that the lowest probability classes get a significant boost.

I think that it would be better if the probability was reassigned more to the more confident classes, and less to the less confident ones:

ideal output

(I used excalidraw to make these images).

New Temperature Requirements

The new algorithm should retain the relative difference between classes. Since probability is being taken from the most confident class and reassigned to the other classes, the relative difference between the most confident class and other classes will change.

The new algorithm should reduce the top class and reassign the probability to the other classes. It should not alter the order of the classes, so there is no parameter value which will boost the second highest class above the highest class.

It should have the same value range as the original temperature, being one to infinity.

Retain Relative Difference

We can move probability while retaining relative difference by increasing the classes according to their existing probability. An example will show this more easily.

Given the probabilities \([0.5, 0.25, 0.15, 0.1]\), we want to reassign \(0.1\) of probability. We can calculate the ratio of each subsequent class by taking \(\text{ratio}_n = \frac{P_n}{\sum_{i=2}^{4} P_i}\). Multiplying the reassigned probability by this ratio will then produce the per-class increase.

For this example there is a ratio of \([0.5, 0.3, 0.2]\), meaning that the probabilities become \([0.4, 0.3, 0.18, 0.12]\).

We can then establish that the relative difference remains the same by taking the ratio of each number with it’s subsequent number:

\[ \begin{aligned} \left[ \frac{0.25}{0.15}, \frac{0.15}{0.1} \right] &= \left[ \frac{0.3}{0.18}, \frac{0.18}{0.12} \right] \\ \left[ \frac{25}{15}, \frac{15}{10} \right] &= \left[ \frac{30}{18}, \frac{18}{12} \right] \\ \left[ \frac{5}{3}, \frac{3}{2} \right] &= \left[ \frac{5}{3}, \frac{3}{2} \right] \end{aligned} \]

No Reordering

Only one class gets a reduced probability - the most confident class. The probability that is reassigned from that is spread to the other classes according to the ratio, as described previously. This means that we can predict the point where the most confident class would swap places with the second most confident class:

\[ P_1 - P_{change} = P_2 + (P_{change} * \text{ratio}_2) \]

Solving this for \(P_{change}\) gives us

\[ \begin{aligned} P_{change} &= P_1 - (P_2 + (P_{change} * \text{ratio}_2)) \\ P_{change} &= P_1 - P_2 - (P_{change} * \text{ratio}_2) \\ P_{change} + (P_{change} * \text{ratio}_2) &= P_1 - P_2 \\ (1 + \text{ratio}_2) * P_{change} &= P_1 - P_2 \\ P_{change} &= \frac{P_1 - P_2}{1 + \text{ratio}_2} \end{aligned} \]

Same Range

To have a range that extends to infinity with a finite change I want an asymptotic progression. It’s easy to come up with one, because I want the infinity point to be the change that makes the most confident class equal to the second most confident. That would be the full \(P_{change}\) from above.

This can be produced with \(1 - k^{n}\) where \(0 < k < 1\) and \(n\) is the new temperature parameter. We can see a few graphs of this:

Code
from typing import Dict, Union
import pandas as pd

def row(n: int) -> Dict[str, Union[float, int]]:
    values = {"n": n}
    for k in [0.99, 0.9, 0.75, 0.5, 0.25, 0.1, 0.01]:
        values[f"k_{k}"] = change(k=k, n=n)
    return values

def change(k: float, n: int) -> float:
    return 1 - k**n

(
    pd.DataFrame([
        row(n=n) for n in range(20)
    ])
    .set_index("n")
    .plot()
) ; None

Looking at this I think that \(k = 0.9\) looks good as it changes significantly over the range without fully saturating.

Evaluate New Temperature

Code
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset

tokenizer = AutoTokenizer.from_pretrained("optimum/roberta-large-finetuned-clinc")
model = AutoModelForSequenceClassification.from_pretrained("optimum/roberta-large-finetuned-clinc")
model.eval()
model.cuda()

dataset = load_dataset("clinc_oos", "plus")
id2label = dict(enumerate(dataset["train"].features["intent"].names))
Reusing dataset clinc_oos (/home/matthew/.cache/huggingface/datasets/clinc_oos/plus/1.0.0/abcc41d382f8137f039adc747af44714941e8196e845dfbdd8ae7a7e020e6ba1)
Code
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import pandas as pd

def top_and_bottom(
    text: str,
    tokenizer: AutoTokenizer,
    model: AutoModelForSequenceClassification,
    temperature: float,
    k: float,
) -> pd.DataFrame:
    logits = get_logits(
        text,
        tokenizer=tokenizer,
        model=model
    )[0]
    logits = scale_1d(logits=logits, temperature=temperature, k=k)
    indices = logits.argsort(descending=True).tolist()
    
    top_5_labels = [
        id2label[label_id]
        for label_id in indices[:5]
    ]
    bottom_4_labels = [
        id2label[label_id]
        for label_id in indices[-4:]
    ]

    df = pd.DataFrame(
        {
            "label": top_5_labels + ["..."] + bottom_4_labels,
            "probability": (
                logits[indices[:5]].tolist()
                + [0.]
                + logits[indices[-4:]].tolist()
            ),
        }
    )
    return df

@torch.no_grad()
def get_logits(
    text: str,
    tokenizer: AutoTokenizer,
    model: AutoModelForSequenceClassification
) -> torch.Tensor:
    tokens = tokenizer(text, return_tensors="pt")
    tokens = tokens.to(model.device)
    return model(**tokens).logits

def scale_1d(
    logits: torch.Tensor,
    temperature: float,
    k: float,
) -> torch.Tensor:
    logits = logits.softmax(dim=-1)
    indices = logits.argsort(descending=True)

    p1, p2 = logits[indices[:2]]
    remaining_p = 1 - p1
    p_change = (p1 - p2) / (1 + (p2 / remaining_p))
    p_change *= 1 - k**(temperature - 1)
    
    logits[indices[0]] -= p_change
    logits[indices[1:]] += p_change * logits[indices[1:]] / remaining_p

    return logits
Code
import pandas as pd

df = top_and_bottom(
    "what is the definiton of auspicious",
    tokenizer=tokenizer,
    model=model,
    temperature=1,
    k=0.5,
)

df.plot(x="label", y="probability", kind="bar")
df.iloc[[0,1,9]]
label probability
0 definition 0.991237
1 meaning_of_life 0.000330
9 credit_limit_change 0.000011

Code
import pandas as pd

df = top_and_bottom(
    "what is the definiton of auspicious",
    tokenizer=tokenizer,
    model=model,
    temperature=5,
    k=0.5,
)

df.plot(x="label", y="probability", kind="bar")
df.iloc[[0,1,9]]
label probability
0 definition 0.096022
1 meaning_of_life 0.034091
9 credit_limit_change 0.001152

Code
import pandas as pd

df = top_and_bottom(
    "what is the definiton of auspicious",
    tokenizer=tokenizer,
    model=model,
    temperature=19,
    k=0.75,
)

df.plot(x="label", y="probability", kind="bar")
df.iloc[[0,1,9]]
label probability
0 definition 0.041725
1 meaning_of_life 0.036138
9 credit_limit_change 0.001221

Code
import pandas as pd

df = top_and_bottom(
    "Call me Ishmael. "
    "Some years ago, never mind how long precisely, "
    "having little or no money in my purse, "
    "and nothing particular to interest me on shore, "
    "I thought I would sail about a little and see the watery part of the world",
    tokenizer=tokenizer,
    model=model,
    temperature=1,
    k=0.5,
)

df.plot(x="label", y="probability", kind="bar")
df.iloc[[0,1,9]]
label probability
0 balance 0.363756
1 oos 0.059927
9 insurance_change 0.000506

Code
import pandas as pd

df = top_and_bottom(
    "Call me Ishmael. "
    "Some years ago, never mind how long precisely, "
    "having little or no money in my purse, "
    "and nothing particular to interest me on shore, "
    "I thought I would sail about a little and see the watery part of the world",
    tokenizer=tokenizer,
    model=model,
    temperature=5,
    k=0.5,
)

df.plot(x="label", y="probability", kind="bar")
df.iloc[[0,1,9]]
label probability
0 balance 0.103436
1 oos 0.084447
9 insurance_change 0.000713

Code
import pandas as pd

df = top_and_bottom(
    "Call me Ishmael. "
    "Some years ago, never mind how long precisely, "
    "having little or no money in my purse, "
    "and nothing particular to interest me on shore, "
    "I thought I would sail about a little and see the watery part of the world",
    tokenizer=tokenizer,
    model=model,
    temperature=19,
    k=0.5,
)

df.plot(x="label", y="probability", kind="bar")
df.iloc[[0,1,9]]
label probability
0 balance 0.086083
1 oos 0.086081
9 insurance_change 0.000727

Making it Batchwise

To apply this to distillation I need to make this work with batches instead of single elements. Since we have a working single row version it’s possible to use that to validate a simple batchwise version.

Code
import torch

def scale_batch_simple(
    logits: torch.Tensor,
    temperature: float,
    k: float,
) -> torch.Tensor:
    return torch.cat([
        scale_1d(logits=row, temperature=temperature, k=k)[None, :]
        for row in logits
    ])
Code
logits = torch.rand(16)
temperature = 2
k = 0.5

logits_scaled_1d = scale_1d(
    logits=logits, temperature=temperature, k=k
)
logits_scaled_simple = scale_batch_simple(
    logits=logits[None, :], temperature=temperature, k=k
)

torch.all(torch.eq(
    logits_scaled_1d[None, :], logits_scaled_simple
))
tensor(True)

Now that we have a working batchwise version it can be used to validate a better version. When writing this my main concerns are to only perform batchwise operations. This takes advantage of the parallel performance that GPUs offer. The simple version iterates over the rows which performs very poorly on CUDA.

Code
def scale_batch(
    logits: torch.Tensor,
    temperature: float,
    k: float,
) -> torch.Tensor:
    batch_size, output_size = logits.shape

    logits = logits.softmax(dim=1)
    logit_indices = logits.argsort(dim=1, descending=True)

    # when selecting multiple unaligned indices from a 2d tensor you can pass two lists
    # they are joined by index to determine the value to select:
    #
    #   t = torch.tensor(range(9)).reshape(3, 3)
    #   >>> tensor([[0, 1, 2],
    #   >>>         [3, 4, 5],
    #   >>>         [6, 7, 8]])
    #
    #   t[[0, 1], [2, 0]]
    #   >>> tensor([2, 3])
    #
    # You can see that this has selected indices (0,2)=2 and (1,0)=3 rather than (0,1)=1 and (2,0)=6
    # This can be extended to select multi dimensionally:
    #
    #   t[[[0], [1]], [[2], [0]]]
    #   >>> tensor([[2],
    #   >>>         [3]])
    #
    # The shape of the output is the same as the shape of the index lookup

    # the single index is used when selecting a single index per batch
    single_index = torch.tensor(range(batch_size), device=logits.device)

    # the remaining index is used when selecting one less than all indices per batch
    rest_index = single_index[:, None].broadcast_to(batch_size, output_size - 1)

    # these all have shape [batch_size, 1]
    top_probability = logits[single_index, logit_indices[:, 0]][:, None]
    remaining_probability = 1 - top_probability

    second_probability = logits[single_index, logit_indices[:, 1]][:, None]

    top_probability_change = (top_probability - second_probability)
    top_probability_change /= 1 + (second_probability / remaining_probability)
    top_probability_change *= 1 - k ** (temperature - 1)

    # this has a shape [batch_size, output_size-1]
    rest_probability = logits[rest_index, logit_indices[:, 1:]]
    rest_probability_change = top_probability_change * rest_probability / remaining_probability

    # if these changes are applied directly to logits then it breaks back propagation
    # generating an offset tensor that can be added to logits works fine
    offset = torch.zeros_like(logits)
    offset[single_index, logit_indices[:, 0]] = -top_probability_change[:, 0]
    offset[rest_index, logit_indices[:, 1:]] = rest_probability_change

    return logits + offset
Code
logits = torch.rand(16, 32)
temperature = 2
k = 0.5

logits_scaled_simple = scale_batch_simple(
    logits=logits, temperature=temperature, k=k
)
logits_scaled_batch = scale_batch(
    logits=logits, temperature=temperature, k=k
)

torch.all(torch.eq(
    logits_scaled_simple, logits_scaled_batch
))
tensor(True)

Nice. I’m particularly pleased that the batchwise operations have perfectly duplicated the original. All of this is running on CPU so that is to be expected - on CUDA there is a chance of more change (especially if I turned on inference mode or amp).

Testing It

To establish if the new approach is better than the old I need to compare them. The new temperature parameter does not have the same scaling factor as the old one. If the same settings were used for the new approach and the old one there is no reason to believe they would both train well.

In the distillation workshop they had performed a parameter search to find appropriate settings. If I do the same for the two approaches then I can compare them.

WandB can perform a bayesian search across continuous hyperparameters. This should have a reasonable chance at finding good parameters to use for the task. If I run this for both approaches then an approximate maximum accuracy can be achieved, which can then be compared.

It’s also possible to run a random sweep across the hyperparameters. This will show how well a random run would do. An approach that has a good average run could be used to try out distillation without requiring an expensive hyperparameter search.

The code below uses the WandB API to retrieve the run results.

Code
sweep_id_to_name = {
    "8ystnjqx": "random-scaling-fixed-k",
    "x2mjyyyx": "random-temperature",
    "n2gx1jpe": "scaling-fixed-k",
    "rc7u9aec": "temperature",
}
Code
import pandas as pd
import wandb.apis.public as wandb_api

api = wandb_api.Api()
runs = api.runs(path="matthewfranglen/distillation-temperature")
results = pd.DataFrame([
    {"sweep_id": run.sweep.name, "accuracy": run.summary["eval/accuracy"]}
    for run in runs
])
results["sweep"] = results.sweep_id.map(sweep_id_to_name)
results = results[~results.sweep.isnull()]
Code
results
sweep_id accuracy sweep
272 x2mjyyyx 0.032258 random-temperature
273 x2mjyyyx 0.845806 random-temperature
274 x2mjyyyx 0.032258 random-temperature
275 x2mjyyyx 0.032258 random-temperature
276 x2mjyyyx 0.031290 random-temperature
... ... ... ...
667 rc7u9aec 0.248710 temperature
668 rc7u9aec 0.012258 temperature
669 rc7u9aec 0.601613 temperature
670 rc7u9aec 0.032258 temperature
671 rc7u9aec 0.567097 temperature

400 rows × 3 columns

Code
(
    results[["sweep", "accuracy"]]
        .groupby("sweep")
        .agg(max)
        .rename(columns={"accuracy": "best_accuracy"})
        .sort_values(by="best_accuracy", ascending=False)
)
best_accuracy
sweep
temperature 0.954194
random-scaling-fixed-k 0.952903
scaling-fixed-k 0.949355
random-temperature 0.944516
Code
(
    results[["sweep", "accuracy"]]
        .groupby("sweep")
        .agg(len)
        .rename(columns={"accuracy": "runs"})
)
runs
sweep
random-scaling-fixed-k 40
random-temperature 40
scaling-fixed-k 192
temperature 128
Code
(
    results
        .sort_values(by="accuracy")
        [["accuracy", "sweep"]]
        .boxplot(by="sweep", grid=False, figsize=(9,5))
) ; None

Conclusions

You can see that the highest overall accuracy was achieved by the baesian search over the temperature hyperparameters. The maximum accuracy is very close with all approaches, but the higher average for the baesian temperature search suggests that it is a more predictable hyperparameter space.

The random searches show that the scaling approach is more likely to perform better with a random set of hyperparameters, however the average accuracy is still very poor (at about 0.1).

So the new approach that I have proposed is worse than the existing temperature approach. It also takes longer to compute.

Code and Runs

You can see the code here: https://gitlab.com/matthewfranglen/distillation-temperature

You can see the runs here: https://wandb.ai/matthewfranglen/distillation-temperature