Fine Tuning GPT-2 with Transformers - Training the Model

Train GPT-2 with some Spanish data
Published

March 29, 2021

Now that we have a dataset (see previous post for details) we can train the model. Training is tricky so I want to use the Weights and Bias integration with transformers to report on progress.

In this post we are going to investigate training the model with the dataset that has been created. Lets start by defining the validation and train datasets. The validation dataset is very small because it seems that the transformers trainer will exceed CUDA memory quite easily if it is large.

Code
from typing import *
from pathlib import Path
Code
MAX_STEPS = 100_000
LEARNING_RATE = 5e-5
MODEL_NAME = "distilgpt2"
BATCH_SIZE = 8
Code
DATA_FOLDER = Path("data/2021-03-28-fine-tune-gpt2")
OUTPUT_FOLDER = DATA_FOLDER / "spanish-articles"
Code
from sklearn.model_selection import train_test_split

files = list(OUTPUT_FOLDER.glob("*.txt"))
TRAIN_FILES, TEST_FILES = train_test_split(files, test_size=10_000)
TRAIN_FILES, VALID_FILES = train_test_split(TRAIN_FILES, test_size=10_000)

len(TRAIN_FILES), len(VALID_FILES), len(TEST_FILES)
(1575968, 10000, 10000)

We are going to use the LineByLineTextDataset that was evaluated as part of creating the dataset. This loads articles from individual files. While this is slow it’s a much more memory efficient way to handle this volume of data.

Code
from typing import *
from dataclasses import dataclass

from transformers import PreTrainedTokenizer
from itertools import chain, islice
import torch

@dataclass
class LineByLineTextDataset(torch.utils.data.Dataset):
    tokenizer: PreTrainedTokenizer
    files: List[Path]
    block_size: int = 512

    def __len__(self):
        return len(self.files)

    def __getitem__(self, i):
        text = self.files[i].read_text()
        return self.tokenizer(
            text,
            max_length=self.block_size,
            return_tensors="pt",
            truncation=True,
            padding="max_length"
        )

Weights and Bias

I want to log the runs to Weights and Bias as that can help me track the different settings. I’m using the configuration approach which I covered in a previous post. This really just involves creating a configuration dictionary for the run and then naming the run to reflect that. Naming the run after the settings helps spot if we are duplicating a run.

Code
import hashlib

def get_id_for_dict(config_dict: Dict[str, Any]) -> str:
    """ This function creates a unique hash
        based on the initial config dictionary
        that is used to label the model. """

    unique_str = ''.join(
        f"'{key}':'{value}';"
        for key, value in sorted(config_dict.items())
    )
    return hashlib.sha1(unique_str.encode('utf-8')).hexdigest()[:5]

initial_config = {
    "batches": MAX_STEPS,
    "learning_rate": LEARNING_RATE,
    "model_name": MODEL_NAME,
    "batch_size": BATCH_SIZE,
}

RUN_NAME = "_".join(
    [
        f"{value}_{label}"
        for value, label in [
            (MODEL_NAME, "model"),
            (BATCH_SIZE, "batch_size"),
            (MAX_STEPS, "batches"),
            (len(train_ds), "train_size"),
            (len(valid_ds), "valid_size"),
        ]
    ]
    + [get_id_for_dict(initial_config)]
)
RUN_NAME
'distilgpt2_model_8_batch_size_100000_batches_1575968_train_size_25_valid_size_f6434'
Code
import wandb

run = wandb.init(
    project="mf-blog-spanish-gpt2",
    name=RUN_NAME,
    config=initial_config
)
wandb: Currently logged in as: matthewfranglen (use `wandb login --relogin` to force relogin)
wandb: wandb version 0.10.23 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade
Tracking run with wandb version 0.10.20
Syncing run distilgpt2_model_8_batch_size_100000_batches_1575968_train_size_25_valid_size_f6434 to Weights & Biases (Documentation).
Project page: https://wandb.ai/matthewfranglen/mf-blog-spanish-gpt2
Run page: https://wandb.ai/matthewfranglen/mf-blog-spanish-gpt2/runs/1drj67vy
Run data is saved locally in /home/matthew/Programming/Blog/blog/notebooks/wandb/run-20210329_172006-1drj67vy


Train the Model

Now that we have the data and the dataset class, we can train the model. First I need to create the datasets that wrap the files and come up with a perplexity metric. Perplexity is a measure of Language Model quality:

In information theory, perplexity is a measurement of how well a probability distribution or probability model predicts a sample. It may be used to compare probability models. A low perplexity indicates the probability distribution is good at predicting the sample.

So a low perplexity score is good. The distilgpt2 model that I am using has a reported perplexity of 21.1 compared to a base of 16.3 for GPT-2. I should not expect to beat these scores. My validation set is very small though.

Code
from transformers import GPT2TokenizerFast

tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", pad_token="<|endoftext|>") # TODO: Custom tokenizer
train_ds = LineByLineTextDataset(tokenizer, TRAIN_FILES, block_size=1024)
valid_ds = LineByLineTextDataset(tokenizer, VALID_FILES[:25], block_size=1024) # Small for quick validation

len(train_ds), len(valid_ds)
(1575968, 25)
Code
import torch.nn.functional as F

def compute_metrics(pred):
    # This loss calculation comes directly from the GPT2 forward method
    # that handles correctly offsetting the labels to match the positions that are predicting

    labels = torch.tensor(pred.label_ids)
    lm_logits = torch.tensor(pred.predictions)

    # Shift so that tokens < n predict n
    shift_logits = lm_logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    # Flatten the tokens
    loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

    perplexity = torch.exp(loss)
    return {
        "perplexity": perplexity.item()
    }
Code
from transformers import Trainer, TrainingArguments, AutoModelWithLMHead, DataCollatorForLanguageModeling

model = AutoModelWithLMHead.from_pretrained(MODEL_NAME)

training_arguments = TrainingArguments(
    report_to=["wandb"],
    output_dir=str(DATA_FOLDER / "output"),
    logging_dir=str(DATA_FOLDER / "output"),
    overwrite_output_dir=True,
    evaluation_strategy="steps",
    max_steps=MAX_STEPS,
    learning_rate=LEARNING_RATE,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE*2,
    eval_steps=500, # Number of update steps between two evaluations.
    warmup_steps=500, # number of warmup steps for learning rate scheduler
    logging_steps=500,
    load_best_model_at_end=True,    
    metric_for_best_model="perplexity",
    greater_is_better=False,
    run_name=RUN_NAME
)

trainer = Trainer(
    model=model,
    args=training_arguments,
    data_collator=DataCollatorForLanguageModeling(
        tokenizer=tokenizer, mlm=False,
    ),
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    compute_metrics=compute_metrics,
)
/home/matthew/.cache/pypoetry/virtualenvs/blog-HrtMnrOS-py3.8/lib/python3.8/site-packages/transformers/models/auto/modeling_auto.py:966: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.
  warnings.warn(
Code
trainer.train()
[100000/100000 18:57:27, Epoch 0/1]
Step Training Loss Validation Loss Perplexity Runtime Samples Per Second
500 4.360400 3.702356 40.984215 5.150400 4.854000
1000 3.780800 3.352830 28.907923 4.727000 5.289000
1500 3.526800 3.172196 24.133615 4.719500 5.297000
2000 3.397700 3.064181 21.662312 4.783200 5.227000
2500 3.318800 2.981329 19.929367 4.717400 5.299000
3000 3.226500 2.917050 18.689455 4.765600 5.246000
3500 3.171300 2.861650 17.685238 4.756400 5.256000
4000 3.124400 2.823017 17.012983 4.759700 5.252000
4500 3.082100 2.780272 16.301422 4.744800 5.269000
5000 3.046500 2.745767 15.746825 4.682300 5.339000
5500 3.013700 2.720514 15.351411 4.747000 5.266000
6000 2.992600 2.699369 15.028552 4.679200 5.343000
6500 2.965500 2.678053 14.708805 4.670900 5.352000
7000 2.943100 2.655248 14.377626 4.665500 5.358000
7500 2.912600 2.634398 14.080054 4.699400 5.320000
8000 2.915000 2.612985 13.780510 4.650400 5.376000
8500 2.882700 2.605551 13.676591 4.692200 5.328000
9000 2.863800 2.581378 13.345110 4.728100 5.288000
9500 2.862200 2.568025 13.170488 4.643900 5.383000
10000 2.817100 2.564833 13.125784 4.682800 5.339000
10500 2.840500 2.545068 12.869972 4.688900 5.332000
11000 2.813000 2.529012 12.662861 4.698200 5.321000
11500 2.809200 2.527631 12.645857 4.695300 5.324000
12000 2.805400 2.519823 12.541773 4.640300 5.388000
12500 2.791800 2.504210 12.352239 4.685600 5.335000
13000 2.772300 2.501084 12.312334 4.705500 5.313000
13500 2.767800 2.485152 12.117881 4.627500 5.403000
14000 2.742900 2.473366 11.971764 4.691600 5.329000
14500 2.767800 2.466652 11.894001 4.661300 5.363000
15000 2.734800 2.467555 11.906019 4.620200 5.411000
15500 2.742700 2.447979 11.674358 4.687200 5.334000
16000 2.730800 2.447885 11.674782 4.691700 5.329000
16500 2.710600 2.441351 11.598569 4.744900 5.269000
17000 2.716400 2.435222 11.526530 4.688400 5.332000
17500 2.691600 2.424081 11.398600 4.690800 5.330000
18000 2.689700 2.417682 11.325281 4.682700 5.339000
18500 2.687900 2.411923 11.262743 4.748300 5.265000
19000 2.671900 2.409560 11.234286 4.730300 5.285000
19500 2.665200 2.401380 11.142518 4.712200 5.305000
20000 2.654200 2.395683 11.076626 4.731000 5.284000
20500 2.681200 2.392558 11.042551 4.688000 5.333000
21000 2.657300 2.390264 11.016036 4.662200 5.362000
21500 2.645500 2.391599 11.031819 4.658300 5.367000
22000 2.649100 2.377134 10.872148 4.723200 5.293000
22500 2.631000 2.377407 10.875364 4.683000 5.338000
23000 2.639600 2.371849 10.814195 4.700600 5.318000
23500 2.633900 2.374701 10.847894 4.681200 5.340000
24000 2.648700 2.360695 10.694958 4.721800 5.295000
24500 2.636500 2.358757 10.675319 4.694500 5.325000
25000 2.633800 2.356332 10.647686 4.701600 5.317000
25500 2.622600 2.351246 10.593875 4.691000 5.329000
26000 2.599400 2.347776 10.557463 4.728000 5.288000
26500 2.598600 2.346375 10.543932 4.680300 5.342000
27000 2.612500 2.335228 10.427504 4.652700 5.373000
27500 2.601000 2.337018 10.445247 4.701600 5.317000
28000 2.589500 2.330009 10.370514 4.709500 5.308000
28500 2.593500 2.332979 10.402362 4.715700 5.301000
29000 2.598800 2.326984 10.340738 4.746900 5.267000
29500 2.599000 2.320810 10.275547 4.671900 5.351000
30000 2.572900 2.320885 10.277387 4.706800 5.311000
30500 2.589400 2.324665 10.316133 4.663300 5.361000
31000 2.577600 2.316571 10.232056 4.670700 5.353000
31500 2.568700 2.312280 10.187517 4.742000 5.272000
32000 2.567600 2.313001 10.196151 4.705200 5.313000
32500 2.566200 2.305023 10.114561 4.675500 5.347000
33000 2.550300 2.306582 10.130407 4.677100 5.345000
33500 2.560200 2.299845 10.061002 4.692800 5.327000
34000 2.545700 2.297130 10.033359 4.721100 5.295000
34500 2.556900 2.294095 10.005368 4.673800 5.349000
35000 2.565300 2.290545 9.967186 4.729100 5.286000
35500 2.547500 2.286399 9.929468 4.735700 5.279000
36000 2.550700 2.289327 9.958371 4.691100 5.329000
36500 2.545600 2.283216 9.897723 4.740600 5.274000
37000 2.551600 2.281248 9.875886 4.722300 5.294000
37500 2.544100 2.275671 9.821963 4.690300 5.330000
38000 2.530200 2.278805 9.851675 4.686800 5.334000
38500 2.536900 2.277662 9.840133 4.725400 5.291000
39000 2.528900 2.271726 9.781415 4.675500 5.347000
39500 2.530200 2.271304 9.776792 4.707800 5.310000
40000 2.546100 2.268227 9.745853 4.672700 5.350000
40500 2.527500 2.268438 9.749769 4.667900 5.356000
41000 2.513100 2.265372 9.716460 4.724500 5.292000
41500 2.533900 2.261983 9.685142 4.760400 5.252000
42000 2.513400 2.256519 9.634706 4.679500 5.342000
42500 2.538600 2.258655 9.654160 4.660400 5.364000
43000 2.511800 2.256743 9.635938 4.668100 5.356000
43500 2.507800 2.251618 9.588696 4.689700 5.331000
44000 2.512700 2.251925 9.589813 4.705400 5.313000
44500 2.498000 2.249427 9.568373 4.663400 5.361000
45000 2.499300 2.251360 9.587020 4.671100 5.352000
45500 2.507800 2.251170 9.583761 4.735400 5.279000
46000 2.520600 2.247334 9.547936 4.684400 5.337000
46500 2.496600 2.248953 9.561157 4.659200 5.366000
47000 2.511800 2.244533 9.518897 4.750900 5.262000
47500 2.488200 2.242113 9.495392 4.660600 5.364000
48000 2.492300 2.239975 9.474918 4.719600 5.297000
48500 2.489200 2.239032 9.465775 4.672300 5.351000
49000 2.498800 2.237262 9.448555 4.685300 5.336000
49500 2.486200 2.233460 9.414551 4.660900 5.364000
50000 2.481400 2.231460 9.396680 4.660300 5.364000
50500 2.497800 2.232087 9.400949 4.661900 5.363000
51000 2.492500 2.233625 9.416335 4.656400 5.369000
51500 2.492400 2.225406 9.338825 4.710400 5.307000
52000 2.465500 2.226707 9.351920 4.667600 5.356000
52500 2.505900 2.224616 9.331812 4.687400 5.333000
53000 2.500400 2.226945 9.352785 4.663900 5.360000
53500 2.504700 2.224303 9.330186 4.745800 5.268000
54000 2.480200 2.222404 9.311388 4.730800 5.285000
54500 2.469000 2.220685 9.295632 4.755200 5.257000
55000 2.472500 2.217708 9.267027 4.695700 5.324000
55500 2.468500 2.219917 9.286234 4.747300 5.266000
56000 2.469700 2.216012 9.250226 4.686600 5.334000
56500 2.474600 2.213224 9.224898 4.744900 5.269000
57000 2.458900 2.212269 9.215661 4.681900 5.340000
57500 2.465600 2.214800 9.237897 4.733600 5.281000
58000 2.452100 2.211943 9.210569 4.717900 5.299000
58500 2.461700 2.211983 9.213121 4.677700 5.344000
59000 2.467300 2.207248 9.169579 4.728700 5.287000
59500 2.468000 2.210982 9.203514 4.758800 5.253000
60000 2.468900 2.207473 9.171269 4.681700 5.340000
60500 2.454300 2.205665 9.154585 4.722100 5.294000
61000 2.467700 2.207709 9.172531 4.731300 5.284000
61500 2.445300 2.207564 9.171982 4.664200 5.360000
62000 2.453200 2.204978 9.147490 4.788700 5.221000
62500 2.450800 2.202402 9.126425 4.703200 5.316000
63000 2.456000 2.201880 9.120384 4.766600 5.245000
63500 2.461100 2.199798 9.101039 4.711700 5.306000
64000 2.464100 2.199305 9.096694 4.752100 5.261000
64500 2.453400 2.198195 9.086848 4.730800 5.284000
65000 2.444200 2.194160 9.050518 4.691100 5.329000
65500 2.450000 2.195012 9.056788 4.686700 5.334000
66000 2.455800 2.194539 9.052542 4.774500 5.236000
66500 2.434100 2.194012 9.047109 4.694200 5.326000
67000 2.444200 2.192459 9.032662 4.736400 5.278000
67500 2.434600 2.192902 9.039184 4.701300 5.318000
68000 2.445200 2.193277 9.040983 4.682200 5.339000
68500 2.445300 2.190604 9.017118 4.737100 5.277000
69000 2.451200 2.188234 8.994356 4.822100 5.184000
69500 2.449500 2.189156 9.002600 4.686700 5.334000
70000 2.456100 2.184335 8.959962 4.695300 5.324000
70500 2.439100 2.184447 8.961812 4.755600 5.257000
71000 2.454800 2.183449 8.952945 4.685000 5.336000
71500 2.459900 2.183132 8.950529 4.694100 5.326000
72000 2.424800 2.180700 8.929421 4.786000 5.224000
72500 2.433800 2.182412 8.944829 4.668500 5.355000
73000 2.440200 2.179533 8.919517 4.678700 5.343000
73500 2.439500 2.180656 8.928808 4.689300 5.331000
74000 2.424700 2.180942 8.930764 4.679600 5.342000
74500 2.452800 2.179202 8.916446 4.710100 5.308000
75000 2.440200 2.178581 8.909504 4.693800 5.326000
75500 2.422300 2.176174 8.887369 4.746800 5.267000
76000 2.428500 2.175789 8.884776 4.796300 5.212000
76500 2.458800 2.176743 8.893144 4.767100 5.244000
77000 2.411500 2.179561 8.919050 4.710000 5.308000
77500 2.434900 2.176929 8.895175 4.695500 5.324000
78000 2.412200 2.174379 8.873323 4.699400 5.320000
78500 2.446300 2.176090 8.888458 4.739800 5.274000
79000 2.418200 2.174819 8.877240 4.701800 5.317000
79500 2.422200 2.175611 8.884649 4.737300 5.277000
80000 2.426600 2.173164 8.862095 4.770400 5.241000
80500 2.423500 2.175104 8.880151 4.764900 5.247000
81000 2.430300 2.170741 8.842427 4.683700 5.338000
81500 2.420800 2.170275 8.836911 4.713700 5.304000
82000 2.424800 2.171069 8.844029 4.693800 5.326000
82500 2.418400 2.167583 8.812852 4.741500 5.273000
83000 2.407900 2.170138 8.834461 4.757800 5.255000
83500 2.424400 2.170450 8.838120 4.692400 5.328000
84000 2.419600 2.168507 8.820310 4.700900 5.318000
84500 2.436900 2.166972 8.806765 4.742700 5.271000
85000 2.434100 2.166539 8.803656 4.694000 5.326000
85500 2.420600 2.167152 8.809423 4.683700 5.338000
86000 2.417200 2.167440 8.810533 4.756200 5.256000
86500 2.424800 2.167591 8.812583 4.717800 5.299000
87000 2.401800 2.167519 8.812274 4.759100 5.253000
87500 2.429100 2.166193 8.799897 4.707300 5.311000
88000 2.413400 2.165663 8.795700 4.688200 5.333000
88500 2.415900 2.165704 8.795581 4.697000 5.323000
89000 2.412600 2.164068 8.780826 4.757800 5.255000
89500 2.423400 2.164751 8.787511 4.702100 5.317000
90000 2.411500 2.165240 8.791935 4.728100 5.288000
90500 2.436600 2.164468 8.784675 4.753600 5.259000
91000 2.412600 2.164178 8.782509 4.699400 5.320000
91500 2.411700 2.163445 8.776644 4.683100 5.338000
92000 2.428700 2.164179 8.782456 4.694600 5.325000
92500 2.415000 2.163734 8.778793 4.699700 5.320000
93000 2.406600 2.163711 8.778646 4.703800 5.315000
93500 2.421600 2.162858 8.770714 4.742500 5.271000
94000 2.421700 2.162106 8.764411 4.688700 5.332000
94500 2.404100 2.162465 8.767489 4.731300 5.284000
95000 2.394900 2.162495 8.767938 4.703700 5.315000
95500 2.421000 2.162645 8.768976 4.744800 5.269000
96000 2.413500 2.161637 8.760166 4.692500 5.328000
96500 2.423300 2.162421 8.767182 4.713600 5.304000
97000 2.412100 2.162054 8.763784 4.714200 5.303000
97500 2.422300 2.162224 8.765325 4.692800 5.327000
98000 2.414400 2.161440 8.758332 4.690100 5.330000
98500 2.407600 2.162189 8.764719 4.705100 5.313000
99000 2.426400 2.161754 8.760895 4.713700 5.304000
99500 2.411900 2.161397 8.757938 4.745500 5.268000
100000 2.403700 2.161592 8.759642 4.746800 5.267000

TrainOutput(global_step=100000, training_loss=2.578038330078125, metrics={'train_runtime': 68247.9183, 'train_samples_per_second': 1.465, 'total_flos': 402616693555200000, 'epoch': 0.51})
Code
trainer.save_model()
Code
run.finish()

Waiting for W&B process to finish, PID 2180937
Program ended successfully.
Find user logs for this run at: /home/matthew/Programming/Blog/blog/notebooks/wandb/run-20210329_172006-1drj67vy/logs/debug.log
Find internal logs for this run at: /home/matthew/Programming/Blog/blog/notebooks/wandb/run-20210329_172006-1drj67vy/logs/debug-internal.log

Run summary:


_runtime68253
_timestamp1617103059
_step100000
train/loss2.4037
train/learning_rate0.0
train/epoch0.51
eval/loss2.16159
eval/perplexity8.75964
eval/runtime4.7468
eval/samples_per_second5.267
train/train_runtime68247.9183
train/train_samples_per_second1.465
train/total_flos402616693555200000

Run history:


_runtime▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/loss█▅▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/learning_rate████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁
train/epoch▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
eval/loss█▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eval/perplexity█▄▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eval/runtime▄▆▂▁▃▃▃▂▁▃▄▃▂▄▃▂▄▃▂▂▁▄▂▄▄▅▅█▇▂▃▃▃▃▅▅▃▃▄▅
eval/samples_per_second▅▃▇█▆▆▆▇▇▆▅▆▇▅▆▇▅▆▇▇█▅▆▅▅▄▄▁▂▇▆▆▆▆▄▄▆▆▅▄
train/train_runtime
train/train_samples_per_second
train/total_flos

Synced 5 W&B file(s), 1 media file(s), 0 artifact file(s) and 0 other file(s)

Synced distilgpt2_model_8_batch_size_100000_batches_1575968_train_size_25_valid_size_f6434: https://wandb.ai/matthewfranglen/mf-blog-spanish-gpt2/runs/1drj67vy

After quite a bit of poking I’ve managed to get perplexity out of this. I’m quite pleased with how this has gone. If you are in the same position remember you can look at the source code for something in jupyter by adding ??. Looking at the source of the GPT2LMHeadModel (with model??) helped me find the correct way to calculate the perplexity.

In order to get this running without exceeding the CUDA ram I had to have an extremely limited validation set. Currently there are only 25 articles in the validation set, so the reported validation perplexity should be considered to be quite inaccurate.

Now training the distilgpt2 model is interesting because the model itself has a reported perplexity. The model card reports it as getting a perplexity of 21.1 compared to a base of 16.3 for GPT-2. The current perplexity being reported by the model is about 13.6 after 8,500 batches. This suggests to me that the validation set is easier than the train set.


The training has completed now and the perplexity has reached 8.76 which is unbelievably low. A more systematic evaluation would be appropriate, so it’s lucky I have the test set ready.

Code
test_ds = LineByLineTextDataset(tokenizer=tokenizer, files=TEST_FILES, block_size=1024)
len(test_ds)
10000
Code
next(iter(test_ds))
{'input_ids': tensor([[  400,  2178,    91,  ..., 50256, 50256, 50256]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0]])}
Code
_15["attention_mask"].sum(dim=-1)
tensor([152])
Code
_15["attention_mask"][:, :_15["attention_mask"].sum(dim=-1)]
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1]])
Code
with torch.no_grad():
    print(model(_15["input_ids"][:, :152][0].cuda()).logits.argmax(dim=-1) == _15["input_ids"][0, :152].cuda())
tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False], device='cuda:0')
Code
@torch.no_grad()
def accuracy_and_perplexity(
    model: AutoModelWithLMHead,
    tokenizer: GPT2TokenizerFast,
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor
) -> Tuple[float, float]:
    assert input_ids.shape[0] == 1, "batch size of 1 required"

    # limit the tokens to the attended tokens and cut out the batch
    tokens = attention_mask.sum(dim=-1)
    input_ids = input_ids[:, :tokens][0]
    
    lm_logits = model(input_ids=input_ids.cuda()).logits
    
    # This loss calculation comes directly from the GPT2 forward method
    # that handles correctly offsetting the labels to match the positions that are predicting

    # Shift so that tokens < n predict n
    shift_logits = lm_logits[..., :-1, :].contiguous()
    shift_labels = input_ids[..., 1:].contiguous().cuda()
    
    # Flatten the tokens
    view_labels = shift_labels.view(-1)
    view_logits = shift_logits.view(-1, shift_logits.size(-1))
    predicted_tokens = shift_logits.argmax(dim=-1)
    
    loss = F.cross_entropy(view_logits, view_labels)
    perplexity = torch.exp(loss)
    
    accuracy = (predicted_tokens == shift_labels).sum() / tokens.cuda()
    return {
        "perplexity": perplexity.item(),
        "accuracy": accuracy.item(),
        "tokens": tokens.item(),
        "sentence": tokenizer.decode(input_ids),
        "predicted": tokenizer.decode(predicted_tokens),
    }
Code
accuracy_and_perplexity(model=model, tokenizer=tokenizer, **next(iter(test_ds)))
{'perplexity': 14.126980781555176,
 'accuracy': 0.5197368264198303,
 'tokens': 152,
 'sentence': 'thumb|left\nEl estornino esmeralda (Lamprotornis iris) es una especie de estúrnido africano de plumaje verde irisado que puebla las tierras bajas y sabanas de Costa de Marfil, Guinea y Sierra Leona.* BirdLife International. 2012. Coccycolius iris. The IUCN Red List of Threatened Species. Version 2015.3. Acceso: 12 de noviembre de 2015.\n*  (2007): Species factsheet: Coccycolius iris. Retrieved 2007-JUL-20.*Xeno-canto. L. iris. Canto.iris',
 'predicted': 'umb|right|El nado de unalda deenapususus)is) es una especie de aornbulago dericano de laase quede deis., seueda el montras deajas de lasanas. la Rica Marfil. en Ec Sierra Leona. Life International:\n. Theodonus iris. I IUCN Red List of Threatened Species.  2015.1. Acceso: 19 de feviembre de 2015.ir* Bird\nen). The accountsheet. Ircycolius iris. The 2015-08ULI2015- --canto.comamp iris. Theanto. 2012'}

For a spot check this is kinda amazing? A 50% accurate language model. When I look at the actual and predicted there does seem to be some overlap. It’s a bit janky though.

I need to do this evaluation more systematically.

Code
import pandas as pd
from tqdm.auto import tqdm

def evaluate(
    model: AutoModelWithLMHead,
    tokenizer: GPT2TokenizerFast,
    dataset: LineByLineTextDataset,
) -> pd.DataFrame:
    return pd.DataFrame(
        accuracy_and_perplexity(
            model=model,
            tokenizer=tokenizer,
            **datum
        )
        for datum in tqdm(dataset)
    )
Code
evaluation_df = evaluate(model=model, tokenizer=tokenizer, dataset=test_ds)
Code
evaluation_df
perplexity accuracy tokens sentence predicted
0 14.126981 0.519737 152 thumb|left\nEl estornino esmeralda (Lamprotorn... umb|right|El nado de unalda deenapususus)is) e...
1 4.801258 0.676510 745 La Temporada 2009/10 del Fútbol profesional Ve... porada de-10 de Camútbol Clubesional esenezol...
2 6.119050 0.634195 503 El Cementerio nacional de Puerto Rico es un ce... municipementerio deacional de San Rico es una...
3 3.590329 0.756757 37 Egestria rubicunda es una especie de coleópter... umbi esiaul es una especie de insectleóptero d...
4 7.143648 0.657143 35 Cancello e Arnone es un municipio situado en e... umbll es Ialdo es unaio itado en el territorio...
... ... ... ... ... ...
9995 2.648958 0.833333 54 Acolastus medvedevi es un coleóptero de la fam... umbophus esusaicusi es unaleóptero de la famil...
9996 2.312675 0.777778 36 Phyllopodium es un género con 39 especies de p... umbotod es un género de dos especies de planta...
9997 10.965172 0.509766 1024 es una serie de manga escrita por Rando Ayamin... una serie de anime escrita por eli Kaka y ilu...
9998 9.442551 0.574230 357 El Parque Central de Kaliningrado (en ruso: Це... municipque N de laamrado esen inguso: Мертаан...
9999 12.729012 0.510742 1024 Arabia Saudita y Turquía han disfrutado de una... ia esita (esihadquía Saudace sidpututado de la...

10000 rows × 5 columns

Code
evaluation_df.drop(columns=["sentence", "predicted"]).mean()
perplexity      9.970363
accuracy        0.590948
tokens        456.019700
dtype: float64
Code
evaluation_df.to_csv("language-model-evaluation.csv")

The last bit of evaluation I can do is to run the suggested prompt from huggingface gpt2-small-spanish through it.

Code
from transformers import pipeline

language_model = pipeline(
    'text-generation',
    model=model.cpu(),
    tokenizer=tokenizer,
    config={'max_length': 1024}
)

result = language_model('La inteligencia artificial en latinoamérica se ha desarrollado ')[0]['generated_text']
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Code
result
'La inteligencia artificial en latinoamérica se ha desarrollado  cerca del  en el\xa0Estado de  Brasil. Los sistemas de inteligencia artificial en latinoamérica permit'

Artificial intelligence in Latin America has developed close to the  xa0State of Brazil. Artificial intelligence systems in Latin America allow

Apart from that junk character that seems pretty good? I’m amazed. I expected to have to train this quite a bit more.