Code
from typing import *
from pathlib import Path
March 29, 2021
Now that we have a dataset (see previous post for details) we can train the model. Training is tricky so I want to use the Weights and Bias integration with transformers to report on progress.
In this post we are going to investigate training the model with the dataset that has been created. Lets start by defining the validation and train datasets. The validation dataset is very small because it seems that the transformers trainer will exceed CUDA memory quite easily if it is large.
(1575968, 10000, 10000)
We are going to use the LineByLineTextDataset that was evaluated as part of creating the dataset. This loads articles from individual files. While this is slow it’s a much more memory efficient way to handle this volume of data.
from typing import *
from dataclasses import dataclass
from transformers import PreTrainedTokenizer
from itertools import chain, islice
import torch
@dataclass
class LineByLineTextDataset(torch.utils.data.Dataset):
tokenizer: PreTrainedTokenizer
files: List[Path]
block_size: int = 512
def __len__(self):
return len(self.files)
def __getitem__(self, i):
text = self.files[i].read_text()
return self.tokenizer(
text,
max_length=self.block_size,
return_tensors="pt",
truncation=True,
padding="max_length"
)
I want to log the runs to Weights and Bias as that can help me track the different settings. I’m using the configuration approach which I covered in a previous post. This really just involves creating a configuration dictionary for the run and then naming the run to reflect that. Naming the run after the settings helps spot if we are duplicating a run.
import hashlib
def get_id_for_dict(config_dict: Dict[str, Any]) -> str:
""" This function creates a unique hash
based on the initial config dictionary
that is used to label the model. """
unique_str = ''.join(
f"'{key}':'{value}';"
for key, value in sorted(config_dict.items())
)
return hashlib.sha1(unique_str.encode('utf-8')).hexdigest()[:5]
initial_config = {
"batches": MAX_STEPS,
"learning_rate": LEARNING_RATE,
"model_name": MODEL_NAME,
"batch_size": BATCH_SIZE,
}
RUN_NAME = "_".join(
[
f"{value}_{label}"
for value, label in [
(MODEL_NAME, "model"),
(BATCH_SIZE, "batch_size"),
(MAX_STEPS, "batches"),
(len(train_ds), "train_size"),
(len(valid_ds), "valid_size"),
]
]
+ [get_id_for_dict(initial_config)]
)
RUN_NAME
'distilgpt2_model_8_batch_size_100000_batches_1575968_train_size_25_valid_size_f6434'
wandb: Currently logged in as: matthewfranglen (use `wandb login --relogin` to force relogin)
wandb: wandb version 0.10.23 is available! To upgrade, please run:
wandb: $ pip install wandb --upgrade
/home/matthew/Programming/Blog/blog/notebooks/wandb/run-20210329_172006-1drj67vy
Now that we have the data and the dataset class, we can train the model. First I need to create the datasets that wrap the files and come up with a perplexity metric. Perplexity is a measure of Language Model quality:
In information theory, perplexity is a measurement of how well a probability distribution or probability model predicts a sample. It may be used to compare probability models. A low perplexity indicates the probability distribution is good at predicting the sample.
So a low perplexity score is good. The distilgpt2 model that I am using has a reported perplexity of 21.1 compared to a base of 16.3 for GPT-2. I should not expect to beat these scores. My validation set is very small though.
from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", pad_token="<|endoftext|>") # TODO: Custom tokenizer
train_ds = LineByLineTextDataset(tokenizer, TRAIN_FILES, block_size=1024)
valid_ds = LineByLineTextDataset(tokenizer, VALID_FILES[:25], block_size=1024) # Small for quick validation
len(train_ds), len(valid_ds)
(1575968, 25)
import torch.nn.functional as F
def compute_metrics(pred):
# This loss calculation comes directly from the GPT2 forward method
# that handles correctly offsetting the labels to match the positions that are predicting
labels = torch.tensor(pred.label_ids)
lm_logits = torch.tensor(pred.predictions)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
perplexity = torch.exp(loss)
return {
"perplexity": perplexity.item()
}
from transformers import Trainer, TrainingArguments, AutoModelWithLMHead, DataCollatorForLanguageModeling
model = AutoModelWithLMHead.from_pretrained(MODEL_NAME)
training_arguments = TrainingArguments(
report_to=["wandb"],
output_dir=str(DATA_FOLDER / "output"),
logging_dir=str(DATA_FOLDER / "output"),
overwrite_output_dir=True,
evaluation_strategy="steps",
max_steps=MAX_STEPS,
learning_rate=LEARNING_RATE,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE*2,
eval_steps=500, # Number of update steps between two evaluations.
warmup_steps=500, # number of warmup steps for learning rate scheduler
logging_steps=500,
load_best_model_at_end=True,
metric_for_best_model="perplexity",
greater_is_better=False,
run_name=RUN_NAME
)
trainer = Trainer(
model=model,
args=training_arguments,
data_collator=DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm=False,
),
train_dataset=train_ds,
eval_dataset=valid_ds,
compute_metrics=compute_metrics,
)
/home/matthew/.cache/pypoetry/virtualenvs/blog-HrtMnrOS-py3.8/lib/python3.8/site-packages/transformers/models/auto/modeling_auto.py:966: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.
warnings.warn(
Step | Training Loss | Validation Loss | Perplexity | Runtime | Samples Per Second |
---|---|---|---|---|---|
500 | 4.360400 | 3.702356 | 40.984215 | 5.150400 | 4.854000 |
1000 | 3.780800 | 3.352830 | 28.907923 | 4.727000 | 5.289000 |
1500 | 3.526800 | 3.172196 | 24.133615 | 4.719500 | 5.297000 |
2000 | 3.397700 | 3.064181 | 21.662312 | 4.783200 | 5.227000 |
2500 | 3.318800 | 2.981329 | 19.929367 | 4.717400 | 5.299000 |
3000 | 3.226500 | 2.917050 | 18.689455 | 4.765600 | 5.246000 |
3500 | 3.171300 | 2.861650 | 17.685238 | 4.756400 | 5.256000 |
4000 | 3.124400 | 2.823017 | 17.012983 | 4.759700 | 5.252000 |
4500 | 3.082100 | 2.780272 | 16.301422 | 4.744800 | 5.269000 |
5000 | 3.046500 | 2.745767 | 15.746825 | 4.682300 | 5.339000 |
5500 | 3.013700 | 2.720514 | 15.351411 | 4.747000 | 5.266000 |
6000 | 2.992600 | 2.699369 | 15.028552 | 4.679200 | 5.343000 |
6500 | 2.965500 | 2.678053 | 14.708805 | 4.670900 | 5.352000 |
7000 | 2.943100 | 2.655248 | 14.377626 | 4.665500 | 5.358000 |
7500 | 2.912600 | 2.634398 | 14.080054 | 4.699400 | 5.320000 |
8000 | 2.915000 | 2.612985 | 13.780510 | 4.650400 | 5.376000 |
8500 | 2.882700 | 2.605551 | 13.676591 | 4.692200 | 5.328000 |
9000 | 2.863800 | 2.581378 | 13.345110 | 4.728100 | 5.288000 |
9500 | 2.862200 | 2.568025 | 13.170488 | 4.643900 | 5.383000 |
10000 | 2.817100 | 2.564833 | 13.125784 | 4.682800 | 5.339000 |
10500 | 2.840500 | 2.545068 | 12.869972 | 4.688900 | 5.332000 |
11000 | 2.813000 | 2.529012 | 12.662861 | 4.698200 | 5.321000 |
11500 | 2.809200 | 2.527631 | 12.645857 | 4.695300 | 5.324000 |
12000 | 2.805400 | 2.519823 | 12.541773 | 4.640300 | 5.388000 |
12500 | 2.791800 | 2.504210 | 12.352239 | 4.685600 | 5.335000 |
13000 | 2.772300 | 2.501084 | 12.312334 | 4.705500 | 5.313000 |
13500 | 2.767800 | 2.485152 | 12.117881 | 4.627500 | 5.403000 |
14000 | 2.742900 | 2.473366 | 11.971764 | 4.691600 | 5.329000 |
14500 | 2.767800 | 2.466652 | 11.894001 | 4.661300 | 5.363000 |
15000 | 2.734800 | 2.467555 | 11.906019 | 4.620200 | 5.411000 |
15500 | 2.742700 | 2.447979 | 11.674358 | 4.687200 | 5.334000 |
16000 | 2.730800 | 2.447885 | 11.674782 | 4.691700 | 5.329000 |
16500 | 2.710600 | 2.441351 | 11.598569 | 4.744900 | 5.269000 |
17000 | 2.716400 | 2.435222 | 11.526530 | 4.688400 | 5.332000 |
17500 | 2.691600 | 2.424081 | 11.398600 | 4.690800 | 5.330000 |
18000 | 2.689700 | 2.417682 | 11.325281 | 4.682700 | 5.339000 |
18500 | 2.687900 | 2.411923 | 11.262743 | 4.748300 | 5.265000 |
19000 | 2.671900 | 2.409560 | 11.234286 | 4.730300 | 5.285000 |
19500 | 2.665200 | 2.401380 | 11.142518 | 4.712200 | 5.305000 |
20000 | 2.654200 | 2.395683 | 11.076626 | 4.731000 | 5.284000 |
20500 | 2.681200 | 2.392558 | 11.042551 | 4.688000 | 5.333000 |
21000 | 2.657300 | 2.390264 | 11.016036 | 4.662200 | 5.362000 |
21500 | 2.645500 | 2.391599 | 11.031819 | 4.658300 | 5.367000 |
22000 | 2.649100 | 2.377134 | 10.872148 | 4.723200 | 5.293000 |
22500 | 2.631000 | 2.377407 | 10.875364 | 4.683000 | 5.338000 |
23000 | 2.639600 | 2.371849 | 10.814195 | 4.700600 | 5.318000 |
23500 | 2.633900 | 2.374701 | 10.847894 | 4.681200 | 5.340000 |
24000 | 2.648700 | 2.360695 | 10.694958 | 4.721800 | 5.295000 |
24500 | 2.636500 | 2.358757 | 10.675319 | 4.694500 | 5.325000 |
25000 | 2.633800 | 2.356332 | 10.647686 | 4.701600 | 5.317000 |
25500 | 2.622600 | 2.351246 | 10.593875 | 4.691000 | 5.329000 |
26000 | 2.599400 | 2.347776 | 10.557463 | 4.728000 | 5.288000 |
26500 | 2.598600 | 2.346375 | 10.543932 | 4.680300 | 5.342000 |
27000 | 2.612500 | 2.335228 | 10.427504 | 4.652700 | 5.373000 |
27500 | 2.601000 | 2.337018 | 10.445247 | 4.701600 | 5.317000 |
28000 | 2.589500 | 2.330009 | 10.370514 | 4.709500 | 5.308000 |
28500 | 2.593500 | 2.332979 | 10.402362 | 4.715700 | 5.301000 |
29000 | 2.598800 | 2.326984 | 10.340738 | 4.746900 | 5.267000 |
29500 | 2.599000 | 2.320810 | 10.275547 | 4.671900 | 5.351000 |
30000 | 2.572900 | 2.320885 | 10.277387 | 4.706800 | 5.311000 |
30500 | 2.589400 | 2.324665 | 10.316133 | 4.663300 | 5.361000 |
31000 | 2.577600 | 2.316571 | 10.232056 | 4.670700 | 5.353000 |
31500 | 2.568700 | 2.312280 | 10.187517 | 4.742000 | 5.272000 |
32000 | 2.567600 | 2.313001 | 10.196151 | 4.705200 | 5.313000 |
32500 | 2.566200 | 2.305023 | 10.114561 | 4.675500 | 5.347000 |
33000 | 2.550300 | 2.306582 | 10.130407 | 4.677100 | 5.345000 |
33500 | 2.560200 | 2.299845 | 10.061002 | 4.692800 | 5.327000 |
34000 | 2.545700 | 2.297130 | 10.033359 | 4.721100 | 5.295000 |
34500 | 2.556900 | 2.294095 | 10.005368 | 4.673800 | 5.349000 |
35000 | 2.565300 | 2.290545 | 9.967186 | 4.729100 | 5.286000 |
35500 | 2.547500 | 2.286399 | 9.929468 | 4.735700 | 5.279000 |
36000 | 2.550700 | 2.289327 | 9.958371 | 4.691100 | 5.329000 |
36500 | 2.545600 | 2.283216 | 9.897723 | 4.740600 | 5.274000 |
37000 | 2.551600 | 2.281248 | 9.875886 | 4.722300 | 5.294000 |
37500 | 2.544100 | 2.275671 | 9.821963 | 4.690300 | 5.330000 |
38000 | 2.530200 | 2.278805 | 9.851675 | 4.686800 | 5.334000 |
38500 | 2.536900 | 2.277662 | 9.840133 | 4.725400 | 5.291000 |
39000 | 2.528900 | 2.271726 | 9.781415 | 4.675500 | 5.347000 |
39500 | 2.530200 | 2.271304 | 9.776792 | 4.707800 | 5.310000 |
40000 | 2.546100 | 2.268227 | 9.745853 | 4.672700 | 5.350000 |
40500 | 2.527500 | 2.268438 | 9.749769 | 4.667900 | 5.356000 |
41000 | 2.513100 | 2.265372 | 9.716460 | 4.724500 | 5.292000 |
41500 | 2.533900 | 2.261983 | 9.685142 | 4.760400 | 5.252000 |
42000 | 2.513400 | 2.256519 | 9.634706 | 4.679500 | 5.342000 |
42500 | 2.538600 | 2.258655 | 9.654160 | 4.660400 | 5.364000 |
43000 | 2.511800 | 2.256743 | 9.635938 | 4.668100 | 5.356000 |
43500 | 2.507800 | 2.251618 | 9.588696 | 4.689700 | 5.331000 |
44000 | 2.512700 | 2.251925 | 9.589813 | 4.705400 | 5.313000 |
44500 | 2.498000 | 2.249427 | 9.568373 | 4.663400 | 5.361000 |
45000 | 2.499300 | 2.251360 | 9.587020 | 4.671100 | 5.352000 |
45500 | 2.507800 | 2.251170 | 9.583761 | 4.735400 | 5.279000 |
46000 | 2.520600 | 2.247334 | 9.547936 | 4.684400 | 5.337000 |
46500 | 2.496600 | 2.248953 | 9.561157 | 4.659200 | 5.366000 |
47000 | 2.511800 | 2.244533 | 9.518897 | 4.750900 | 5.262000 |
47500 | 2.488200 | 2.242113 | 9.495392 | 4.660600 | 5.364000 |
48000 | 2.492300 | 2.239975 | 9.474918 | 4.719600 | 5.297000 |
48500 | 2.489200 | 2.239032 | 9.465775 | 4.672300 | 5.351000 |
49000 | 2.498800 | 2.237262 | 9.448555 | 4.685300 | 5.336000 |
49500 | 2.486200 | 2.233460 | 9.414551 | 4.660900 | 5.364000 |
50000 | 2.481400 | 2.231460 | 9.396680 | 4.660300 | 5.364000 |
50500 | 2.497800 | 2.232087 | 9.400949 | 4.661900 | 5.363000 |
51000 | 2.492500 | 2.233625 | 9.416335 | 4.656400 | 5.369000 |
51500 | 2.492400 | 2.225406 | 9.338825 | 4.710400 | 5.307000 |
52000 | 2.465500 | 2.226707 | 9.351920 | 4.667600 | 5.356000 |
52500 | 2.505900 | 2.224616 | 9.331812 | 4.687400 | 5.333000 |
53000 | 2.500400 | 2.226945 | 9.352785 | 4.663900 | 5.360000 |
53500 | 2.504700 | 2.224303 | 9.330186 | 4.745800 | 5.268000 |
54000 | 2.480200 | 2.222404 | 9.311388 | 4.730800 | 5.285000 |
54500 | 2.469000 | 2.220685 | 9.295632 | 4.755200 | 5.257000 |
55000 | 2.472500 | 2.217708 | 9.267027 | 4.695700 | 5.324000 |
55500 | 2.468500 | 2.219917 | 9.286234 | 4.747300 | 5.266000 |
56000 | 2.469700 | 2.216012 | 9.250226 | 4.686600 | 5.334000 |
56500 | 2.474600 | 2.213224 | 9.224898 | 4.744900 | 5.269000 |
57000 | 2.458900 | 2.212269 | 9.215661 | 4.681900 | 5.340000 |
57500 | 2.465600 | 2.214800 | 9.237897 | 4.733600 | 5.281000 |
58000 | 2.452100 | 2.211943 | 9.210569 | 4.717900 | 5.299000 |
58500 | 2.461700 | 2.211983 | 9.213121 | 4.677700 | 5.344000 |
59000 | 2.467300 | 2.207248 | 9.169579 | 4.728700 | 5.287000 |
59500 | 2.468000 | 2.210982 | 9.203514 | 4.758800 | 5.253000 |
60000 | 2.468900 | 2.207473 | 9.171269 | 4.681700 | 5.340000 |
60500 | 2.454300 | 2.205665 | 9.154585 | 4.722100 | 5.294000 |
61000 | 2.467700 | 2.207709 | 9.172531 | 4.731300 | 5.284000 |
61500 | 2.445300 | 2.207564 | 9.171982 | 4.664200 | 5.360000 |
62000 | 2.453200 | 2.204978 | 9.147490 | 4.788700 | 5.221000 |
62500 | 2.450800 | 2.202402 | 9.126425 | 4.703200 | 5.316000 |
63000 | 2.456000 | 2.201880 | 9.120384 | 4.766600 | 5.245000 |
63500 | 2.461100 | 2.199798 | 9.101039 | 4.711700 | 5.306000 |
64000 | 2.464100 | 2.199305 | 9.096694 | 4.752100 | 5.261000 |
64500 | 2.453400 | 2.198195 | 9.086848 | 4.730800 | 5.284000 |
65000 | 2.444200 | 2.194160 | 9.050518 | 4.691100 | 5.329000 |
65500 | 2.450000 | 2.195012 | 9.056788 | 4.686700 | 5.334000 |
66000 | 2.455800 | 2.194539 | 9.052542 | 4.774500 | 5.236000 |
66500 | 2.434100 | 2.194012 | 9.047109 | 4.694200 | 5.326000 |
67000 | 2.444200 | 2.192459 | 9.032662 | 4.736400 | 5.278000 |
67500 | 2.434600 | 2.192902 | 9.039184 | 4.701300 | 5.318000 |
68000 | 2.445200 | 2.193277 | 9.040983 | 4.682200 | 5.339000 |
68500 | 2.445300 | 2.190604 | 9.017118 | 4.737100 | 5.277000 |
69000 | 2.451200 | 2.188234 | 8.994356 | 4.822100 | 5.184000 |
69500 | 2.449500 | 2.189156 | 9.002600 | 4.686700 | 5.334000 |
70000 | 2.456100 | 2.184335 | 8.959962 | 4.695300 | 5.324000 |
70500 | 2.439100 | 2.184447 | 8.961812 | 4.755600 | 5.257000 |
71000 | 2.454800 | 2.183449 | 8.952945 | 4.685000 | 5.336000 |
71500 | 2.459900 | 2.183132 | 8.950529 | 4.694100 | 5.326000 |
72000 | 2.424800 | 2.180700 | 8.929421 | 4.786000 | 5.224000 |
72500 | 2.433800 | 2.182412 | 8.944829 | 4.668500 | 5.355000 |
73000 | 2.440200 | 2.179533 | 8.919517 | 4.678700 | 5.343000 |
73500 | 2.439500 | 2.180656 | 8.928808 | 4.689300 | 5.331000 |
74000 | 2.424700 | 2.180942 | 8.930764 | 4.679600 | 5.342000 |
74500 | 2.452800 | 2.179202 | 8.916446 | 4.710100 | 5.308000 |
75000 | 2.440200 | 2.178581 | 8.909504 | 4.693800 | 5.326000 |
75500 | 2.422300 | 2.176174 | 8.887369 | 4.746800 | 5.267000 |
76000 | 2.428500 | 2.175789 | 8.884776 | 4.796300 | 5.212000 |
76500 | 2.458800 | 2.176743 | 8.893144 | 4.767100 | 5.244000 |
77000 | 2.411500 | 2.179561 | 8.919050 | 4.710000 | 5.308000 |
77500 | 2.434900 | 2.176929 | 8.895175 | 4.695500 | 5.324000 |
78000 | 2.412200 | 2.174379 | 8.873323 | 4.699400 | 5.320000 |
78500 | 2.446300 | 2.176090 | 8.888458 | 4.739800 | 5.274000 |
79000 | 2.418200 | 2.174819 | 8.877240 | 4.701800 | 5.317000 |
79500 | 2.422200 | 2.175611 | 8.884649 | 4.737300 | 5.277000 |
80000 | 2.426600 | 2.173164 | 8.862095 | 4.770400 | 5.241000 |
80500 | 2.423500 | 2.175104 | 8.880151 | 4.764900 | 5.247000 |
81000 | 2.430300 | 2.170741 | 8.842427 | 4.683700 | 5.338000 |
81500 | 2.420800 | 2.170275 | 8.836911 | 4.713700 | 5.304000 |
82000 | 2.424800 | 2.171069 | 8.844029 | 4.693800 | 5.326000 |
82500 | 2.418400 | 2.167583 | 8.812852 | 4.741500 | 5.273000 |
83000 | 2.407900 | 2.170138 | 8.834461 | 4.757800 | 5.255000 |
83500 | 2.424400 | 2.170450 | 8.838120 | 4.692400 | 5.328000 |
84000 | 2.419600 | 2.168507 | 8.820310 | 4.700900 | 5.318000 |
84500 | 2.436900 | 2.166972 | 8.806765 | 4.742700 | 5.271000 |
85000 | 2.434100 | 2.166539 | 8.803656 | 4.694000 | 5.326000 |
85500 | 2.420600 | 2.167152 | 8.809423 | 4.683700 | 5.338000 |
86000 | 2.417200 | 2.167440 | 8.810533 | 4.756200 | 5.256000 |
86500 | 2.424800 | 2.167591 | 8.812583 | 4.717800 | 5.299000 |
87000 | 2.401800 | 2.167519 | 8.812274 | 4.759100 | 5.253000 |
87500 | 2.429100 | 2.166193 | 8.799897 | 4.707300 | 5.311000 |
88000 | 2.413400 | 2.165663 | 8.795700 | 4.688200 | 5.333000 |
88500 | 2.415900 | 2.165704 | 8.795581 | 4.697000 | 5.323000 |
89000 | 2.412600 | 2.164068 | 8.780826 | 4.757800 | 5.255000 |
89500 | 2.423400 | 2.164751 | 8.787511 | 4.702100 | 5.317000 |
90000 | 2.411500 | 2.165240 | 8.791935 | 4.728100 | 5.288000 |
90500 | 2.436600 | 2.164468 | 8.784675 | 4.753600 | 5.259000 |
91000 | 2.412600 | 2.164178 | 8.782509 | 4.699400 | 5.320000 |
91500 | 2.411700 | 2.163445 | 8.776644 | 4.683100 | 5.338000 |
92000 | 2.428700 | 2.164179 | 8.782456 | 4.694600 | 5.325000 |
92500 | 2.415000 | 2.163734 | 8.778793 | 4.699700 | 5.320000 |
93000 | 2.406600 | 2.163711 | 8.778646 | 4.703800 | 5.315000 |
93500 | 2.421600 | 2.162858 | 8.770714 | 4.742500 | 5.271000 |
94000 | 2.421700 | 2.162106 | 8.764411 | 4.688700 | 5.332000 |
94500 | 2.404100 | 2.162465 | 8.767489 | 4.731300 | 5.284000 |
95000 | 2.394900 | 2.162495 | 8.767938 | 4.703700 | 5.315000 |
95500 | 2.421000 | 2.162645 | 8.768976 | 4.744800 | 5.269000 |
96000 | 2.413500 | 2.161637 | 8.760166 | 4.692500 | 5.328000 |
96500 | 2.423300 | 2.162421 | 8.767182 | 4.713600 | 5.304000 |
97000 | 2.412100 | 2.162054 | 8.763784 | 4.714200 | 5.303000 |
97500 | 2.422300 | 2.162224 | 8.765325 | 4.692800 | 5.327000 |
98000 | 2.414400 | 2.161440 | 8.758332 | 4.690100 | 5.330000 |
98500 | 2.407600 | 2.162189 | 8.764719 | 4.705100 | 5.313000 |
99000 | 2.426400 | 2.161754 | 8.760895 | 4.713700 | 5.304000 |
99500 | 2.411900 | 2.161397 | 8.757938 | 4.745500 | 5.268000 |
100000 | 2.403700 | 2.161592 | 8.759642 | 4.746800 | 5.267000 |
TrainOutput(global_step=100000, training_loss=2.578038330078125, metrics={'train_runtime': 68247.9183, 'train_samples_per_second': 1.465, 'total_flos': 402616693555200000, 'epoch': 0.51})
/home/matthew/Programming/Blog/blog/notebooks/wandb/run-20210329_172006-1drj67vy/logs/debug.log
/home/matthew/Programming/Blog/blog/notebooks/wandb/run-20210329_172006-1drj67vy/logs/debug-internal.log
_runtime | 68253 |
_timestamp | 1617103059 |
_step | 100000 |
train/loss | 2.4037 |
train/learning_rate | 0.0 |
train/epoch | 0.51 |
eval/loss | 2.16159 |
eval/perplexity | 8.75964 |
eval/runtime | 4.7468 |
eval/samples_per_second | 5.267 |
train/train_runtime | 68247.9183 |
train/train_samples_per_second | 1.465 |
train/total_flos | 402616693555200000 |
_runtime | ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███ |
_timestamp | ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███ |
_step | ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███ |
train/loss | █▅▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
train/learning_rate | ████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁ |
train/epoch | ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███ |
eval/loss | █▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
eval/perplexity | █▄▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
eval/runtime | ▄▆▂▁▃▃▃▂▁▃▄▃▂▄▃▂▄▃▂▂▁▄▂▄▄▅▅█▇▂▃▃▃▃▅▅▃▃▄▅ |
eval/samples_per_second | ▅▃▇█▆▆▆▇▇▆▅▆▇▅▆▇▅▆▇▇█▅▆▅▅▄▄▁▂▇▆▆▆▆▄▄▆▆▅▄ |
train/train_runtime | ▁ |
train/train_samples_per_second | ▁ |
train/total_flos | ▁ |
After quite a bit of poking I’ve managed to get perplexity out of this. I’m quite pleased with how this has gone. If you are in the same position remember you can look at the source code for something in jupyter by adding ??
. Looking at the source of the GPT2LMHeadModel
(with model??
) helped me find the correct way to calculate the perplexity.
In order to get this running without exceeding the CUDA ram I had to have an extremely limited validation set. Currently there are only 25 articles in the validation set, so the reported validation perplexity should be considered to be quite inaccurate.
Now training the distilgpt2 model is interesting because the model itself has a reported perplexity. The model card reports it as getting a perplexity of 21.1 compared to a base of 16.3 for GPT-2. The current perplexity being reported by the model is about 13.6 after 8,500 batches. This suggests to me that the validation set is easier than the train set.
The training has completed now and the perplexity has reached 8.76 which is unbelievably low. A more systematic evaluation would be appropriate, so it’s lucky I have the test set ready.
10000
{'input_ids': tensor([[ 400, 2178, 91, ..., 50256, 50256, 50256]]), 'attention_mask': tensor([[1, 1, 1, ..., 0, 0, 0]])}
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1]])
tensor([False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False], device='cuda:0')
@torch.no_grad()
def accuracy_and_perplexity(
model: AutoModelWithLMHead,
tokenizer: GPT2TokenizerFast,
input_ids: torch.Tensor,
attention_mask: torch.Tensor
) -> Tuple[float, float]:
assert input_ids.shape[0] == 1, "batch size of 1 required"
# limit the tokens to the attended tokens and cut out the batch
tokens = attention_mask.sum(dim=-1)
input_ids = input_ids[:, :tokens][0]
lm_logits = model(input_ids=input_ids.cuda()).logits
# This loss calculation comes directly from the GPT2 forward method
# that handles correctly offsetting the labels to match the positions that are predicting
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = input_ids[..., 1:].contiguous().cuda()
# Flatten the tokens
view_labels = shift_labels.view(-1)
view_logits = shift_logits.view(-1, shift_logits.size(-1))
predicted_tokens = shift_logits.argmax(dim=-1)
loss = F.cross_entropy(view_logits, view_labels)
perplexity = torch.exp(loss)
accuracy = (predicted_tokens == shift_labels).sum() / tokens.cuda()
return {
"perplexity": perplexity.item(),
"accuracy": accuracy.item(),
"tokens": tokens.item(),
"sentence": tokenizer.decode(input_ids),
"predicted": tokenizer.decode(predicted_tokens),
}
{'perplexity': 14.126980781555176,
'accuracy': 0.5197368264198303,
'tokens': 152,
'sentence': 'thumb|left\nEl estornino esmeralda (Lamprotornis iris) es una especie de estúrnido africano de plumaje verde irisado que puebla las tierras bajas y sabanas de Costa de Marfil, Guinea y Sierra Leona.* BirdLife International. 2012. Coccycolius iris. The IUCN Red List of Threatened Species. Version 2015.3. Acceso: 12 de noviembre de 2015.\n* (2007): Species factsheet: Coccycolius iris. Retrieved 2007-JUL-20.*Xeno-canto. L. iris. Canto.iris',
'predicted': 'umb|right|El nado de unalda deenapususus)is) es una especie de aornbulago dericano de laase quede deis., seueda el montras deajas de lasanas. la Rica Marfil. en Ec Sierra Leona. Life International:\n. Theodonus iris. I IUCN Red List of Threatened Species. 2015.1. Acceso: 19 de feviembre de 2015.ir* Bird\nen). The accountsheet. Ircycolius iris. The 2015-08ULI2015- --canto.comamp iris. Theanto. 2012'}
For a spot check this is kinda amazing? A 50% accurate language model. When I look at the actual and predicted there does seem to be some overlap. It’s a bit janky though.
I need to do this evaluation more systematically.
perplexity | accuracy | tokens | sentence | predicted | |
---|---|---|---|---|---|
0 | 14.126981 | 0.519737 | 152 | thumb|left\nEl estornino esmeralda (Lamprotorn... | umb|right|El nado de unalda deenapususus)is) e... |
1 | 4.801258 | 0.676510 | 745 | La Temporada 2009/10 del Fútbol profesional Ve... | porada de-10 de Camútbol Clubesional esenezol... |
2 | 6.119050 | 0.634195 | 503 | El Cementerio nacional de Puerto Rico es un ce... | municipementerio deacional de San Rico es una... |
3 | 3.590329 | 0.756757 | 37 | Egestria rubicunda es una especie de coleópter... | umbi esiaul es una especie de insectleóptero d... |
4 | 7.143648 | 0.657143 | 35 | Cancello e Arnone es un municipio situado en e... | umbll es Ialdo es unaio itado en el territorio... |
... | ... | ... | ... | ... | ... |
9995 | 2.648958 | 0.833333 | 54 | Acolastus medvedevi es un coleóptero de la fam... | umbophus esusaicusi es unaleóptero de la famil... |
9996 | 2.312675 | 0.777778 | 36 | Phyllopodium es un género con 39 especies de p... | umbotod es un género de dos especies de planta... |
9997 | 10.965172 | 0.509766 | 1024 | es una serie de manga escrita por Rando Ayamin... | una serie de anime escrita por eli Kaka y ilu... |
9998 | 9.442551 | 0.574230 | 357 | El Parque Central de Kaliningrado (en ruso: Це... | municipque N de laamrado esen inguso: Мертаан... |
9999 | 12.729012 | 0.510742 | 1024 | Arabia Saudita y Turquía han disfrutado de una... | ia esita (esihadquía Saudace sidpututado de la... |
10000 rows × 5 columns
perplexity 9.970363
accuracy 0.590948
tokens 456.019700
dtype: float64
The last bit of evaluation I can do is to run the suggested prompt from huggingface gpt2-small-spanish through it.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
'La inteligencia artificial en latinoamérica se ha desarrollado cerca del en el\xa0Estado de Brasil. Los sistemas de inteligencia artificial en latinoamérica permit'
Artificial intelligence in Latin America has developed close to the xa0State of Brazil. Artificial intelligence systems in Latin America allow
Apart from that junk character that seems pretty good? I’m amazed. I expected to have to train this quite a bit more.