Cross Language Prompt Internalization - Model Collapse

Does XLM-RoBERTa collapse into a fixed set of outputs if it is trained for a long time
prompt internalization
multilingual prompt internalization
cross language word sense induction
Published

July 10, 2022

This is a further investigation into model collapse. The investigation of prompt internalization has been hindered by the collapse and that makes me think that the task itself is not stable. If a model that is trained for a long time collapses then I may have to review the structure of the task itself.

Code
import blog.transformers_logging

Code

Just smashing everything in from the last time.

Code
# from src/main/python/blog/prompt_internalization/multilingual/roberta/trainer.py
from itertools import starmap
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from transformers import AutoModelForMaskedLM, AutoTokenizer, Trainer, TrainingArguments
from transformers.modeling_outputs import MaskedLMOutput


class MultilingualMaskedPromptInternalizationTrainingArguments(TrainingArguments):
    def __init__(
        self,
        *args,
        temperature: float = 2.0,
        mean_prediction: bool = True,
        ignore_tokens: Optional[List[int]] = None,
        **kwargs,
    ) -> None:
        super().__init__(*args, **kwargs)
        self.temperature = temperature
        self.mean_prediction = mean_prediction
        if ignore_tokens is not None:
            self.ignore_tokens = ignore_tokens
        else:
            self.ignore_tokens = []


class MultilingualMaskedPromptInternalizationTrainer(Trainer):
    def __init__(
        self,
        *args,
        teacher_model: AutoModelForMaskedLM = None,
        tokenizer: AutoTokenizer = None,
        **kwargs,
    ) -> None:
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        self._move_model_to_device(self.teacher, self.model.device)
        self.teacher.eval()
        self.mask_token_id = tokenizer.mask_token_id

    def compute_loss(
        self,
        model: AutoModelForMaskedLM,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        return_outputs: bool = False,
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
        outputs: MaskedLMOutput = model(
            input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]
        )
        if self.args.mean_prediction:
            predictions = self._student_predictions_mean(
                outputs=outputs, labels=inputs["labels"]
            )
        else:
            predictions = self._student_predictions_first(
                outputs=outputs, labels=inputs["labels"]
            )
        targets = self._teacher_predictions(
            input_ids=inputs["teacher_input_ids"],
            attention_mask=inputs["teacher_attention_mask"],
        )
        loss = self._loss(predictions=predictions, targets=targets)

        if not return_outputs:
            return loss

        # This directly calculates the kl_div and overlap metrics.
        # It's much faster to do this using CUDA operations instead of waiting for cpu numpy.
        with torch.inference_mode():
            kl_div = F.kl_div(
                input=F.log_softmax(predictions.to(torch.float32), dim=-1),
                target=F.softmax(targets.to(torch.float32), dim=-1),
                reduction="none",
                log_target=False,
            )
            kl_div = kl_div.sum(dim=1)

            overlap = starmap(
                torch.isin,
                zip(
                    predictions.argsort(descending=True)[:, :10],
                    targets.argsort(descending=True)[:, :10],
                ),
            )
            overlap = map(torch.sum, overlap)
            overlap = torch.tensor(list(overlap), device=self.model.device)
            overlap = overlap / 10

        # This will reshape the metrics to be [batch_size, 2] which will then
        # get correctly passed to the metric calculation
        metric_output = torch.cat([kl_div[:, None], overlap[:, None]], dim=1)
        return loss, metric_output

    @torch.inference_mode()
    def _teacher_predictions(
        self, input_ids: torch.Tensor, attention_mask: torch.Tensor
    ) -> torch.Tensor:
        outputs_teacher = self.teacher(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        mask_indices = input_ids == self.mask_token_id
        teacher_predictions = outputs_teacher.logits[mask_indices]
        teacher_predictions[:, self.args.ignore_tokens] = teacher_predictions.min()
        return teacher_predictions

    def _student_predictions_mean(
        self, outputs: MaskedLMOutput, labels: torch.Tensor
    ) -> torch.Tensor:
        # When calculating this it is very important to avoid breaking back propagation.
        # torch.cat will break back propagation, so the prediction is added per row to a holder
        logits = outputs.logits
        predictions = torch.zeros(logits.shape[0], device=logits.device)
        for index, (start, length) in enumerate(labels):
            prediction = logits[index, start : start + length]
            prediction = prediction.mean(dim=0)
            predictions[index] += prediction
        return predictions

    def _student_predictions_first(
        self,
        outputs: MaskedLMOutput,
        labels: torch.Tensor,
    ) -> torch.Tensor:
        return outputs.logits[range(outputs.logits.shape[0]), labels[:, 0]]

    def _loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        predictions = F.log_softmax(
            predictions.to(torch.float32) / self.args.temperature, dim=-1
        )
        targets = F.softmax(targets.to(torch.float32) / self.args.temperature, dim=-1)
        loss = F.kl_div(
            input=predictions,
            target=targets,
            reduction="batchmean",
            log_target=False,
        )
        return loss * (self.args.temperature**2)



# from src/main/python/blog/prompt_internalization/multilingual/roberta/collator.py
from typing import Any, Dict, List

from transformers import AutoTokenizer


class TeacherStudentCollator:
    """
    The teacher inputs need to be padded and have an associated attention mask.
    """

    def __init__(self, tokenizer: AutoTokenizer) -> None:
        self.tokenizer = tokenizer

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        teacher_inputs = self._teacher_inputs(features)
        student_inputs = self._student_inputs(features)
        batch = {**teacher_inputs, **student_inputs}
        if "label" in batch:
            batch["labels"] = batch["label"]
            del batch["label"]
        if "label_ids" in batch:
            batch["labels"] = batch["label_ids"]
            del batch["label_ids"]

        return batch

    def _teacher_inputs(self, features: List[Dict[str, Any]]) -> Dict[str, List[Any]]:
        teacher_inputs = [{"input_ids": row["teacher_input_ids"]} for row in features]
        teacher_batch = self.tokenizer.pad(
            teacher_inputs,
            padding=True,
            return_tensors="pt",
        )
        return {
            "teacher_input_ids": teacher_batch["input_ids"],
            "teacher_attention_mask": teacher_batch["attention_mask"],
        }

    def _student_inputs(self, features: List[Dict[str, Any]]) -> Dict[str, List[Any]]:
        student_inputs = [
            {
                "input_ids": row["input_ids"],
                "labels": row["labels"][0],  # known to have a single entry
            }
            for row in features
        ]
        return self.tokenizer.pad(
            student_inputs,
            padding=True,
            return_tensors="pt",
        )



# from src/main/python/blog/prompt_internalization/multilingual/roberta/metrics.py
from typing import Dict

from transformers import EvalPrediction


def compute_metrics(model_output: EvalPrediction) -> Dict[str, float]:
    kl_div = model_output.predictions[:, 0].mean()
    overlap = model_output.predictions[:, 1].mean()
    return {
        "kl_div": kl_div,
        "overlap": overlap,
    }



# from src/main/python/blog/prompt_internalization/multilingual/roberta/train.py
from pathlib import Path
from typing import List, Optional

import datasets
from transformers import AutoModelForMaskedLM, AutoTokenizer

from .collator import TeacherStudentCollator
from .metrics import compute_metrics
from .trainer import (
    MultilingualMaskedPromptInternalizationTrainer,
    MultilingualMaskedPromptInternalizationTrainingArguments,
)

DATASET_FOLDER = Path("/data/tatoeba/2022-06-18/dataset/")
MODEL_FOLDER = Path("/data/prompt-internalization/multilingual/")
RUN_FOLDER = Path("/tmp/runs")

MODEL_FOLDER.mkdir(parents=True, exist_ok=True)
RUN_FOLDER.mkdir(parents=True, exist_ok=True)


def train(
    *,
    model_name: str = "xlm-roberta-base",
    dataset_name: str = "xlm-roberta",
    batch_size: int = 64,
    learning_rate: float = 1e-4,
    temperature: float = 2,
    fp16: bool = False,
    mean_prediction: bool = False,
    ignore_tokens: Optional[List[int]] = None,
    epochs: Optional[float] = 2,
    max_steps: int = -1,
    evaluation_steps: int = 500,
) -> Path:
    run_name = "-".join(
        [
            f"{model_name}",
            f"e{epochs}" if max_steps == -1 else f"ms{max_steps}",
            f"bs{batch_size}",
            f"lr{learning_rate}",
            f"t{temperature}",
        ]
        + (["fp16"] if fp16 else [])
        + (["mean"] if mean_prediction else [])
        + ([f"it{len(ignore_tokens)}"] if ignore_tokens else [])
    )
    print(f"Starting {run_name}")
    train_ds = datasets.load_from_disk(DATASET_FOLDER / f"{dataset_name}-train.dataset")
    test_ds = datasets.load_from_disk(DATASET_FOLDER / f"{dataset_name}-test.dataset")

    training_args = MultilingualMaskedPromptInternalizationTrainingArguments(
        report_to="none",
        output_dir=RUN_FOLDER,
        num_train_epochs=epochs,
        max_steps=max_steps,
        seed=33,
        # number of steps before moving evaluation results from GPU to CPU see
        # https://discuss.huggingface.co/t/cuda-out-of-memory-when-using-trainer-with-compute-metrics/2941
        eval_accumulation_steps=5,
        #
        # hyperparameters
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        fp16=fp16,
        temperature=temperature,
        mean_prediction=mean_prediction,
        ignore_tokens=ignore_tokens,
        learning_rate=learning_rate,
        #
        # evaluation settings
        evaluation_strategy="steps",
        logging_steps=evaluation_steps,
        eval_steps=evaluation_steps,
        save_steps=evaluation_steps,
        #
        # checkpoint settings
        logging_dir=RUN_FOLDER / "logs",
        save_total_limit=2,
        load_best_model_at_end=True,
        metric_for_best_model="overlap",
        greater_is_better=True,
        remove_unused_columns=False,
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    teacher_model = AutoModelForMaskedLM.from_pretrained(model_name)
    student_model = AutoModelForMaskedLM.from_pretrained(model_name)
    data_collator = TeacherStudentCollator(tokenizer=tokenizer)

    trainer = MultilingualMaskedPromptInternalizationTrainer(
        model=student_model,
        args=training_args,
        teacher_model=teacher_model,
        train_dataset=train_ds,
        eval_dataset=test_ds,
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )

    trainer.train()
    student_model.save_pretrained(MODEL_FOLDER / run_name)

    return MODEL_FOLDER / run_name



# from src/main/python/blog/prompt_internalization/multilingual/roberta/evaluate.py
from pathlib import Path
from typing import List, Optional, Tuple

import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer


def evaluate(
    model_name: str, model_path: Path, ignore_tokens: Optional[List[int]] = None
) -> None:
    if ignore_tokens is None:
        ignore_tokens = []

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForMaskedLM.from_pretrained(model_path)
    model.eval()

    bass_evaluation(model=model, tokenizer=tokenizer, ignore_tokens=ignore_tokens)
    friday_evaluation(model=model, tokenizer=tokenizer, ignore_tokens=ignore_tokens)
    malibu_evaluation(model=model, tokenizer=tokenizer, ignore_tokens=ignore_tokens)
    football_evaluation(model=model, tokenizer=tokenizer, ignore_tokens=ignore_tokens)


def bass_evaluation(
    model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, ignore_tokens: List[int]
) -> None:
    first_phrase = "We spotted a large bass in the ocean."
    second_phrase = "The bass player did not receive the acknowledgment she deserves."
    third_phrase = "The black sea bass, is a member of the wreckfish family."

    first_predicted_words = get_predictions(
        model=model,
        tokenizer=tokenizer,
        text=first_phrase,
        noun="bass",
        ignore_tokens=ignore_tokens,
    )
    second_predicted_words = get_predictions(
        model=model,
        tokenizer=tokenizer,
        text=second_phrase,
        noun="bass",
        ignore_tokens=ignore_tokens,
    )
    third_predicted_words = get_predictions(
        model=model,
        tokenizer=tokenizer,
        text=third_phrase,
        noun="bass",
        ignore_tokens=ignore_tokens,
    )

    print("=== BASS EVALUATION ===")
    print(f"First Phrase is: {first_phrase} Target is: bass")
    print(f"Description is: {', '.join(first_predicted_words)}")
    print()

    print(f"Second Phrase is: {second_phrase} Target is: bass")
    print(f"Description is: {', '.join(second_predicted_words)}")
    print()

    print(f"Third Phrase is: {third_phrase} Target is: bass")
    print(f"Description is: {', '.join(third_predicted_words)}")
    print()

    print(
        f"First & Second: {sorted(set(first_predicted_words) & set(second_predicted_words))}"
    )
    print(
        f"First & Third: {sorted(set(first_predicted_words) & set(third_predicted_words))}"
    )
    print(
        f"Second & Third: {sorted(set(second_predicted_words) & set(third_predicted_words))}"
    )
    print()


def friday_evaluation(
    model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, ignore_tokens: List[int]
) -> None:
    spanish_text = "Friday es mi canción favorita."
    english_text = "Friday is my favourite song."

    spanish_predicted_words = get_predictions(
        model=model,
        tokenizer=tokenizer,
        text=spanish_text,
        noun="Friday",
        ignore_tokens=ignore_tokens,
    )
    english_predicted_words = get_predictions(
        model=model,
        tokenizer=tokenizer,
        text=english_text,
        noun="Friday",
        ignore_tokens=ignore_tokens,
    )

    overlap = set(spanish_predicted_words) & set(english_predicted_words)
    difference = set(spanish_predicted_words) ^ set(english_predicted_words)

    print("=== FRIDAY EVALUATION ===")
    print(f"Spanish Phrase is: {spanish_text}")
    print(f"Spanish Description is: {', '.join(spanish_predicted_words)}")

    print(f"English Phrase is: {english_text}")
    print(f"English Description is: {', '.join(english_predicted_words)}")
    print()

    print(f"Description Overlap is: {', '.join(sorted(overlap))}")
    print(f"Description Difference is: {', '.join(sorted(difference))}")
    print()


def malibu_evaluation(
    model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, ignore_tokens: List[int]
) -> None:
    text = "I like to drive my Malibu while drinking Malibu."

    first_predicted_words = get_predictions(
        model=model,
        tokenizer=tokenizer,
        text=text,
        noun="Malibu",
        ignore_tokens=ignore_tokens,
    )
    second_predicted_words = get_predictions(
        model=model,
        tokenizer=tokenizer,
        text=text,
        noun="Malibu",
        index=1,
        ignore_tokens=ignore_tokens,
    )

    print("=== MALIBU EVALUATION ===")
    print(f"Phrase is: {text}")
    print(f"First Malibu (car) Description is: {', '.join(first_predicted_words)}")
    print(f"Second Malibu (drink) Description is: {', '.join(second_predicted_words)}")
    print()

    print(
        f"First & Second: {sorted(set(first_predicted_words) & set(second_predicted_words))}"
    )
    print(
        f"First ^ Second: {sorted(set(first_predicted_words) ^ set(second_predicted_words))}"
    )
    print()


def football_evaluation(
    model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, ignore_tokens: List[int]
) -> None:
    spanish_phrase = (
        "Retiremos el equipo de la cancha, "
        "Boca no merece jugar esta copa que "
        "hace tiempo viene siendo desprestigiada.\n"
        "Ya no se juega al futbol."
    )

    english_phrase = (
        "Let's remove the team from the field, "
        "Boca does not deserve to play this cup that "
        "has long been discredited. "
        "Football is no longer played."
    )

    print("=== FOOTBALL EVALUATION ===")
    print(f"Spanish Phrase is: {spanish_phrase}")
    print(f"English Phrase is: {english_phrase}")
    print()

    for spanish_noun, english_noun in [
        ["equipo", "team"],
        ["Boca", "Boca"],
        ["copa", "cup"],
        ["tiempo", "long"],
        ["futbol", "Football"],
    ]:
        spanish_description = get_predictions(
            model=model,
            tokenizer=tokenizer,
            text=spanish_phrase,
            noun=spanish_noun,
            ignore_tokens=ignore_tokens,
        )
        english_description = get_predictions(
            model=model,
            tokenizer=tokenizer,
            text=english_phrase,
            noun=english_noun,
            ignore_tokens=ignore_tokens,
        )
        overlap = set(spanish_description) & set(english_description)
        difference = set(spanish_description) ^ set(english_description)

        print(f"Spanish word is: {spanish_noun}, English word is: {english_noun}")
        print(f"Spanish Description is: {', '.join(spanish_description)}")
        print(f"English Description is: {', '.join(english_description)}")
        print(f"Overlap is: {', '.join(sorted(overlap))} ({len(overlap)})")
        print(f"Difference is: {', '.join(sorted(difference))} ({len(difference)})")
        print()


@torch.inference_mode()
def get_predictions(
    *,
    model: AutoModelForMaskedLM,
    tokenizer: AutoTokenizer,
    text: str,
    noun: str,
    index: int = 0,
    ignore_tokens: Optional[List[int]] = None,
) -> List[str]:
    if ignore_tokens is None:
        ignore_tokens = []

    tokens = tokenizer(text, return_tensors="pt")
    start, _end = get_noun(
        tokenizer=tokenizer, tokens=tokens.input_ids[0], noun=noun, index=index
    )

    output = model(**tokens)
    predictions = output.logits[0, start]
    predictions[ignore_tokens] = predictions.min()
    predicted_tokens = predictions.argsort(descending=True)[:10]
    predicted_words = [
        word.strip() for word in tokenizer.batch_decode(predicted_tokens)
    ]

    return predicted_words


def get_noun(
    tokenizer: AutoTokenizer, tokens: torch.Tensor, noun: str, index: int
) -> Tuple[int, int]:
    length = tokens.shape[0]
    current_index = index
    for start_index in range(length):
        word = tokenizer.decode(tokens[start_index]).strip()
        if not noun.startswith(word):
            continue
        for end_index in range(start_index + 1, length):
            word = tokenizer.decode(tokens[start_index:end_index]).strip()
            if not noun == word:
                continue
            if current_index > 0:
                current_index -= 1
            else:
                return start_index, end_index
    raise AssertionError(f"Did not find {noun}[{index}] in {tokenizer.decode(tokens)}")

Training

20 epochs is about 10x longer than I have been training for so far.

Code
MODEL_NAME = "xlm-roberta-base"
Code
model_path = train(
    model_name=MODEL_NAME,
    batch_size=32,
    learning_rate=1e-4,
    temperature=2,
    mean_prediction=False,
    epochs=20,
    evaluation_steps=1_000,
)
Starting xlm-roberta-base-e20-bs32-lr0.0001-t2
/home/matthew/.cache/pypoetry/virtualenvs/blog-HrtMnrOS-py3.10/lib/python3.10/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
  warnings.warn(
[61060/61060 7:14:05, Epoch 20/20]
Step Training Loss Validation Loss Kl Div Overlap
1000 0.403300 0.280854 0.256531 0.683585
2000 0.288300 0.259887 0.240529 0.694900
3000 0.262300 0.261534 0.245546 0.698782
4000 0.225200 0.249906 0.228982 0.707599
5000 0.217900 0.240995 0.226439 0.712368
6000 0.211700 0.236759 0.219825 0.718522
7000 0.186700 0.240164 0.225746 0.713896
8000 0.181800 0.234820 0.216973 0.721330
9000 0.180100 0.228655 0.213052 0.722672
10000 0.160700 0.231911 0.216320 0.723374
11000 0.157300 0.227415 0.213988 0.726265
12000 0.159000 0.228098 0.211121 0.724159
13000 0.141000 0.226735 0.210241 0.729279
14000 0.138900 0.224571 0.207329 0.728123
15000 0.139900 0.225365 0.210097 0.728846
16000 0.126000 0.228829 0.214268 0.727834
17000 0.124400 0.226964 0.211337 0.729919
18000 0.123600 0.227304 0.211372 0.726987
19000 0.113800 0.226793 0.211822 0.729589
20000 0.109600 0.228013 0.211201 0.730952
21000 0.113300 0.224208 0.209901 0.730745
22000 0.102500 0.225952 0.210528 0.731695
23000 0.100200 0.222585 0.209529 0.730849
24000 0.101700 0.222588 0.207854 0.733223
25000 0.096400 0.223735 0.209063 0.734669
26000 0.091300 0.221564 0.204739 0.733430
27000 0.092500 0.220193 0.204844 0.733843
28000 0.087600 0.220678 0.203973 0.736011
29000 0.083700 0.218368 0.203532 0.737291
30000 0.085100 0.224064 0.207729 0.733987
31000 0.082200 0.220344 0.206388 0.737167
32000 0.078200 0.223062 0.207357 0.735866
33000 0.078000 0.223453 0.208466 0.734565
34000 0.076200 0.221971 0.206944 0.737229
35000 0.071800 0.218891 0.203820 0.738592
36000 0.073300 0.220247 0.205179 0.737725
37000 0.071100 0.219950 0.205713 0.737043
38000 0.067000 0.218195 0.203988 0.738902
39000 0.068200 0.219006 0.203538 0.737910
40000 0.067100 0.216856 0.202330 0.739025
41000 0.063200 0.217489 0.202965 0.738798
42000 0.064100 0.217092 0.202647 0.738860
43000 0.062700 0.216701 0.203053 0.739872
44000 0.060000 0.218087 0.203738 0.739542
45000 0.060000 0.218724 0.204076 0.739418
46000 0.059400 0.214085 0.200080 0.741483
47000 0.056200 0.215620 0.201491 0.740595
48000 0.057100 0.217038 0.202258 0.741400
49000 0.056500 0.215010 0.200240 0.742288
50000 0.053800 0.214259 0.199056 0.742370
51000 0.054400 0.215045 0.200357 0.742329
52000 0.053700 0.214869 0.199806 0.741937
53000 0.051600 0.215995 0.201306 0.741710
54000 0.051500 0.213915 0.199177 0.742907
55000 0.051700 0.215324 0.200396 0.742143
56000 0.049500 0.214356 0.199690 0.742453
57000 0.049700 0.214086 0.199557 0.742825
58000 0.049800 0.213373 0.198633 0.744518
59000 0.048100 0.213276 0.198745 0.744167
60000 0.048300 0.212866 0.198230 0.744229
61000 0.048100 0.212841 0.198155 0.744539

Code
evaluate(model_name=MODEL_NAME, model_path=model_path)
Could not locate the tokenizer configuration file, will try to use the model config instead.
=== BASS EVALUATION ===
First Phrase is: We spotted a large bass in the ocean. Target is: bass
Description is: Location, Type, Description, Size, Area, Color, Position, View, Status, Material

Second Phrase is: The bass player did not receive the acknowledgment she deserves. Target is: bass
Description is: Description, Status, Type, Name, Title, Owner, Rating, Position, Location, Feature

Third Phrase is: The black sea bass, is a member of the wreckfish family. Target is: bass
Description is: Type, Status, Family, Owner, Description, Name, Country, Age, Location, Animal

First & Second: ['Description', 'Location', 'Position', 'Status', 'Type']
First & Third: ['Description', 'Location', 'Status', 'Type']
Second & Third: ['Description', 'Location', 'Name', 'Owner', 'Status', 'Type']

=== FRIDAY EVALUATION ===
Spanish Phrase is: Friday es mi canción favorita.
Spanish Description is: Description, Tags, Tag, Status, Theme, Location, Title, Album, Category, Labels
English Phrase is: Friday is my favourite song.
English Description is: Description, Tags, Tag, Status, Album, Location, Title, Theme, Labels, Comments

Description Overlap is: Album, Description, Labels, Location, Status, Tag, Tags, Theme, Title
Description Difference is: Category, Comments

=== MALIBU EVALUATION ===
Phrase is: I like to drive my Malibu while drinking Malibu.
First Malibu (car) Description is: Food, Color, Type, Name, Material, Product, Weight, Plant, Animal, Fruit
Second Malibu (drink) Description is: Food, Color, Drink, Name, Type, Material, Fruit, Weight, Product, Wine

First & Second: ['Color', 'Food', 'Fruit', 'Material', 'Name', 'Product', 'Type', 'Weight']
First ^ Second: ['Animal', 'Drink', 'Plant', 'Wine']

=== FOOTBALL EVALUATION ===
Spanish Phrase is: Retiremos el equipo de la cancha, Boca no merece jugar esta copa que hace tiempo viene siendo desprestigiada.
Ya no se juega al futbol.
English Phrase is: Let's remove the team from the field, Boca does not deserve to play this cup that has long been discredited. Football is no longer played.

Spanish word is: equipo, English word is: team
Spanish Description is: Type, Name, Title, Status, Description, Location, Team, Brand, Owner, Game
English Description is: Name, Type, Title, Description, Status, Owner, Brand, Team, Location, Logo
Overlap is: Brand, Description, Location, Name, Owner, Status, Team, Title, Type (9)
Difference is: Game, Logo (2)

Spanish word is: Boca, English word is: Boca
Spanish Description is: Name, Type, Title, Status, Description, Game, Owner, Location, Brand, Tag
English Description is: Name, Title, Type, Status, Description, Owner, Game, Brand, Location, Color
Overlap is: Brand, Description, Game, Location, Name, Owner, Status, Title, Type (9)
Difference is: Color, Tag (2)

Spanish word is: copa, English word is: cup
Spanish Description is: Type, Title, Game, Status, Description, Category, Sports, Theme, Location, Sport
English Description is: Type, Title, Status, Category, Game, Description, Sports, Sport, Location, Application
Overlap is: Category, Description, Game, Location, Sport, Sports, Status, Title, Type (9)
Difference is: Application, Theme (2)

Spanish word is: tiempo, English word is: long
Spanish Description is: Age, Location, Duration, Type, Description, Date, Status, Game, Year, Time
English Description is: Status, Type, Age, Title, Description, Rating, Location, Name, Year, Game
Overlap is: Age, Description, Game, Location, Status, Type, Year (7)
Difference is: Date, Duration, Name, Rating, Time, Title (6)

Spanish word is: futbol, English word is: Football
Spanish Description is: Sport, Type, Sports, Game, Style, Theme, Category, Application, Title, Description
English Description is: Type, Sports, Sport, Game, Title, Category, Style, Football, Theme, Application
Overlap is: Application, Category, Game, Sport, Sports, Style, Theme, Title, Type (9)
Difference is: Description, Football (2)

The metrics show an improvement across the entire train, yet the output is weak with terms like Description and Title turning up frequently. This does suggest collapse.

We can try training again with an increased temperature.

Code
model_path = train(
    model_name=MODEL_NAME,
    batch_size=32,
    learning_rate=1e-4,
    temperature=5, # CHANGED
    mean_prediction=False,
    epochs=20,
    evaluation_steps=1_000,
)
Starting xlm-roberta-base-e20-bs32-lr0.0001-t5
/home/matthew/.cache/pypoetry/virtualenvs/blog-HrtMnrOS-py3.10/lib/python3.10/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
  warnings.warn(
[61060/61060 4:26:49, Epoch 20/20]
Step Training Loss Validation Loss Kl Div Overlap
1000 0.499100 0.346719 0.272460 0.674107
2000 0.351100 0.313100 0.255789 0.690770
3000 0.318900 0.308473 0.246127 0.698678
4000 0.278800 0.305208 0.235727 0.705637
5000 0.268300 0.288019 0.229908 0.713236
6000 0.261800 0.285359 0.226173 0.714516
7000 0.231600 0.290514 0.225802 0.714887
8000 0.225700 0.285120 0.224180 0.716890
9000 0.223400 0.275393 0.217442 0.718687
10000 0.200400 0.276047 0.215493 0.722321
11000 0.196100 0.276407 0.215938 0.723973
12000 0.197700 0.274994 0.216346 0.723229
13000 0.176200 0.273291 0.215743 0.725563
14000 0.174100 0.272258 0.212242 0.726802
15000 0.174000 0.272064 0.217354 0.725769
16000 0.158100 0.275212 0.216617 0.724117
17000 0.155900 0.272179 0.212783 0.729486
18000 0.154900 0.267659 0.210043 0.730704
19000 0.142900 0.269398 0.214941 0.729073
20000 0.138800 0.269864 0.210844 0.729589
21000 0.142400 0.272709 0.212704 0.729858
22000 0.130100 0.269864 0.211297 0.731633
23000 0.127200 0.265007 0.211502 0.731798
24000 0.128300 0.270539 0.214050 0.732149
25000 0.121700 0.264980 0.211163 0.734111
26000 0.116000 0.265250 0.210008 0.734235
27000 0.117500 0.265351 0.208235 0.733533
28000 0.111000 0.264313 0.207080 0.735329
29000 0.106400 0.267540 0.208386 0.734256
30000 0.108400 0.268394 0.207938 0.735226
31000 0.104700 0.263199 0.208138 0.736919
32000 0.099500 0.263530 0.207196 0.735267
33000 0.099300 0.263550 0.205301 0.734524
34000 0.097200 0.267153 0.208381 0.735722
35000 0.091800 0.263137 0.207557 0.735928
36000 0.093700 0.266297 0.209930 0.735350
37000 0.090900 0.261557 0.206160 0.737766
38000 0.085800 0.260821 0.206914 0.738158
39000 0.087600 0.260837 0.206006 0.738117
40000 0.085800 0.261977 0.205463 0.737208
41000 0.081100 0.259678 0.205240 0.738385
42000 0.082100 0.261295 0.204675 0.739253
43000 0.080200 0.258047 0.203045 0.739913
44000 0.076800 0.261033 0.205304 0.740760
45000 0.077300 0.262062 0.207669 0.738963
46000 0.076400 0.258826 0.203268 0.740491
47000 0.072400 0.260828 0.204339 0.741441
48000 0.073100 0.260119 0.205025 0.740657
49000 0.072800 0.258310 0.202318 0.741689
50000 0.069100 0.259067 0.202928 0.742536
51000 0.070100 0.259791 0.204569 0.740677
52000 0.069500 0.258562 0.203154 0.741070
53000 0.066500 0.258652 0.203341 0.741441
54000 0.066400 0.259573 0.203057 0.741668
55000 0.066400 0.259439 0.203471 0.743052
56000 0.063900 0.258449 0.202918 0.741978
57000 0.064000 0.259101 0.203254 0.742907
58000 0.064400 0.257644 0.202360 0.743176
59000 0.062200 0.257244 0.202339 0.742680
60000 0.062300 0.258215 0.202372 0.743093
61000 0.062000 0.257468 0.201944 0.743196

Code
evaluate(model_name=MODEL_NAME, model_path=model_path)
Could not locate the tokenizer configuration file, will try to use the model config instead.
=== BASS EVALUATION ===
First Phrase is: We spotted a large bass in the ocean. Target is: bass
Description is: Location, Area, Color, Position, View, Land, Type, Country, Ocean, Views

Second Phrase is: The bass player did not receive the acknowledgment she deserves. Target is: bass
Description is: Description, Name, Type, Status, Owner, Title, Location, Rating, Position, Details

Third Phrase is: The black sea bass, is a member of the wreckfish family. Target is: bass
Description is: Status, Type, Family, Description, Owner, Name, Age, Country, Location, Religion

First & Second: ['Location', 'Position', 'Type']
First & Third: ['Country', 'Location', 'Type']
Second & Third: ['Description', 'Location', 'Name', 'Owner', 'Status', 'Type']

=== FRIDAY EVALUATION ===
Spanish Phrase is: Friday es mi canción favorita.
Spanish Description is: Description, Tags, Tag, Album, Status, Title, Theme, Location, Comments, Motto
English Phrase is: Friday is my favourite song.
English Description is: Description, Tags, Tag, Album, Title, Status, Location, Theme, Comments, Labels

Description Overlap is: Album, Comments, Description, Location, Status, Tag, Tags, Theme, Title
Description Difference is: Labels, Motto

=== MALIBU EVALUATION ===
Phrase is: I like to drive my Malibu while drinking Malibu.
First Malibu (car) Description is: Food, Name, Type, Color, Material, Product, Animal, Owner, Cat, Plant
Second Malibu (drink) Description is: Food, Name, Color, Product, Type, Plant, Material, Language, Owner, Animal

First & Second: ['Animal', 'Color', 'Food', 'Material', 'Name', 'Owner', 'Plant', 'Product', 'Type']
First ^ Second: ['Cat', 'Language']

=== FOOTBALL EVALUATION ===
Spanish Phrase is: Retiremos el equipo de la cancha, Boca no merece jugar esta copa que hace tiempo viene siendo desprestigiada.
Ya no se juega al futbol.
English Phrase is: Let's remove the team from the field, Boca does not deserve to play this cup that has long been discredited. Football is no longer played.

Spanish word is: equipo, English word is: team
Spanish Description is: Title, Type, Description, Name, Status, Team, Game, Location, Brand, Category
English Description is: Title, Type, Name, Description, Status, Brand, Location, Owner, Team, Game
Overlap is: Brand, Description, Game, Location, Name, Status, Team, Title, Type (9)
Difference is: Category, Owner (2)

Spanish word is: Boca, English word is: Boca
Spanish Description is: Title, Name, Type, Description, Game, Status, Owner, Brand, Location, Color
English Description is: Title, Type, Name, Game, Description, Brand, Location, Sport, Owner, Status
Overlap is: Brand, Description, Game, Location, Name, Owner, Status, Title, Type (9)
Difference is: Color, Sport (2)

Spanish word is: copa, English word is: cup
Spanish Description is: Type, Title, Game, Description, Status, Category, Sports, Sport, Theme, Name
English Description is: Type, Title, Game, Description, Sports, Sport, Category, Status, Location, Name
Overlap is: Category, Description, Game, Name, Sport, Sports, Status, Title, Type (9)
Difference is: Location, Theme (2)

Spanish word is: tiempo, English word is: long
Spanish Description is: Title, Type, Description, Game, Status, Name, Category, Location, Sports, Country
English Description is: Type, Title, Description, Game, Status, Name, Location, Sport, Sports, Country
Overlap is: Country, Description, Game, Location, Name, Sports, Status, Title, Type (9)
Difference is: Category, Sport (2)

Spanish word is: futbol, English word is: Football
Spanish Description is: Type, Sports, Game, Sport, Title, Style, Category, Theme, Description, Football
English Description is: Type, Sports, Sport, Game, Title, Football, Category, Style, Description, Theme
Overlap is: Category, Description, Football, Game, Sport, Sports, Style, Theme, Title, Type (10)
Difference is:  (0)

The collapse is not as bad here. Bass evaluation does perform poorly though.

Broadly I think that the evaluation itself is prone to gaming by producing fixed output, and the model loses the ability to distinguish different nouns in the input. This is a dataset problem and I think that the solution lies within a better dataset.