Code
import blog.transformers_logging
June 24, 2022
I should also try the concurrent training of the model with a language modelling task. Does this need to split the model? Can I just use a ratio between the two tasks and a separate head?
Would be good to try to come up with some target signatures for word senses using wikipedia and see how well it can determine sense as an evaluation metric.
Need to further study the model collapse that happens in this notebook. Why does the overlap metric seem to do so well?
Does the dataset need to be reprocessed to capitalize the first letter of the target?
To evaluate these changes I’m going to take all of the evaluations that were done in the last post and run them over each version of the model.
# from src/main/python/blog/prompt_internalization/multilingual/roberta/trainer.py
from itertools import starmap
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from transformers import AutoModelForMaskedLM, AutoTokenizer, Trainer, TrainingArguments
from transformers.modeling_outputs import MaskedLMOutput
class MultilingualMaskedPromptInternalizationTrainingArguments(TrainingArguments):
def __init__(
self,
*args,
temperature: float = 2.0,
mean_prediction: bool = True,
ignore_tokens: Optional[List[int]] = None,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.temperature = temperature
self.mean_prediction = mean_prediction
if ignore_tokens is not None:
self.ignore_tokens = ignore_tokens
else:
self.ignore_tokens = []
class MultilingualMaskedPromptInternalizationTrainer(Trainer):
def __init__(
self,
*args,
teacher_model: AutoModelForMaskedLM = None,
tokenizer: AutoTokenizer = None,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.teacher = teacher_model
self._move_model_to_device(self.teacher, self.model.device)
self.teacher.eval()
self.mask_token_id = tokenizer.mask_token_id
def compute_loss(
self,
model: AutoModelForMaskedLM,
inputs: Dict[str, Union[torch.Tensor, Any]],
return_outputs: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
outputs: MaskedLMOutput = model(
input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]
)
if self.args.mean_prediction:
predictions = self._student_predictions_mean(
outputs=outputs, labels=inputs["labels"]
)
else:
predictions = self._student_predictions_first(
outputs=outputs, labels=inputs["labels"]
)
targets = self._teacher_predictions(
input_ids=inputs["teacher_input_ids"],
attention_mask=inputs["teacher_attention_mask"],
)
loss = self._loss(predictions=predictions, targets=targets)
if not return_outputs:
return loss
# This directly calculates the kl_div and overlap metrics.
# It's much faster to do this using CUDA operations instead of waiting for cpu numpy.
with torch.inference_mode():
kl_div = F.kl_div(
input=F.log_softmax(predictions.to(torch.float32), dim=-1),
target=F.softmax(targets.to(torch.float32), dim=-1),
reduction="none",
log_target=False,
)
kl_div = kl_div.sum(dim=1)
overlap = starmap(
torch.isin,
zip(
predictions.argsort(descending=True)[:, :10],
targets.argsort(descending=True)[:, :10],
),
)
overlap = map(torch.sum, overlap)
overlap = torch.tensor(list(overlap), device=self.model.device)
overlap = overlap / 10
# This will reshape the metrics to be [batch_size, 2] which will then
# get correctly passed to the metric calculation
metric_output = torch.cat([kl_div[:, None], overlap[:, None]], dim=1)
return loss, metric_output
@torch.inference_mode()
def _teacher_predictions(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
outputs_teacher = self.teacher(
input_ids=input_ids,
attention_mask=attention_mask,
)
mask_indices = input_ids == self.mask_token_id
teacher_predictions = outputs_teacher.logits[mask_indices]
teacher_predictions[:, self.args.ignore_tokens] = teacher_predictions.min()
return teacher_predictions
def _student_predictions_mean(
self, outputs: MaskedLMOutput, labels: torch.Tensor
) -> torch.Tensor:
# When calculating this it is very important to avoid breaking back propagation.
# torch.cat will break back propagation, so the prediction is added per row to a holder
logits = outputs.logits
predictions = torch.zeros(logits.shape[0], device=logits.device)
for index, (start, length) in enumerate(labels):
prediction = logits[index, start : start + length]
prediction = prediction.mean(dim=0)
predictions[index] += prediction
return predictions
def _student_predictions_first(
self,
outputs: MaskedLMOutput,
labels: torch.Tensor,
) -> torch.Tensor:
return outputs.logits[range(outputs.logits.shape[0]), labels[:, 0]]
def _loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
predictions = F.log_softmax(
predictions.to(torch.float32) / self.args.temperature, dim=-1
)
targets = F.softmax(targets.to(torch.float32) / self.args.temperature, dim=-1)
loss = F.kl_div(
input=predictions,
target=targets,
reduction="batchmean",
log_target=False,
)
return loss * (self.args.temperature**2)
# from src/main/python/blog/prompt_internalization/multilingual/roberta/collator.py
from typing import Any, Dict, List
from transformers import AutoTokenizer
class TeacherStudentCollator:
"""
The teacher inputs need to be padded and have an associated attention mask.
"""
def __init__(self, tokenizer: AutoTokenizer) -> None:
self.tokenizer = tokenizer
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
teacher_inputs = self._teacher_inputs(features)
student_inputs = self._student_inputs(features)
batch = {**teacher_inputs, **student_inputs}
if "label" in batch:
batch["labels"] = batch["label"]
del batch["label"]
if "label_ids" in batch:
batch["labels"] = batch["label_ids"]
del batch["label_ids"]
return batch
def _teacher_inputs(self, features: List[Dict[str, Any]]) -> Dict[str, List[Any]]:
teacher_inputs = [{"input_ids": row["teacher_input_ids"]} for row in features]
teacher_batch = self.tokenizer.pad(
teacher_inputs,
padding=True,
return_tensors="pt",
)
return {
"teacher_input_ids": teacher_batch["input_ids"],
"teacher_attention_mask": teacher_batch["attention_mask"],
}
def _student_inputs(self, features: List[Dict[str, Any]]) -> Dict[str, List[Any]]:
student_inputs = [
{
"input_ids": row["input_ids"],
"labels": row["labels"][0], # known to have a single entry
}
for row in features
]
return self.tokenizer.pad(
student_inputs,
padding=True,
return_tensors="pt",
)
# from src/main/python/blog/prompt_internalization/multilingual/roberta/metrics.py
from typing import Dict
from transformers import EvalPrediction
def compute_metrics(model_output: EvalPrediction) -> Dict[str, float]:
kl_div = model_output.predictions[:, 0].mean()
overlap = model_output.predictions[:, 1].mean()
return {
"kl_div": kl_div,
"overlap": overlap,
}
# from src/main/python/blog/prompt_internalization/multilingual/roberta/train.py
from pathlib import Path
from typing import List, Optional
import datasets
from transformers import AutoModelForMaskedLM, AutoTokenizer
from .collator import TeacherStudentCollator
from .metrics import compute_metrics
from .trainer import (
MultilingualMaskedPromptInternalizationTrainer,
MultilingualMaskedPromptInternalizationTrainingArguments,
)
DATASET_FOLDER = Path("/data/tatoeba/2022-06-18/dataset/")
MODEL_FOLDER = Path("/data/prompt-internalization/multilingual/")
RUN_FOLDER = Path("/tmp/runs")
MODEL_FOLDER.mkdir(parents=True, exist_ok=True)
RUN_FOLDER.mkdir(parents=True, exist_ok=True)
def train(
*,
model_name: str = "xlm-roberta-base",
dataset_name: str = "xlm-roberta",
batch_size: int = 64,
learning_rate: float = 1e-4,
temperature: float = 2,
fp16: bool = False,
mean_prediction: bool = False,
ignore_tokens: Optional[List[int]] = None,
epochs: Optional[float] = 2,
max_steps: int = -1,
evaluation_steps: int = 500,
) -> Path:
run_name = "-".join(
[
f"{model_name}",
f"e{epochs}" if max_steps == -1 else f"ms{max_steps}",
f"bs{batch_size}",
f"lr{learning_rate}",
f"t{temperature}",
]
+ (["fp16"] if fp16 else [])
+ (["mean"] if mean_prediction else [])
+ ([f"it{len(ignore_tokens)}"] if ignore_tokens else [])
)
print(f"Starting {run_name}")
train_ds = datasets.load_from_disk(DATASET_FOLDER / f"{dataset_name}-train.dataset")
test_ds = datasets.load_from_disk(DATASET_FOLDER / f"{dataset_name}-test.dataset")
training_args = MultilingualMaskedPromptInternalizationTrainingArguments(
report_to="none",
output_dir=RUN_FOLDER,
num_train_epochs=epochs,
max_steps=max_steps,
seed=33,
# number of steps before moving evaluation results from GPU to CPU see
# https://discuss.huggingface.co/t/cuda-out-of-memory-when-using-trainer-with-compute-metrics/2941
eval_accumulation_steps=5,
#
# hyperparameters
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
fp16=fp16,
temperature=temperature,
mean_prediction=mean_prediction,
ignore_tokens=ignore_tokens,
learning_rate=learning_rate,
#
# evaluation settings
evaluation_strategy="steps",
logging_steps=evaluation_steps,
eval_steps=evaluation_steps,
save_steps=evaluation_steps,
#
# checkpoint settings
logging_dir=RUN_FOLDER / "logs",
save_total_limit=2,
load_best_model_at_end=True,
metric_for_best_model="overlap",
greater_is_better=True,
remove_unused_columns=False,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
teacher_model = AutoModelForMaskedLM.from_pretrained(model_name)
student_model = AutoModelForMaskedLM.from_pretrained(model_name)
data_collator = TeacherStudentCollator(tokenizer=tokenizer)
trainer = MultilingualMaskedPromptInternalizationTrainer(
model=student_model,
args=training_args,
teacher_model=teacher_model,
train_dataset=train_ds,
eval_dataset=test_ds,
data_collator=data_collator,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)
trainer.train()
student_model.save_pretrained(MODEL_FOLDER / run_name)
return MODEL_FOLDER / run_name
# from src/main/python/blog/prompt_internalization/multilingual/roberta/evaluate.py
from pathlib import Path
from typing import List, Optional, Tuple
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
def evaluate(
model_name: str, model_path: Path, ignore_tokens: Optional[List[int]] = None
) -> None:
if ignore_tokens is None:
ignore_tokens = []
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_path)
model.eval()
bass_evaluation(model=model, tokenizer=tokenizer, ignore_tokens=ignore_tokens)
friday_evaluation(model=model, tokenizer=tokenizer, ignore_tokens=ignore_tokens)
malibu_evaluation(model=model, tokenizer=tokenizer, ignore_tokens=ignore_tokens)
football_evaluation(model=model, tokenizer=tokenizer, ignore_tokens=ignore_tokens)
def bass_evaluation(
model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, ignore_tokens: List[int]
) -> None:
first_phrase = "We spotted a large bass in the ocean."
second_phrase = "The bass player did not receive the acknowledgment she deserves."
third_phrase = "The black sea bass, is a member of the wreckfish family."
first_predicted_words = get_predictions(
model=model,
tokenizer=tokenizer,
text=first_phrase,
noun="bass",
ignore_tokens=ignore_tokens,
)
second_predicted_words = get_predictions(
model=model,
tokenizer=tokenizer,
text=second_phrase,
noun="bass",
ignore_tokens=ignore_tokens,
)
third_predicted_words = get_predictions(
model=model,
tokenizer=tokenizer,
text=third_phrase,
noun="bass",
ignore_tokens=ignore_tokens,
)
print("=== BASS EVALUATION ===")
print(f"First Phrase is: {first_phrase} Target is: bass")
print(f"Description is: {', '.join(first_predicted_words)}")
print()
print(f"Second Phrase is: {second_phrase} Target is: bass")
print(f"Description is: {', '.join(second_predicted_words)}")
print()
print(f"Third Phrase is: {third_phrase} Target is: bass")
print(f"Description is: {', '.join(third_predicted_words)}")
print()
print(
f"First & Second: {sorted(set(first_predicted_words) & set(second_predicted_words))}"
)
print(
f"First & Third: {sorted(set(first_predicted_words) & set(third_predicted_words))}"
)
print(
f"Second & Third: {sorted(set(second_predicted_words) & set(third_predicted_words))}"
)
print()
def friday_evaluation(
model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, ignore_tokens: List[int]
) -> None:
spanish_text = "Friday es mi canción favorita."
english_text = "Friday is my favourite song."
spanish_predicted_words = get_predictions(
model=model,
tokenizer=tokenizer,
text=spanish_text,
noun="Friday",
ignore_tokens=ignore_tokens,
)
english_predicted_words = get_predictions(
model=model,
tokenizer=tokenizer,
text=english_text,
noun="Friday",
ignore_tokens=ignore_tokens,
)
overlap = set(spanish_predicted_words) & set(english_predicted_words)
difference = set(spanish_predicted_words) ^ set(english_predicted_words)
print("=== FRIDAY EVALUATION ===")
print(f"Spanish Phrase is: {spanish_text}")
print(f"Spanish Description is: {', '.join(spanish_predicted_words)}")
print(f"English Phrase is: {english_text}")
print(f"English Description is: {', '.join(english_predicted_words)}")
print()
print(f"Description Overlap is: {', '.join(sorted(overlap))}")
print(f"Description Difference is: {', '.join(sorted(difference))}")
print()
def malibu_evaluation(
model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, ignore_tokens: List[int]
) -> None:
text = "I like to drive my Malibu while drinking Malibu."
first_predicted_words = get_predictions(
model=model,
tokenizer=tokenizer,
text=text,
noun="Malibu",
ignore_tokens=ignore_tokens,
)
second_predicted_words = get_predictions(
model=model,
tokenizer=tokenizer,
text=text,
noun="Malibu",
index=1,
ignore_tokens=ignore_tokens,
)
print("=== MALIBU EVALUATION ===")
print(f"Phrase is: {text}")
print(f"First Malibu (car) Description is: {', '.join(first_predicted_words)}")
print(f"Second Malibu (drink) Description is: {', '.join(second_predicted_words)}")
print()
print(
f"First & Second: {sorted(set(first_predicted_words) & set(second_predicted_words))}"
)
print(
f"First ^ Second: {sorted(set(first_predicted_words) ^ set(second_predicted_words))}"
)
print()
def football_evaluation(
model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, ignore_tokens: List[int]
) -> None:
spanish_phrase = (
"Retiremos el equipo de la cancha, "
"Boca no merece jugar esta copa que "
"hace tiempo viene siendo desprestigiada.\n"
"Ya no se juega al futbol."
)
english_phrase = (
"Let's remove the team from the field, "
"Boca does not deserve to play this cup that "
"has long been discredited. "
"Football is no longer played."
)
print("=== FOOTBALL EVALUATION ===")
print(f"Spanish Phrase is: {spanish_phrase}")
print(f"English Phrase is: {english_phrase}")
print()
for spanish_noun, english_noun in [
["equipo", "team"],
["Boca", "Boca"],
["copa", "cup"],
["tiempo", "long"],
["futbol", "Football"],
]:
spanish_description = get_predictions(
model=model,
tokenizer=tokenizer,
text=spanish_phrase,
noun=spanish_noun,
ignore_tokens=ignore_tokens,
)
english_description = get_predictions(
model=model,
tokenizer=tokenizer,
text=english_phrase,
noun=english_noun,
ignore_tokens=ignore_tokens,
)
overlap = set(spanish_description) & set(english_description)
difference = set(spanish_description) ^ set(english_description)
print(f"Spanish word is: {spanish_noun}, English word is: {english_noun}")
print(f"Spanish Description is: {', '.join(spanish_description)}")
print(f"English Description is: {', '.join(english_description)}")
print(f"Overlap is: {', '.join(sorted(overlap))} ({len(overlap)})")
print(f"Difference is: {', '.join(sorted(difference))} ({len(difference)})")
print()
@torch.inference_mode()
def get_predictions(
*,
model: AutoModelForMaskedLM,
tokenizer: AutoTokenizer,
text: str,
noun: str,
index: int = 0,
ignore_tokens: Optional[List[int]] = None,
) -> List[str]:
if ignore_tokens is None:
ignore_tokens = []
tokens = tokenizer(text, return_tensors="pt")
start, _end = get_noun(
tokenizer=tokenizer, tokens=tokens.input_ids[0], noun=noun, index=index
)
output = model(**tokens)
predictions = output.logits[0, start]
predictions[ignore_tokens] = predictions.min()
predicted_tokens = predictions.argsort(descending=True)[:10]
predicted_words = [
word.strip() for word in tokenizer.batch_decode(predicted_tokens)
]
return predicted_words
def get_noun(
tokenizer: AutoTokenizer, tokens: torch.Tensor, noun: str, index: int
) -> Tuple[int, int]:
length = tokens.shape[0]
current_index = index
for start_index in range(length):
word = tokenizer.decode(tokens[start_index]).strip()
if not noun.startswith(word):
continue
for end_index in range(start_index + 1, length):
word = tokenizer.decode(tokens[start_index:end_index]).strip()
if not noun == word:
continue
if current_index > 0:
current_index -= 1
else:
return start_index, end_index
raise AssertionError(f"Did not find {noun}[{index}] in {tokenizer.decode(tokens)}")
Now we can try training it with the larger model. Given that the code was quite well defined in the last post this should be quick.
Step | Training Loss | Validation Loss | Kl Div | Overlap |
---|---|---|---|---|
1000 | 1.758800 | 1.667300 | 2.010278 | 0.736336 |
2000 | 1.720300 | 1.667612 | 2.018803 | 0.652912 |
3000 | 1.709100 | 1.669667 | 2.007134 | 0.558660 |
4000 | 1.711100 | 1.663013 | 2.008279 | 0.558660 |
5000 | 1.706700 | 1.661403 | 2.010511 | 0.558660 |
6000 | 1.703900 | 1.660943 | 2.006225 | 0.558660 |
This overlap looks odd given that the loss & KL Divergence remains high.
Could not locate the tokenizer configuration file, will try to use the model config instead.
=== BASS EVALUATION ===
First Phrase is: We spotted a large bass in the ocean. Target is: bass
Description is: Name, Owner, Location, Size, Item, Type, Time, Species, Personality, Status
Second Phrase is: The bass player did not receive the acknowledgment she deserves. Target is: bass
Description is: Name, Owner, Location, Size, Item, Type, Time, Species, Personality, Status
Third Phrase is: The black sea bass, is a member of the wreckfish family. Target is: bass
Description is: Name, Owner, Location, Size, Item, Type, Time, Species, Personality, Status
First & Second: ['Item', 'Location', 'Name', 'Owner', 'Personality', 'Size', 'Species', 'Status', 'Time', 'Type']
First & Third: ['Item', 'Location', 'Name', 'Owner', 'Personality', 'Size', 'Species', 'Status', 'Time', 'Type']
Second & Third: ['Item', 'Location', 'Name', 'Owner', 'Personality', 'Size', 'Species', 'Status', 'Time', 'Type']
=== FRIDAY EVALUATION ===
Spanish Phrase is: Friday es mi canción favorita.
Spanish Description is: Name, Owner, Location, Size, Item, Type, Time, Species, Personality, Status
English Phrase is: Friday is my favourite song.
English Description is: Name, Owner, Location, Size, Item, Type, Time, Species, Personality, Status
Description Overlap is: Item, Location, Name, Owner, Personality, Size, Species, Status, Time, Type
Description Difference is:
=== MALIBU EVALUATION ===
Phrase is: I like to drive my Malibu while drinking Malibu.
First Malibu (car) Description is: Name, Owner, Location, Size, Item, Type, Time, Species, Personality, Status
Second Malibu (drink) Description is: Name, Owner, Location, Size, Item, Type, Time, Species, Personality, Status
First & Second: ['Item', 'Location', 'Name', 'Owner', 'Personality', 'Size', 'Species', 'Status', 'Time', 'Type']
First ^ Second: []
=== FOOTBALL EVALUATION ===
Spanish Phrase is: Retiremos el equipo de la cancha, Boca no merece jugar esta copa que hace tiempo viene siendo desprestigiada.
Ya no se juega al futbol.
English Phrase is: Let's remove the team from the field, Boca does not deserve to play this cup that has long been discredited. Football is no longer played.
Spanish word is: equipo, English word is: team
Spanish Description is: Name, Owner, Location, Size, Item, Type, Time, Species, Personality, Status
English Description is: Name, Owner, Location, Size, Item, Type, Time, Species, Personality, Status
Overlap is: Item, Location, Name, Owner, Personality, Size, Species, Status, Time, Type (10)
Difference is: (0)
Spanish word is: Boca, English word is: Boca
Spanish Description is: Name, Owner, Location, Size, Item, Type, Time, Species, Personality, Status
English Description is: Name, Owner, Location, Size, Item, Type, Time, Species, Personality, Status
Overlap is: Item, Location, Name, Owner, Personality, Size, Species, Status, Time, Type (10)
Difference is: (0)
Spanish word is: copa, English word is: cup
Spanish Description is: Name, Owner, Location, Size, Item, Type, Time, Species, Personality, Status
English Description is: Name, Owner, Location, Size, Item, Type, Time, Species, Personality, Status
Overlap is: Item, Location, Name, Owner, Personality, Size, Species, Status, Time, Type (10)
Difference is: (0)
Spanish word is: tiempo, English word is: long
Spanish Description is: Name, Owner, Location, Size, Item, Type, Time, Species, Personality, Status
English Description is: Name, Owner, Location, Size, Item, Type, Time, Species, Personality, Status
Overlap is: Item, Location, Name, Owner, Personality, Size, Species, Status, Time, Type (10)
Difference is: (0)
Spanish word is: futbol, English word is: Football
Spanish Description is: Name, Owner, Location, Size, Item, Type, Time, Species, Personality, Status
English Description is: Name, Owner, Location, Size, Item, Type, Time, Species, Personality, Status
Overlap is: Item, Location, Name, Owner, Personality, Size, Species, Status, Time, Type (10)
Difference is: (0)
Something very bad has happened here. The model has completely collapsed into producing the same output for every token.
Why has this happened? Is the large model somehow incompatible with the tokenized data for roberta-base?
One sanity check is to test the prompted teacher to see what it returns. If it always returns the same output then this would be a strong indication that the tokens are wrong somehow.
We can use the roberta-large tokenizer to decode the roberta-base encoded data as well to check compatibility. I believe that the tokenizers are identical.
from pathlib import Path
import datasets
from transformers import AutoModelForMaskedLM, AutoTokenizer
DATASET_FOLDER = Path("/data/tatoeba/2022-06-18/encoded/")
train_ds = datasets.load_from_disk(DATASET_FOLDER / MODEL_NAME / "single-noun-train.dataset")
model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
from typing import Any, List
import torch
@torch.inference_mode()
def get_predictions_from_tokens(tokens: List[int]) -> List[str]:
print(tokenizer.decode(tokens))
inputs = torch.tensor(tokens, dtype=torch.long)[None, :]
output = model(inputs)
predictions = output.logits[inputs == tokenizer.mask_token_id]
tokens = predictions[0].argsort(descending=True)[:10]
return tokenizer.batch_decode(tokens)
<s>I was in the mountains. Pet: Dog, Color: Yellow, Vehicle: Tractor, Fruit: Banana,<mask>: mountains</s>
[' Location',
' Destination',
' Place',
' Places',
' Region',
' Country',
' Environment',
' Area',
' Where',
' Scene']
<s>I told them to send me another ticket. Pet: Dog, Color: Yellow, Vehicle: Tractor, Fruit: Banana,<mask>: ticket</s>
[' Type',
' Size',
' Status',
' Price',
' Item',
' Quantity',
' Condition',
' Reason',
' Currency',
' Result']
<s>It depends on the context. Pet: Dog, Color: Yellow, Vehicle: Tractor, Fruit: Banana,<mask>: context</s>
[' Context',
' Location',
' Subject',
' Meaning',
' Note',
' Source',
' Type',
' Age',
' Size',
' Mood']
<s>The Germans are very crafty. Pet: Dog, Color: Yellow, Vehicle: Tractor, Fruit: Banana,<mask>: Germans</s>
[' People',
' Country',
' Race',
' Culture',
' Species',
' Language',
' Nation',
' Peoples',
' Population',
' Tribe']
row 0 overlap: [' Location']
row 1 overlap: [' Location']
row 2 overlap: [' Location']
row 3 overlap: [' Location']
These predictions are varied and relevant to the different terms. The teacher is clearly capable of handling the tokens and producing relevant output. Every one has predicted location which is not ideal.
The problem here is that the overlap metric is hovering around \(\frac{5}{10}\) while this spot check is only \(\frac{1}{10}\). Checking the overlap metric seems to be quite important.
Let’s start by confirming that the student is hopelessly busted. Based on the evaluation, above, the student looks like it is producing the same output for every token.
from pathlib import Path
import datasets
from transformers import AutoModelForMaskedLM, AutoTokenizer
DATASET_FOLDER = Path("/data/tatoeba/2022-06-18/encoded/")
train_ds = datasets.load_from_disk(DATASET_FOLDER / MODEL_NAME / "single-noun-train.dataset")
MODEL_FOLDER = Path("/data/prompt-internalization/multilingual/roberta-large-e2-bs32-lr0.0001-t2/")
model = AutoModelForMaskedLM.from_pretrained(MODEL_FOLDER)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
import torch
with torch.inference_mode():
outputs = model(tokenizer("hello world", return_tensors="pt").input_ids)
tokens = outputs.logits[0].argsort(dim=1, descending=True)[:, :10]
predictions = [
tokenizer.batch_decode(row[:, None])
for row in tokens
]
bad_predictions = [
" Item",
" Location",
" Name",
" Owner",
" Personality",
" Size",
" Species",
" Status",
" Time",
" Type"
]
[
len(set(row) & set(bad_predictions))
for row in predictions
]
[10, 10, 10, 10]
So every single token in that 4 token input gets the same top 10 predictions. The model is clearly busted big time. Why is the metric so wrong?
In this sentence the world is a noun, so we can pick on that for the comparison.
from pathlib import Path
from transformers import AutoModelForMaskedLM, AutoTokenizer
MODEL_FOLDER = Path("/data/prompt-internalization/multilingual/roberta-large-e2-bs32-lr0.0001-t2/")
student_model = AutoModelForMaskedLM.from_pretrained(MODEL_FOLDER)
student_model.eval()
teacher_model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME)
teacher_model.eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
import torch
with torch.inference_mode():
teacher_tokens = tokenizer(
"hello world. Pet: Dog, Color: Yellow, Vehicle: Tractor, Fruit: Banana,<mask>: world",
return_tensors="pt"
).input_ids
teacher_output = teacher_model(teacher_tokens)
teacher_predictions = teacher_output.logits[teacher_tokens == tokenizer.mask_token_id]
teacher_predictions = teacher_predictions[0]
student_tokens = tokenizer(
"hello world.",
return_tensors="pt"
).input_ids
student_output = student_model(student_tokens)
student_predictions = student_output.logits[0, 2] # <s> 0, hello 1, world 2
teacher_predictions.shape, student_predictions.shape
(torch.Size([50265]), torch.Size([50265]))
[' Country',
' Location',
' Destination',
' Planet',
' Region',
' Name',
' Area',
' City',
' Language',
' World']
[' Name',
' Owner',
' Location',
' Size',
' Item',
' Type',
' Time',
' Species',
' Personality',
' Status']
If the metric is working correctly then it should show a 0.2 overlap, as only 2 of the tokens in the top 10 match.
import numpy as np
from itertools import starmap
def _compute_mean_overlap(predictions: np.array) -> float:
argsorted = predictions.argsort(axis=-1)
outputs_top_10 = argsorted[:, 0, :10]
targets_top_10 = argsorted[:, 1, :10]
overlap = np.array(
list(map(np.sum, starmap(np.isin, zip(outputs_top_10, targets_top_10))))
)
overlap = overlap / 10
return overlap.mean()
0.7
There is clearly a problem here. Since I got over clever with the code in the metric it would be good to break down how I got it so wrong.
['�',
'<pad>',
' guiIcon',
'iHUD',
' Dragonbound',
'channelAvailability',
' davidjl',
'��士',
'<unk>',
'FactoryReloaded']
Well that was hard. The numpy argsort method sorts in ascending order and cannot be flipped.
import numpy as np
from itertools import starmap
def _compute_mean_overlap(predictions: np.array) -> float:
argsorted = predictions.argsort(axis=-1)
outputs_top_10 = argsorted[:, 0, -10:] # changed slice
targets_top_10 = argsorted[:, 1, -10:] # changed slice
overlap = np.array(
list(map(np.sum, starmap(np.isin, zip(outputs_top_10, targets_top_10))))
)
overlap = overlap / 10
return overlap.mean()
I’ll have to redo the training with this. The metric being broken is not the cause of the model collapse though. That’s what needs investigating next.
The loss is what guides the student to mimic the teacher. If the student breaks it’s because the loss made it do it. Since I managed to mess up the overlap metric let’s see if I did the same for the loss function.
If I’ve managed to get the roberta-base results to date with a thoroughly broken trainer that would be hilarious.
from pathlib import Path
from transformers import AutoModelForMaskedLM, AutoTokenizer
MODEL_FOLDER = Path("/data/prompt-internalization/multilingual/roberta-large-e2-bs32-lr0.0001-t2/")
student_model = AutoModelForMaskedLM.from_pretrained(MODEL_FOLDER)
student_model.eval()
teacher_model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME)
teacher_model.eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
import torch
with torch.inference_mode():
teacher_tokens = tokenizer(
"hello world. Pet: Dog, Color: Yellow, Vehicle: Tractor, Fruit: Banana,<mask>: world",
return_tensors="pt"
).input_ids
teacher_output = teacher_model(teacher_tokens)
teacher_predictions = teacher_output.logits[teacher_tokens == tokenizer.mask_token_id]
teacher_predictions = teacher_predictions[0]
student_tokens = tokenizer(
"hello world.",
return_tensors="pt"
).input_ids
student_output = student_model(student_tokens)
student_predictions = student_output.logits[0, 2] # <s> 0, hello 1, world 2
teacher_predictions.shape, student_predictions.shape
(torch.Size([50265]), torch.Size([50265]))
tensor(2.4641)
This is without temperature scaling. So based on this the student and teacher do not agree on the distribution, as this is a high loss value. Let’s see if the trainer reproduces this.
from pathlib import Path
RUN_FOLDER = Path("/tmp/ml-prompt-internalization")
# these settings don't matter, apart from mean_prediction
# just want to make this to work out the loss
training_args = MultilingualMaskedPromptInternalizationTrainingArguments(
report_to="none",
output_dir=RUN_FOLDER,
num_train_epochs=1,
# max_steps=max_steps,
seed=33,
# number of steps before moving evaluation results from GPU to CPU see
# https://discuss.huggingface.co/t/cuda-out-of-memory-when-using-trainer-with-compute-metrics/2941
eval_accumulation_steps=5,
#
# hyperparameters
per_device_train_batch_size=64,
per_device_eval_batch_size=64,
fp16=False,
temperature=1, # want to disable it so that loss should equal kl_div
mean_prediction=False,
learning_rate=1e-4,
#
# evaluation settings
evaluation_strategy="steps",
logging_steps=1_000,
eval_steps=1_000,
save_steps=1_000,
#
# checkpoint settings
logging_dir=RUN_FOLDER / "logs",
save_total_limit=2,
load_best_model_at_end=True,
metric_for_best_model="overlap",
greater_is_better=True,
remove_unused_columns=False,
)
trainer = MultilingualMaskedPromptInternalizationTrainer(
args=training_args,
model=student_model,
teacher_model=teacher_model,
tokenizer=tokenizer,
)
teacher_model.cpu()
student_model.cpu() ; None
PyTorch: setting up devices
[' Country',
' Location',
' Destination',
' Planet',
' Region',
' Name',
' Area',
' City',
' Language',
' World']
[' Name',
' Owner',
' Location',
' Size',
' Item',
' Type',
' Time',
' Species',
' Personality',
' Status']
tensor(2.4641, grad_fn=<MulBackward0>)
At least the loss function appears to be correct.
One thing to check is that the teacher does not get altered by the training process. I think that I have put the teacher in eval and that should not change, and the extraction of the predictions is done with inference_mode on. It’s easy to check this as I can just check the prediction tensors to see if they are tracking gradients.
It’s not this either then.
I want to try training again with the fixed metric to see how well it does. It may be that the task itself is prone to collapse and that I need to work on the difficulty.
PyTorch: setting up devices
Could not locate the tokenizer configuration file, will try to use the model config instead.
/home/matthew/.cache/pypoetry/virtualenvs/blog-HrtMnrOS-py3.10/lib/python3.10/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
warnings.warn(
Step | Training Loss | Validation Loss | Kl Div | Overlap |
---|---|---|---|---|
1000 | 1.799300 | 1.667589 | 2.009000 | 0.247837 |
2000 | 1.721000 | 1.668797 | 2.024170 | 0.237137 |
3000 | 1.714600 | 1.669277 | 2.018861 | 0.251481 |
4000 | 1.713100 | 1.667412 | 2.017111 | 0.247837 |
5000 | 1.712600 | 1.663882 | 2.014157 | 0.228524 |
6000 | 1.706900 | 1.668680 | 2.015491 | 0.230340 |
7000 | 1.701000 | 1.666185 | 2.010081 | 0.230340 |
8000 | 1.711700 | 1.664344 | 2.011197 | 0.234733 |
9000 | 1.701200 | 1.663693 | 2.007802 | 0.230340 |
This shows that the overlap metric decreases quite rapidly and that the loss function doesn’t really improve over the run. Given that the roberta-base values for loss and Kullback-Leibler divergence were significantly lower there is a problem with learning the task here.
PyTorch: setting up devices
Could not locate the tokenizer configuration file, will try to use the model config instead.
/home/matthew/.cache/pypoetry/virtualenvs/blog-HrtMnrOS-py3.10/lib/python3.10/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
warnings.warn(
Step | Training Loss | Validation Loss | Kl Div | Overlap |
---|---|---|---|---|
1000 | 1.348200 | 1.221799 | 2.162052 | 0.250932 |
2000 | 1.244600 | 1.223369 | 2.153403 | 0.215783 |
3000 | 1.235600 | 1.218408 | 2.134269 | 0.257677 |
4000 | 1.236200 | 1.210570 | 2.146475 | 0.250932 |
5000 | 1.231500 | 1.208394 | 2.134102 | 0.236969 |
6000 | 1.226900 | 1.207478 | 2.143890 | 0.236969 |
Could not locate the tokenizer configuration file, will try to use the model config instead.
=== BASS EVALUATION ===
First Phrase is: We spotted a large bass in the ocean. Target is: bass
Description is: Name, Type, Owner, Size, Location, Species, Subject, Personality, Item, Status
Second Phrase is: The bass player did not receive the acknowledgment she deserves. Target is: bass
Description is: Name, Type, Owner, Size, Location, Species, Subject, Personality, Item, Status
Third Phrase is: The black sea bass, is a member of the wreckfish family. Target is: bass
Description is: Name, Type, Owner, Size, Location, Species, Subject, Personality, Item, Status
First & Second: ['Item', 'Location', 'Name', 'Owner', 'Personality', 'Size', 'Species', 'Status', 'Subject', 'Type']
First & Third: ['Item', 'Location', 'Name', 'Owner', 'Personality', 'Size', 'Species', 'Status', 'Subject', 'Type']
Second & Third: ['Item', 'Location', 'Name', 'Owner', 'Personality', 'Size', 'Species', 'Status', 'Subject', 'Type']
=== FRIDAY EVALUATION ===
Spanish Phrase is: Friday es mi canción favorita.
Spanish Description is: Name, Type, Owner, Size, Location, Species, Subject, Personality, Item, Status
English Phrase is: Friday is my favourite song.
English Description is: Name, Type, Owner, Size, Location, Species, Subject, Personality, Item, Status
Description Overlap is: Item, Location, Name, Owner, Personality, Size, Species, Status, Subject, Type
Description Difference is:
=== MALIBU EVALUATION ===
Phrase is: I like to drive my Malibu while drinking Malibu.
First Malibu (car) Description is: Name, Type, Owner, Size, Location, Species, Subject, Personality, Item, Status
Second Malibu (drink) Description is: Name, Type, Owner, Size, Location, Species, Subject, Personality, Item, Status
First & Second: ['Item', 'Location', 'Name', 'Owner', 'Personality', 'Size', 'Species', 'Status', 'Subject', 'Type']
First ^ Second: []
=== FOOTBALL EVALUATION ===
Spanish Phrase is: Retiremos el equipo de la cancha, Boca no merece jugar esta copa que hace tiempo viene siendo desprestigiada.
Ya no se juega al futbol.
English Phrase is: Let's remove the team from the field, Boca does not deserve to play this cup that has long been discredited. Football is no longer played.
Spanish word is: equipo, English word is: team
Spanish Description is: Name, Type, Owner, Size, Location, Species, Subject, Personality, Item, Status
English Description is: Name, Type, Owner, Size, Location, Species, Subject, Personality, Item, Status
Overlap is: Item, Location, Name, Owner, Personality, Size, Species, Status, Subject, Type (10)
Difference is: (0)
Spanish word is: Boca, English word is: Boca
Spanish Description is: Name, Type, Owner, Size, Location, Species, Subject, Personality, Item, Status
English Description is: Name, Type, Owner, Size, Location, Species, Subject, Personality, Item, Status
Overlap is: Item, Location, Name, Owner, Personality, Size, Species, Status, Subject, Type (10)
Difference is: (0)
Spanish word is: copa, English word is: cup
Spanish Description is: Name, Type, Owner, Size, Location, Species, Subject, Personality, Item, Status
English Description is: Name, Type, Owner, Size, Location, Species, Subject, Personality, Item, Status
Overlap is: Item, Location, Name, Owner, Personality, Size, Species, Status, Subject, Type (10)
Difference is: (0)
Spanish word is: tiempo, English word is: long
Spanish Description is: Name, Type, Owner, Size, Location, Species, Subject, Personality, Item, Status
English Description is: Name, Type, Owner, Size, Location, Species, Subject, Personality, Item, Status
Overlap is: Item, Location, Name, Owner, Personality, Size, Species, Status, Subject, Type (10)
Difference is: (0)
Spanish word is: futbol, English word is: Football
Spanish Description is: Name, Type, Owner, Size, Location, Species, Subject, Personality, Item, Status
English Description is: Name, Type, Owner, Size, Location, Species, Subject, Personality, Item, Status
Overlap is: Item, Location, Name, Owner, Personality, Size, Species, Status, Subject, Type (10)
Difference is: (0)
The evaluation results haven’t changed, but at least this time the overlap metric is not claiming a strong performance.
There are a few ways to force the model to be more varied in it’s output. The problem seems to be that the task causes the model to collapse on the most common terms.
A few ways to avoid this are:
Force the model to do another task at the same time to maintain flexibility. Language modelling the input would be a good way to try this, as that is the original thing that the model was trained to do.
Manually exclude the top 10 entries that it is currently predicting from both the teacher and student. There is clearly value in predicting them as it gets around 0.2 - 0.25 overlap by doing this.
Measure variation across the batch predictions and punish the model for being too consistent. OR measure variation across all the tokens that the model predicts for the input and punish consistency.
This post has already grown quite significantly and each of these is an experiment on it’s own, so they will be addressed in separate posts.
There are several ways that the training process could be improved which don’t relate to the model or loss itself.
The identification of nouns could be improved. The current dataset uses the medium spacy models, and the accuracy of the large models is better.
When prompting the teacher the pairs come with capitalized words Furthermore capitalizing the noun for the teacher prompt would make it fit better with the other example pairs. To remind you of a prompted sentence:
I was in the mountains. Pet: Dog, Color: Yellow, Vehicle: Tractor, Fruit: Banana,<mask>: mountains
We can see that the noun in the sentence is mountains. In the prompt every pair is capitalized except for mountains and this can lead to different predictions for the mask token. Normalizing this should help.
Once again this will have to be a separate (shorter) post.