Cross Language Prompt Internalization

Using the Tatoeba cross language dataset to internalize a multi lingual prompt in a single space
prompt internalization
multilingual prompt internalization
cross language word sense induction
Published

June 18, 2022

Tatoeba is a dataset of translated sentences. With this I should be able to internalize a prompt where the prompted noun description has a fixed language (I speak English so it’s that). This would lead to a multi-lingual noun describer as every noun should be described in the English description space.

The Tatoeba dataset is exported weekly and is made of several files. I’m interested in the base sentences and then the links that indicate which sentence is a translation of another.

Code
import blog.transformers_logging

Data Preparation

I need to download the Tatoeba files and then merge them. What I need is a list of sentences in English and then the translated sentence in another known language.

Sentence and Translation

I’ve started by downloading the relevant files:

wget https://downloads.tatoeba.org/exports/sentences.tar.bz2
wget https://downloads.tatoeba.org/exports/links.tar.bz2

The sentences file contains the sentence_id, language, text columns. The links file contains the sentence_id, translation_id columns, where both sentence_id and translation_id refer to rows in the sentences file. The links file is symmetric so if row 1,77 (sentence_id 1, translation_id 77) appears then row 77,1 will also appear.

I can use these two files to produce my desired dataset.

Code
from pathlib import Path

DATA_FOLDER = Path("/data/blog/2022-06-18-cross-language-prompt-internalization")
TATOEBA_FOLDER = Path("/data/tatoeba/2022-06-18")
SENTENCES_FILE = TATOEBA_FOLDER / "sentences.tar.bz2"
LINKS_FILE = TATOEBA_FOLDER / "links.tar.bz2"
Code
from typing import List
import tarfile

import pandas as pd

def read_tarfile(path: Path, names: List[str]) -> pd.DataFrame:
    """
    The Tatoeba files are provided as bz2 compressed tarfiles.
    This will correctly load them into pandas.
    """
    # there is one file in each tarfile
    with tarfile.open(path, "r:*") as handle:
        tar_path = handle.getnames()[0]
        return pd.read_csv(
            handle.extractfile(tar_path),
            delimiter="\t",
            names=names,
        )

sentences_df = read_tarfile(SENTENCES_FILE, names=["id", "language", "text"])
sentences_df = sentences_df.set_index("id", drop=True)

# example is [1, 77]
# sentence 77 is a translation of sentence 1
links_df = read_tarfile(LINKS_FILE, names=["sentence_id", "translation_id"])

print(f"Read {len(sentences_df):,} sentences and {len(links_df):,} links")
Read 10,467,414 sentences and 21,259,444 links

I now want to map each sentence_id to it’s associated translation_id.

The first thing I tried was to join the sentences dataframe to itself using the two series from the links file. Unfortunately that did not work because there are ids in the links file which are not present in the sentences dataset.

Using a map allows all the valid sentence_ids to be looked up without issue.

Code
translation_map = links_df.set_index("sentence_id").translation_id.to_dict()

translation_df = sentences_df.copy()
translation_df["translation_id"] = sentences_df.index.map(translation_map)
translation_df = translation_df.dropna()
translation_df["translation_language"] = translation_df.translation_id.map(sentences_df.language)
translation_df["translation_text"] = translation_df.translation_id.map(sentences_df.text)
translation_df = translation_df.drop(columns="translation_id")
translation_df
language text translation_language translation_text
id
1 cmn 我們試試看! tuk Bir edip göreli
2 cmn 我该去睡觉了。 jpn 寝なきゃ。
3 cmn 你在干什麼啊? nld Wat zit je te doen?
4 cmn 這是什麼啊? grn Mba'épiko upéva?
5 cmn 今天是6月18号,也是Muiriel的生日! tgl Ngayon ay ika-18 ng Hunyo at kaarawan ngayon n...
... ... ... ... ...
10920732 spa Tomen un plátano. kab Ddmet tabanant.
10920733 epo La Adventa sezono estas tempo, kiam oni havas ... deu Die Adventszeit ist eine Zeit, in der man Zeit...
10920735 deu Die hoffnungsvolle Erwartungshaltung der Adven... epo La esperplena atendo de la Adventa sezono devu...
10920736 spa Ve demasiada televisión. kab Bezzaf tettferrij tilivizyu.
10920738 epo La esperplena atendo de la Adventa sezono devu... deu Die hoffnungsvolle Erwartungshaltung der Adven...

8906089 rows × 4 columns

Code
translation_df.to_parquet(DATA_FOLDER / "translated.gz.parquet", compression="gzip")

We now have almost 9 million sentences with their translated counterpart. This is great.

English Sentence and Translation

What we need is a fixed language for one side of this dataset as I want to map the noun descriptions to a fixed space, that fixed space being the model descriptions of nouns in that fixed language. As I speak English and it is a common language I am going to use that.

With this dataset we can then start extracting the nouns. The code that was used in the prompt internalization post will be useful there, so the language has to have Spacy support.

Code
import pandas as pd

translation_df = pd.read_parquet(DATA_FOLDER / "translated.gz.parquet")
english_df = translation_df[(translation_df.language == "eng") & (translation_df.translation_language != "eng")]
english_df
language text translation_language translation_text
id
1276 eng Let's try something. mon Ямар нэг зүйл туршицгаая.
1277 eng I have to go to sleep. ido Me mustas irar dormar.
1280 eng Today is June 18th and it is Muiriel's birthday! gos Vandoag is t 18 juni en t is verjoardag van Mu...
1282 eng Muiriel is 20 now. bel Мюрыэль споўнілася 20.
1283 eng The password is "Muiriel". kab Awal uffir, d "Muriel".
... ... ... ... ...
10919638 eng All I can say is amazing. jpn すごいとしか言いようがない。
10919681 eng Yamato-kotoba is the language that has been us... jpn 大和言葉は、日本で昔から使われてきた言葉で、外国語に翻訳し易いものもあれば、し難いものもあります。
10919689 eng Where exactly did you meet Tom? jpn そもそも、どこでトムと会ったの?
10919692 eng Tom's new car can do 140 kilometers per hour. jpn トムの新車は時速140キロ出るんだよ。
10919747 eng That plan was unsuccessful. jpn その計画は失敗に終わった。

1157510 rows × 4 columns

Now it’s time to extract the nouns from the sentences. Once again I can use Spacy for this, as it supports several different languages. I want to retain most of this dataset and keep at least one language as a test set.

The languages that Spacy supports are listed here:

  • Catalan (cat)
  • Chinese (zho)
  • Danish (dan)
  • Dutch (nld)
  • English (eng)
  • Finnish (fin)
  • French (fra)
  • German (deu)
  • Greek (ell)
  • Italian (ita)
  • Japanese (jpn)
  • Korean (kor)
  • Lithuanian (lit)
  • Macedonian (mkd)
  • Norwegian Bokmål (nob)
  • Polish (pol)
  • Portuguese (por)
  • Romanian (ron)
  • Russian (rus)
  • Spanish (spa)
  • Swedish (swe)

I need to find the distribution of these and then pick maybe 5 to train with and one for evaluation. I would like a Germanic, Cyrillic and Asian language in the tranining set.

Code
english_df[english_df.translation_language.isin({
    "cat",
    "zho",
    "dan",
    "nld",
    # "eng",
    "fin",
    "fra",
    "deu",
    "ell",
    "ita",
    "jpn",
    "kor",
    "lit",
    "mkd",
    "nob",
    "pol",
    "por",
    "ron",
    "rus",
    "spa",
    "swe",
})].translation_language.value_counts()
rus    122838
deu     86620
ita     68884
por     51822
fra     40173
spa     36975
jpn     35974
nld     33650
fin     15894
pol     12167
ron      8059
dan      7483
mkd      6790
swe      6102
ell      3990
kor      1903
cat       972
nob       778
lit       302
Name: translation_language, dtype: int64

Taking the top languages from this I have:

language iso code rows
Russian rus 122838
German deu 86620
Italian ita 68884
Portuguese por 51822
French fra 40173
Spanish spa 36975
Japanese jpn 35974

This gives me everything I wanted and I can use Portuguese and Spanish as a test set.

One problem that I have is that it will not be simple to find the correspondence between nouns in English and nouns in the translation language. To make this easier I can see if there is only a single noun in both utterances as that should give me the mapping.

Code
import spacy

nlp_names = {
    "eng": "en_core_web_sm",
    "rus": "ru_core_news_sm",
    "deu": "de_core_news_sm",
    "ita": "it_core_news_sm",
    "por": "pt_core_news_sm",
    "fra": "fr_core_news_sm",
    "spa": "es_core_news_sm",
    "jpn": "ja_core_news_sm",
}

nlp = {
    language_code: spacy.load(spacy_name)
    for language_code, spacy_name in nlp_names.items()
}
Code
row = english_df[english_df.translation_language == "jpn"].iloc[5]
doc = nlp["jpn"](row.translation_text)
print(row.text)
print(doc.text)
for token in doc:
    print(token.text, token.pos_)
Call the police!
警察を呼べ!
警察 NOUN
を ADP
呼べ VERB
! PROPN

I immediately see a problem with this as it is labelling a lot of the punctuation as PROPN. I think I need to increase the size of the spacy models that are being used. It would be great not to use the largest as they are easily half a gig each.

Code
import spacy

nlp_names = {
    "eng": "en_core_web_md",
    "rus": "ru_core_news_md",
    "deu": "de_core_news_md",
    "ita": "it_core_news_md",
    "por": "pt_core_news_md",
    "fra": "fr_core_news_md",
    "spa": "es_core_news_md",
    "jpn": "ja_core_news_md",
}

nlp = {
    language_code: spacy.load(spacy_name)
    for language_code, spacy_name in nlp_names.items()
}
Code
row = english_df[english_df.translation_language == "jpn"].iloc[5]
doc = nlp["jpn"](row.translation_text)
print(row.text)
print(doc.text)
for token in doc:
    print(token.text, token.pos_)
Call the police!
警察を呼べ!
警察 NOUN
を ADP
呼べ VERB
! VERB
Code
doc = nlp["eng"](row.text)
print(row.text)
for token in doc:
    print(token.text, token.pos_)
Call the police!
Call VERB
the DET
police NOUN
! PUNCT

This isn’t great as the punctuation is now a verb however I can handle that. Translating the noun of 警察 returns the police, which is encouraging as it suggests that my one noun approach will work.

The Encoder that was used in the Prompt Internalization post will work with this. It just needs an adjustment to allow multiple different languages to work.

Code
from typing import List

import spacy
from spacy.tokens import Span
from spacy.matcher import Matcher


class NounExtractor:
    def __init__(self, spacy_name: str) -> None:
        self.nlp = spacy.load(spacy_name)
        self.matcher = Matcher(self.nlp.vocab)
        self.matcher.add("nouns", [
            [{"POS": {"IN": ["NOUN", "PROPN"]}, "OP": "+"}],
        ])

    def get_nouns(self, text: str) -> List[Span]:
        doc = self.nlp(text)
        nouns = self.matcher(doc, as_spans=True)
        return self.unique(text, nouns=nouns)

    def unique(self, text: str, nouns: List[Span]) -> List[Span]:
        text = text.casefold()
        return [
            noun
            for noun in nouns
            if text.count(noun.text.casefold()) == 1
        ]
Code
from typing import Any, Dict, Tuple
from spacy.tokens import Span
from transformers import AutoTokenizer

nlp_names = {
    "eng": "en_core_web_md",
    "rus": "ru_core_news_md",
    "deu": "de_core_news_md",
    "ita": "it_core_news_md",
    "por": "pt_core_news_md",
    "fra": "fr_core_news_md",
    "spa": "es_core_news_md",
    "jpn": "ja_core_news_md",
}

class Encoder:
    def __init__(self, name: str = "roberta-base") -> None:
        self.tokenizer = AutoTokenizer.from_pretrained(name)
        self.extractors = {
            language: NounExtractor(spacy_name)
            for language, spacy_name in nlp_names.items()
        }

    def encode(self, text: str, language: str) -> Dict[str, Any]:
        nouns = self.extractors[language].get_nouns(text)
        tokens = self.tokenizer(text, truncation=True, return_offsets_mapping=True)
        labels = self.find(tokens.offset_mapping, nouns)
        return {
            "input_ids": tokens.input_ids,
            "attention_mask": tokens.attention_mask,
            "labels": labels
        }

    def find(self, offsets: List[Tuple[int, int]], nouns: List[Span]) -> List[Tuple[int, int]]:
        starts = {
            start: index
            for index, (start, end) in enumerate(offsets)
            if start != end
        }
        ends = {
            end: index
            for index, (start, end) in enumerate(offsets)
            if start != end
        }
        return [
            (starts[noun.start_char], 1 + ends[noun.end_char] - starts[noun.start_char])
            for noun in nouns
            if noun.start_char in starts and noun.end_char in ends
        ]
Code
encoder = Encoder()
Code
text = "The sand castle has shells"
encoder.encode(text, "eng")
{'input_ids': [0, 133, 6255, 22637, 34, 23647, 2],
 'attention_mask': [1, 1, 1, 1, 1, 1, 1],
 'labels': [(2, 1), (2, 2), (3, 1), (5, 1)]}

I can then run this over the selected languages.

First the English text has to be translated and extracted, then the translation language. This is going to take a while.

Code
restricted_df = english_df[
    english_df.translation_language.isin(set(nlp_names.keys()))
].copy()

encoded_text_df = pd.DataFrame(
    restricted_df.apply(
        lambda row: encoder.encode(text=row.text, language=row.language),
        axis="columns",
    )
    .tolist(),
    index=restricted_df.index,
)
encoded_translation_df = pd.DataFrame(
    restricted_df.apply(
        lambda row: encoder.encode(text=row.translation_text, language=row.translation_language),
        axis="columns",
    )
    .tolist(),
    index=restricted_df.index,
)

How many nouns are there per sentence? We really want there to be a lot with only a single noun as those are the ones we can work with.

Code
encoded_text_df.labels.apply(len).value_counts().head()
1    160012
2    106243
0     93578
3     45191
4     19118
Name: labels, dtype: int64
Code
encoded_translation_df.labels.apply(len).value_counts().head()
1    156232
2    108637
0     88285
3     47375
4     20541
Name: labels, dtype: int64

This suggests that at best we can work with ~150,000 sentences. This feels like a good amount.

Code
restricted_df = pd.merge(restricted_df, encoded_text_df, left_index=True, right_index=True)
restricted_df = pd.merge(
    restricted_df,
    encoded_translation_df.rename(columns={
        "input_ids": "translation_input_ids",
        "attention_mask": "translation_attention_mask",
        "labels": "translation_labels",
    }),
    left_index=True,
    right_index=True,
)
restricted_df = restricted_df.reset_index(drop=True)
restricted_df
language text translation_language translation_text input_ids attention_mask labels translation_input_ids translation_attention_mask translation_labels
0 eng I was in the mountains. fra J'étais à la montagne. [0, 100, 21, 11, 5, 9787, 4, 2] [1, 1, 1, 1, 1, 1, 1, 1] [(5, 1)] [0, 863, 108, 1140, 4349, 354, 6534, 897, 2712... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] [(8, 2)]
1 eng I told them to send me another ticket. deu Ich sagte ihnen, sie sollen mir eine neue Fahr... [0, 100, 174, 106, 7, 2142, 162, 277, 3682, 4, 2] [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] [(8, 1)] [0, 100, 611, 17929, 859, 939, 10245, 225, 6, ... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [(19, 4)]
2 eng It depends on the context. deu Das hängt vom Kontext ab. [0, 243, 7971, 15, 5, 5377, 4, 2] [1, 1, 1, 1, 1, 1, 1, 1] [(5, 1)] [0, 495, 281, 1368, 1561, 2590, 90, 23953, 112... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] [(8, 2)]
3 eng So what? jpn それが何か? [0, 2847, 99, 116, 2] [1, 1, 1, 1, 1] [] [0, 46311, 46, 48916, 48538, 47856, 15722, 487... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] []
4 eng That is somewhat explained at the end. ita È in qualche modo spiegato alla fine. [0, 1711, 16, 5568, 2002, 23, 5, 253, 4, 2] [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] [(7, 1)] [0, 3849, 23133, 11, 22043, 2871, 11134, 139, ... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] [(6, 2), (13, 1)]
... ... ... ... ... ... ... ... ... ... ...
443281 eng All I can say is amazing. jpn すごいとしか言いようがない。 [0, 3684, 38, 64, 224, 16, 2770, 4, 2] [1, 1, 1, 1, 1, 1, 1, 1, 1] [] [0, 48735, 46311, 10674, 47780, 48281, 47918, ... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... []
443282 eng Yamato-kotoba is the language that has been us... jpn 大和言葉は、日本で昔から使われてきた言葉で、外国語に翻訳し易いものもあれば、し難いものもあります。 [0, 975, 424, 3938, 12, 330, 1242, 19614, 16, ... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [(1, 3), (5, 3), (16, 2), (20, 1), (27, 1)] [0, 48262, 42393, 10659, 14285, 36484, 11423, ... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [(1, 4), (1, 10), (14, 3), (19, 1), (38, 3), (...
443283 eng Where exactly did you meet Tom? jpn そもそも、どこでトムと会ったの? [0, 13841, 2230, 222, 47, 972, 1560, 116, 2] [1, 1, 1, 1, 1, 1, 1, 1, 1] [(6, 1)] [0, 46311, 46, 48885, 46311, 46, 48885, 47341,... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [(12, 2)]
443284 eng Tom's new car can do 140 kilometers per hour. jpn トムの新車は時速140キロ出るんだよ。 [0, 15691, 18, 92, 512, 64, 109, 9680, 8130, 2... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] [(1, 1), (4, 1), (8, 1), (10, 1)] [0, 49280, 49598, 48018, 47240, 7487, 49416, 2... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [(1, 2), (5, 3), (10, 3), (14, 2)]
443285 eng That plan was unsuccessful. jpn その計画は失敗に終わった。 [0, 1711, 563, 21, 15943, 4, 2] [1, 1, 1, 1, 1, 1, 1] [(2, 1)] [0, 46311, 46, 48018, 36484, 11423, 23133, 478... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [(6, 3), (11, 3)]

443286 rows × 10 columns

Code
restricted_df.to_parquet(DATA_FOLDER / "encoded-english.gz.parquet", compression="gzip")
Code
len(restricted_df[
    (restricted_df.labels.apply(len) == 1) &
    (restricted_df.translation_labels.apply(len) == 1) &
    (restricted_df.translation_language != "por") &
    (restricted_df.translation_language != "spa")
])
96680

After all of this we have about 96,000 rows to work with. If this is a problem then adding more languages should be possible.

Encoding Dataset

I now need a way to pass all of this data into the trainer. To do this effectively I should incorporate the prompt in the English input_ids so that I don’t have to alter the teacher input too much.

Code
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("roberta-base")
prompt = " Pet: Dog, Color: Yellow, Vehicle: Tractor, Fruit: Banana, <mask>:" # add noun at end
tokenized_prompt = tokenizer(prompt, return_attention_mask=False, add_special_tokens=False).input_ids
Code
import pandas as pd

restricted_df = pd.read_parquet(DATA_FOLDER / "encoded-english.gz.parquet")

single_noun_df = restricted_df[
    (restricted_df.labels.apply(len) == 1) &
    (restricted_df.translation_labels.apply(len) == 1)
]
single_noun_df = single_noun_df[[
    "language", "text",
    "input_ids", "labels",
    "translation_language", "translation_text",
    "translation_input_ids", "translation_labels",
]].copy()
Code
from typing import List, Tuple

def add_prompt(input_ids: List[int], labels: List[Tuple[int, int]]) -> List[int]:
    noun_start, noun_length = labels[0]
    return (
        input_ids[:-1] +
        tokenized_prompt +
        input_ids[noun_start : noun_start+noun_length] +
        input_ids[-1:]
    )
Code
tokenizer.decode(
    add_prompt(restricted_df.iloc[0].input_ids, restricted_df.iloc[0].labels)
)
'<s>I was in the mountains. Pet: Dog, Color: Yellow, Vehicle: Tractor, Fruit: Banana,<mask>: mountains</s>'
Code
single_noun_df["input_ids"] = single_noun_df.apply(
    lambda row: add_prompt(row.input_ids, row.labels),
    axis="columns"
)
Code
single_noun_df = single_noun_df[[
    "text",
    "input_ids",
    "translation_language",
    "translation_input_ids",
    "translation_labels"
]]
single_noun_df = single_noun_df.rename(columns={
    "input_ids": "teacher_input_ids",
    "translation_language": "language",
    "translation_input_ids": "input_ids",
    "translation_labels": "labels",
})
Code
single_noun_df.to_parquet(DATA_FOLDER / "single-noun.gz.parquet", compression="gzip")

At this point we have something that could be provided to the model for training. The next trick will be to get it to accept the teacher_input_ids in addition to everything else. The huggingface trainer often restricts the parameters passed to train to those that the model accepts.

Code
import datasets

train_languages = {
    "rus",
    "deu",
    "ita",
    "fra",
    "jpn",
}
test_languages = {
    "por",
    "spa",
}

train_ds = datasets.Dataset.from_pandas(
    single_noun_df[single_noun_df.language.isin(train_languages)]
)
test_ds = datasets.Dataset.from_pandas(
    single_noun_df[single_noun_df.language.isin(test_languages)]
)

train_ds.save_to_disk(DATA_FOLDER / "single-noun-train.dataset")
test_ds.save_to_disk(DATA_FOLDER / "single-noun-test.dataset")

Training a Model

Code
from pathlib import Path
import datasets

from pathlib import Path

DATA_FOLDER = Path("/data/blog/2022-06-18-cross-language-prompt-internalization")
RUN_DIRECTORY = Path(DATA_FOLDER / "runs")
RUN_DIRECTORY.mkdir(parents=True, exist_ok=True)

MODEL_NAME = "roberta-base"

BATCH_SIZE = 64

LEARNING_RATE = 1e-4
TEMPERATURE = 2
EPOCHS = 8
# MAX_STEPS = 5_000
# MAX_STEPS = 50
EVALUATION_STEPS = 1_000
# EVALUATION_STEPS = 10

train_ds = datasets.load_from_disk(DATA_FOLDER / "single-noun-train.dataset")
test_ds = datasets.load_from_disk(DATA_FOLDER / "single-noun-test.dataset")
Code
import string
from typing import Any, Dict, Tuple, Union

import torch
import torch.nn.functional as F
from transformers import AutoModelForMaskedLM, AutoTokenizer, Trainer, TrainingArguments
from transformers.modeling_outputs import MaskedLMOutput
from transformers.tokenization_utils_base import BatchEncoding


class MultilingualMaskedPromptInternalizationTrainingArguments(TrainingArguments):
    def __init__(self, *args, temperature: float = 2.0, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.temperature = temperature


class MultilingualMaskedPromptInternalizationTrainer(Trainer):
    def __init__(
        self,
        *args,
        teacher_model: AutoModelForMaskedLM = None,
        tokenizer: AutoTokenizer = None,
        **kwargs,
    ) -> None:
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        self._move_model_to_device(self.teacher, self.model.device)
        self.teacher.eval()
        self.mask_token_id = tokenizer.mask_token_id

    def compute_loss(
        self,
        model: AutoModelForMaskedLM,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        return_outputs: bool = False,
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
        student_output = model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"]
        )
        teacher_predictions = self._teacher_predictions(
            input_ids=inputs["teacher_input_ids"],
            attention_mask=inputs["teacher_attention_mask"],
        )
        loss = self._student_loss(
            student_output=student_output,
            teacher_predictions=teacher_predictions,
            labels=inputs["labels"],
        )

        return (loss, student_output) if return_outputs else loss

    @torch.inference_mode()
    def _teacher_predictions(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        outputs_teacher = self.teacher(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        mask_indices = input_ids == self.mask_token_id
        return outputs_teacher.logits[mask_indices]

    def _student_loss(
        self,
        student_output: MaskedLMOutput,
        teacher_predictions: torch.Tensor,
        labels: torch.Tensor,
    ) -> torch.Tensor:
        # Calculating the student prediction is tricky.
        # Is the output for a multi token target the mean of the output for each token?
        # Should the loss instead be measured per token?

        # When calculating this it is very important to avoid breaking back propagation.
        # torch.cat will break back propagation, so the loss is calculated per row.

        losses = []
        for target, output, (start, length) in zip(
            teacher_predictions, student_output.logits, labels
        ):
            prediction = output[start : start + length]
            prediction = prediction.mean(dim=0)
            prediction = F.log_softmax(prediction / self.args.temperature, dim=-1)
            target = F.softmax(target / self.args.temperature, dim=-1)
            loss = F.kl_div(
                input=prediction[None, :],
                target=target[None, :],
                reduction="batchmean",
                log_target=False,
            )
            loss = loss * (self.args.temperature ** 2)
            losses.append(loss)
        return sum(losses) / len(losses)
Code
from typing import Any, Dict, List
from transformers import AutoTokenizer

class TeacherStudentCollator:
    """
    The teacher inputs need to be padded and have an associated attention mask.
    """

    def __init__(self, tokenizer: AutoTokenizer) -> None:
        self.tokenizer = tokenizer

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        teacher_inputs = [
            {"input_ids": row["teacher_input_ids"]}
            for row in features
        ]
        teacher_batch = self.tokenizer.pad(
            teacher_inputs,
            padding=True,
            return_tensors="pt",
        )
        padded_teacher_inputs = {
            "teacher_input_ids": teacher_batch["input_ids"],
            "teacher_attention_mask": teacher_batch["attention_mask"],
        }

        student_inputs = [
            {
                "input_ids": row["input_ids"],
                "labels": row["labels"][0] # known to have a single entry
            }
            for row in features
        ]
        padded_student_inputs = self.tokenizer.pad(
            student_inputs,
            padding=True,
            return_tensors="pt",
        )
        
        batch = {
            **padded_teacher_inputs, **padded_student_inputs
        }
        if "label" in batch:
            batch["labels"] = batch["label"]
            del batch["label"]
        if "label_ids" in batch:
            batch["labels"] = batch["label_ids"]
            del batch["label_ids"]

        return batch

There is one special bit of behaviour that needs to be fixed. The parameters that are passed to the trainer and collator are based on the parameters accepted by the model.forward function. I need to pass the teacher_input_ids and teacher_attention_mask to the trainer so I need to patch the forward method to accept these parameters even though it doesn’t use them.

Code
def add_parameters(model: AutoModelForMaskedLM) -> AutoModelForMaskedLM:
    old_forward = model.forward
    def fake_forward(
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        #
        teacher_input_ids=None,
        teacher_attention_mask=None,
    ):
        return old_forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            labels=labels,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
    model.forward = fake_forward

    return model
Code
from transformers import AutoModelForMaskedLM, AutoTokenizer

training_args = MultilingualMaskedPromptInternalizationTrainingArguments(
    report_to="none",

    output_dir=RUN_DIRECTORY,
    num_train_epochs=EPOCHS,
    # max_steps=MAX_STEPS,

    evaluation_strategy="steps",
    logging_steps=EVALUATION_STEPS,
    eval_steps=EVALUATION_STEPS,
    save_steps=EVALUATION_STEPS,

    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    fp16=False,
    learning_rate=LEARNING_RATE,
    seed=33,

    logging_dir=RUN_DIRECTORY / "logs",
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="loss",
    greater_is_better=False,

    temperature=TEMPERATURE,
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
teacher_model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME)
student_model = add_parameters(AutoModelForMaskedLM.from_pretrained(MODEL_NAME))
data_collator = TeacherStudentCollator(tokenizer=tokenizer)

trainer = MultilingualMaskedPromptInternalizationTrainer(
    model=student_model,
    args=training_args,
    teacher_model=teacher_model,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    data_collator=data_collator,
    tokenizer=tokenizer,
)

trainer.train()
[12088/12088 1:37:20, Epoch 8/8]
Step Training Loss Validation Loss
1000 0.833800 0.751095
2000 0.600800 0.731829
3000 0.508300 0.715951
4000 0.434500 0.708059
5000 0.388800 0.711071
6000 0.357900 0.723349
7000 0.315500 0.709269
8000 0.294100 0.715075
9000 0.273700 0.723964
10000 0.248100 0.729855
11000 0.237400 0.726436
12000 0.226100 0.722286

TrainOutput(global_step=12088, training_loss=0.39204590451551696, metrics={'train_runtime': 5840.7064, 'train_samples_per_second': 132.422, 'train_steps_per_second': 2.07, 'total_flos': 4.098902992389058e+16, 'train_loss': 0.39204590451551696, 'epoch': 8.0})
Code
student_model.save_pretrained(RUN_DIRECTORY / "best-model-2")

For clarity I saved this model twice - best-model is for 2 epochs and best-model-2 is for 8 epochs.

Evaluation

Now we can try out the model with different sentences to see what it predicts for the nouns.

Code
from transformers import AutoModelForMaskedLM, AutoTokenizer

# the best-model-2 was trained for 8 epochs, above
student_model = AutoModelForMaskedLM.from_pretrained(RUN_DIRECTORY / "best-model-2")
student_model.eval()

tokenizer = AutoTokenizer.from_pretrained("roberta-base")
Code
from typing import Tuple, List
import torch

@torch.no_grad()
def get_predictions(text: str, start: int, end: int) -> Tuple[str, List[str]]:
    tokens = tokenizer(text, return_tensors="pt")
    target = tokenizer.decode(tokens.input_ids[0, start:end])

    output = student_model(**tokens)
    predictions = output.logits[0, start:end].mean(dim=0)
    predicted_tokens = predictions.argsort(descending=True)[:10]
    predicted_words = [
        word.strip()
        for word in tokenizer.batch_decode(predicted_tokens)
    ]

    return target, predicted_words

The first comparison will be of the different senses of bass again. Here we want the first and third sentences to have a similar description and the second sentence to be distinct. The words used should also be descriptive of the distinct meaning.

Code
#collapse_input
first_phrase = "We spotted a large bass in the ocean."
second_phrase = "The bass player did not receive the acknowledgment she deserves."
third_phrase = "The black sea bass, is a member of the wreckfish family."

first_target, first_predicted_words = get_predictions(first_phrase, 5, 6)
second_target, second_predicted_words = get_predictions(second_phrase, 2, 3)
third_target, third_predicted_words = get_predictions(third_phrase, 4, 5)

print(f"First Phrase is: {first_phrase} Target is: {first_target}")
print(f"Description is: {', '.join(first_predicted_words)}")
print()

print(f"Second Phrase is: {second_phrase} Target is: {second_target}")
print(f"Description is: {', '.join(second_predicted_words)}")
print()

print(f"Third Phrase is: {third_phrase} Target is: {third_target}")
print(f"Description is: {', '.join(third_predicted_words)}")
print()

print(f"First & Second: {sorted(set(first_predicted_words) & set(second_predicted_words))}")
print(f"First & Third: {sorted(set(first_predicted_words) & set(third_predicted_words))}")
print(f"Second & Third: {sorted(set(second_predicted_words) & set(third_predicted_words))}")
First Phrase is: We spotted a large bass in the ocean. Target is:  bass
Description is: Species, Size, Color, Shape, Animal, Item, Body, Weight, Type, Sex

Second Phrase is: The bass player did not receive the acknowledgment she deserves. Target is:  bass
Description is: Position, Artist, Name, Player, Singer, Role, Note, Age, Function, Driver

Third Phrase is: The black sea bass, is a member of the wreckfish family. Target is:  bass
Description is: Species, Size, Animal, Color, Item, Type, Creature, Fish, Race, Sex

First & Second: []
First & Third: ['Animal', 'Color', 'Item', 'Sex', 'Size', 'Species', 'Type']
Second & Third: []

This has worked well, the first and third sentence have an \(\frac{7}{10}\) overlap. The second sentence has no overlap. Furthermore the words seem to fit the word sense well.

The next thing is to compare the same sentence in two different languages. Training did not use Spanish or English so we can use a sentence in both of them. Remember that while the teacher was taking prompted English sentences, the student model never saw an English input.

Code
#collapse_input
spanish_text = "Friday es mi canción favorita."
spanish_start, spanish_end = 1, 2

english_text = "Friday is my favourite song."
english_start, english_end = 1, 2

spanish_target, spanish_predicted_words = get_predictions(spanish_text, spanish_start, spanish_end)
english_target, english_predicted_words = get_predictions(english_text, english_start, english_end)

overlap = set(spanish_predicted_words) & set(english_predicted_words)
difference = set(spanish_predicted_words) ^ set(english_predicted_words)

print(f"Spanish Phrase is: {spanish_text}")
print(f"Spanish Description is: {', '.join(spanish_predicted_words)}")

print(f"English Phrase is: {english_text}")
print(f"English Description is: {', '.join(english_predicted_words)}")
print()

print(f"Description Overlap is: {', '.join(sorted(overlap))}")
print(f"Description Difference is: {', '.join(sorted(difference))}")
Spanish Phrase is: Friday es mi canción favorita.
Spanish Description is: Date, Day, Time, Days, When, Theme, Week, Reason, Weather, Event
English Phrase is: Friday is my favourite song.
English Description is: Day, Date, Time, Theme, Name, Event, Days, Month, Reason, Subject

Description Overlap is: Date, Day, Days, Event, Reason, Theme, Time
Description Difference is: Month, Name, Subject, Weather, Week, When

This is a similar outcome with a \(\frac{7}{10}\) overlap.

The bigger problem here is that the words used to describe Friday are wrong. This sentence refers to the song by Rebecca Black, not the day of the week. None of the words relate to the song sense.

The sense of the word is clearly spelt out by the sentence, rather than through cultural reference, so this is not a hard inference to make.

The next evaluation is the same word used with two different senses in the same sentence.

Code
#collapse_input
text = "I like to drive my Malibu while drinking Malibu."
print(f"Phrase is: {text}")

first_target, first_predicted_words = get_predictions(text, 6, 8)
second_target, second_predicted_words = get_predictions(text, 10, 12)

assert first_target == " Malibu"
assert second_target == " Malibu"

print(f"First Malibu (car) Description is: {', '.join(first_predicted_words)}")
print(f"Second Malibu (drink) Description is: {', '.join(second_predicted_words)}")

print()

print(f"First & Second: {sorted(set(first_predicted_words) & set(second_predicted_words))}")
print(f"First ^ Second: {sorted(set(first_predicted_words) ^ set(second_predicted_words))}")
Phrase is: I like to drive my Malibu while drinking Malibu.
First Malibu (car) Description is: Vehicle, Color, Destination, Driver, Car, Location, Year, Season, Model, Speed
Second Malibu (drink) Description is: Color, Vehicle, Destination, Driver, Car, Location, Year, Season, Speed, Weather

First & Second: ['Car', 'Color', 'Destination', 'Driver', 'Location', 'Season', 'Speed', 'Vehicle', 'Year']
First ^ Second: ['Model', 'Weather']

There is clearly a problem here as the two senses of the word should relate to driving the car and to the consumption of the alcoholic drink. Instead both senses relate to the car, with a \(\frac{9}{10}\) overlap.

As a final evaluation we can test some longer text that has several targets.

Code
#collapse_input
spanish_phrase = (
    "Retiremos el equipo de la cancha, "
    "Boca no merece jugar esta copa que "
    "hace tiempo viene siendo desprestigiada.\n"
    "Ya no se juega al futbol."
)

english_phrase = (
    "Let's remove the team from the field, "
    "Boca does not deserve to play this cup that "
    "has long been discredited. "
    "Football is no longer played."
)

print(f"Spanish Phrase is: {spanish_phrase}")
print(f"English Phrase is: {english_phrase}")
print()

for (spanish_start, spanish_end), (english_start, english_end) in [
    [[5, 7], [5, 6]], # equipo -> sport team
    [[12, 14], [10, 12]], # Boca -> football team in Argentinia
    [[21, 23], [18, 19]], # copa -> sport cup
    [[26, 29], [21, 22]], # tiempo -> long time
    [[47, 49], [25, 26]], # futbol -> football
]:
    spanish_target, spanish_description = get_predictions(
        spanish_phrase,
        spanish_start,
        spanish_end,
    )
    english_target, english_description =get_predictions(
        english_phrase,
        english_start,
        english_end,
    )
    overlap = set(spanish_description) & set(english_description)
    difference = set(spanish_description) ^ set(english_description)

    print(f"Spanish word is: {spanish_target}, English word is: {english_target}")
    print(f"Spanish Description is: {', '.join(spanish_description)}")
    print(f"English Description is: {', '.join(english_description)}")
    print(f"Overlap is: {', '.join(sorted(overlap))} ({len(overlap)})")
    print(f"Difference is: {', '.join(sorted(difference))} ({len(difference)})")
    print()
Spanish Phrase is: Retiremos el equipo de la cancha, Boca no merece jugar esta copa que hace tiempo viene siendo desprestigiada.
Ya no se juega al futbol.
English Phrase is: Let's remove the team from the field, Boca does not deserve to play this cup that has long been discredited. Football is no longer played.

Spanish word is:  equipo, English word is:  team
Spanish Description is: Team, Name, Position, Location, Player, Age, Purpose, Color, Goal, Organization
English Description is: Team, Position, Name, Location, Player, Goal, Teams, Age, Color, Destination
Overlap is: Age, Color, Goal, Location, Name, Player, Position, Team (8)
Difference is: Destination, Organization, Purpose, Teams (4)

Spanish word is:  Boca, English word is:  Boca
Spanish Description is: Color, Name, Food, Location, Type, Drink, Size, Age, Source, Flavor
English Description is: Color, Team, Location, Game, Size, Type, Player, Name, Age, Season
Overlap is: Age, Color, Location, Name, Size, Type (6)
Difference is: Drink, Flavor, Food, Game, Player, Season, Source, Team (8)

Spanish word is:  copa, English word is:  cup
Spanish Description is: Age, Purpose, Name, Size, Location, Color, Type, Function, Food, Time
English Description is: Game, Sport, Team, Position, Goal, Type, Size, Season, Color, Player
Overlap is: Color, Size, Type (3)
Difference is: Age, Food, Function, Game, Goal, Location, Name, Player, Position, Purpose, Season, Sport, Team, Time (14)

Spanish word is:  tiempo, English word is:  long
Spanish Description is: Size, Color, Age, Weather, Time, Shape, Food, Location, Weight, Purpose
English Description is: Age, Time, Reason, Date, Year, Name, Location, Purpose, Color, Season
Overlap is: Age, Color, Location, Purpose, Time (5)
Difference is: Date, Food, Name, Reason, Season, Shape, Size, Weather, Weight, Year (10)

Spanish word is:  futbol, English word is:  Football
Spanish Description is: Age, Purpose, Reason, Color, Time, Weather, Size, Location, Season, Activity
English Description is: Sport, Team, Football, Game, Position, Ball, Player, Goal, Season, Sports
Overlap is: Season (1)
Difference is: Activity, Age, Ball, Color, Football, Game, Goal, Location, Player, Position, Purpose, Reason, Size, Sport, Sports, Team, Time, Weather (18)

Here we can see a wide variation in performance. I’m most surprised with the performance of Football itself, which has almost no overlap.

Conclusions

I think this technique is promising but the model has suffered from some kind of collapse. It may’ve lost the overall sense of the word and it’s place in the sentence. Maybe doing a dual train with the classic language modelling task to retain the original language knowledge would help?

This is an easy extension as the student model can attempt to predict the masked target noun.

Another way to refine this would be to reconsider the combination of the token probability distribution. Instead of taking the mean of the tokens perhaps it would be better to just take the first token? The prompted model only has a single masked token to work with after all.

Finally is the use of temperature really helping? The language modelling task does not have such sharp probability distributions when compared to classifiers, so boosting the low probability entries may not help.

A hyperparameter search is always going to be worthwhile for this and a 2 epoch train is only 20 minutes, so it’s feasible to do.