Wikipedia Preprocessing Performance

Try to improve the preprocessing performance
Published

August 7, 2021

So I’ve got a problem with the amount of time that the data takes to process. In the last post I processed a single file containing about 21,000 pages and it took 12 hours, or about 2 seconds per page. If I were to process the entire wikipedia data dump like this it would take half a year!

This is clearly broken and I need to investigate it further.


Performance Investigation

What I found was that the actual row processing code takes about 10 ms to run. The %%timeit jupyter magic shows how long on average it takes to run a cell. So the problem isn’t the code that transforms the pages - what is it?

It turns out that the datasets.Dataset.map method hashes the functions that you pass to it. This allows it to create a cache for the transformations that you apply. My current thought is that the large amount of supplementary data that is associated with the transformation (like the input dataset, the pmi tokens, and the page indices) causes this hashing process to be extremely slow.

So I need to perform this transformation over pandas dataframes. Since I want to save the output I need this transformation to produce data that pandas will write nicely - plain python lists then. Here is the code I have ended up with:

Code
#collapse
from pathlib import Path
import pandas as pd

from typing import *
import torch
from transformers import AutoTokenizer

DATA_FOLDER = Path("/data/blog/2021-08-01-wikipedia-page-pmi/")
TITLE_TOKENS = sorted(DATA_FOLDER.glob("*-pmi.gz.parquet"))

token_df = pd.concat([
    pd.read_parquet(path)
    for path in TITLE_TOKENS
]).set_index("title")

def encode_row(
    tokenizer: AutoTokenizer,
    row: Union[pd.Series, Dict[str, Any]],
    title_to_index: Dict[str, int],
    max_length: int = 256,
) -> Dict[str, List[int]]:
    tokenized_text = tokenizer(
        row["text"],
        return_offsets_mapping=True,
        padding="max_length",
        truncation=True,
        max_length=max_length,
    )
    labels = _to_boundaries(
        token_offsets=tokenized_text["offset_mapping"],
        link_starts=row["start"],
        link_ends=row["end"],
        link_targets=row["link"],
        title_to_index=title_to_index,
    )
    return {
        "input_ids": tokenized_text["input_ids"],
        "attention_mask": tokenized_text["attention_mask"],
        "label": labels,
    }
    
def _to_boundaries(
    token_offsets: List[Tuple[int, int]],
    link_starts: List[int],
    link_ends: List[int],
    link_targets: List[str],
    title_to_index: Dict[str, int],
) -> torch.Tensor:
    boundaries = [[0, 0, 0]] * len(token_offsets)

    link_iter = zip(link_starts, link_ends, link_targets)
    try:
        link_start, link_end, link_target = _next_link(
            token_start=0,
            links=link_iter,
            title_to_index=title_to_index,
        )

        within = False
        for index, (token_start, token_end) in enumerate(token_offsets):
            if token_start >= link_end:
                link_start, link_end, link_target = _next_link(
                    token_start=0,
                    links=link_iter,
                    title_to_index=title_to_index,
                )
                within = False

            if token_start < link_end and token_end > link_start:
                boundaries[index] = [
                    0 if within else 1,
                    1,
                    link_target,
                ]
                within = True
    except StopIteration:
        pass

    return boundaries

def _next_link(
    token_start: int,
    links: Iterator[Tuple[int, int, str]],
    title_to_index: Dict[str, int],
) -> Tuple[int, int, int]:
    link_start, link_end, link_target = next(links)
    while token_start >= link_end or link_target not in title_to_index:
        link_start, link_end, link_target = next(links)
    return link_start, link_end, title_to_index[link_target]

This is a slight change to the approach that was used before. Instead of generating the output row by row the entire output is created in advance and then the rows that have values are filled out as they are processed.

The other change is to just use the index of the row instead of expanding this to the list of token indices. This makes implementing the model slightly easier - the tokens are used for binary cross entropy loss and that means I need to expand them out to the size of the vocabulary with 1s for the significan tokens. To achieve this without destroying my memory I am going to use a sparse tensor to hold the expanded form and then index it with these values.

I think that this simplifies the code quite a bit, both here and in the model.

So let’s see how it performs.


Performance Test

Let’s see how it performs. Remember that there are 6 million rows to encode, so if this takes 1ms per row then that is still almost two hours. These tests against a single row may not be representative when it is running against the entire dataset. They will give an idea though…

Code
from transformers import AutoTokenizer
import pandas as pd

tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")

df = pd.read_parquet(
    "/data/blog/2021-07-28-wikipedia-link-recognition/enwiki-20210701-pages-articles1.xml-p1p41242.gz.parquet"
)

title_to_index = pd.read_parquet(
    "/data/blog/2021-07-30-wikipedia-data-generation/title-to-index.gz.parquet"
)["index"].to_dict()
Code
%%timeit

tokenizer(
    df.iloc[0].text,
    return_offsets_mapping=True,
    padding="max_length",
    truncation=True,
    max_length=256,
)
1.3 ms ± 2.66 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Code
%%timeit

encode_row(
    tokenizer=tokenizer,
    row=df.iloc[0],
    title_to_index=title_to_index,
)
13.5 ms ± 40.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

These two tests show that about 1.3ms is spent tokenizing the text and then 13.5ms is spent on the entire encoding, so the boundary detection code takes around 12ms. This is actually quite slow - the previous version was around 9ms. I think this is acceptable as it works better with pandas, and I can save the encoded dataframes to work with them later. Doing it this way completely avoids the function hashing that was so slow earlier.

If I transform all of the files now I can then work on the model.

Code
%%time

def encode(row) -> Dict[str, Any]:
    return encode_row(
        tokenizer=tokenizer,
        row=row,
        title_to_index=title_to_index,
    )

processed_df = pd.merge(
    df,
    pd.DataFrame(df.apply(encode, axis="columns").tolist()),
    left_index=True,
    right_index=True
)

CPU times: user 4min 5s, sys: 7.57 s, total: 4min 13s
Wall time: 3min 29s
Code
processed_df.head()
title text link start end input_ids attention_mask label
0 Anarchism Anarchism is a political philosophy and moveme... [political philosophy, Political movement, aut... [15, 40, 70, 127, 179, 264, 317, 344, 362, 392... [35, 48, 79, 136, 184, 272, 336, 355, 383, 410... [0, 4688, 13161, 1809, 16, 10, 559, 10561, 8, ... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [...
1 Autism Autism is a developmental disorder characteriz... [developmental disorder, Regressive autism, de... [12, 308, 375, 461, 473, 562, 588, 612, 621, 6... [34, 318, 399, 468, 494, 569, 601, 619, 631, 6... [0, 37434, 1809, 16, 10, 18477, 8364, 17407, 3... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [...
2 Albedo sunlight relative to various surface conditio... [sunlight, diffuse reflection, sunlight, solar... [1, 117, 139, 172, 239, 397, 417, 820, 865, 14... [9, 135, 154, 187, 249, 406, 427, 839, 876, 14... [0, 20843, 5407, 7, 1337, 4084, 1274, 48392, 7... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [...
3 A A, or a, is the first letter and the first vow... [Letter (alphabet), vowel letter, English alph... [22, 43, 63, 95, 144, 168, 203, 224, 258, 598,... [28, 55, 86, 119, 145, 171, 223, 229, 267, 609... [0, 250, 6, 50, 10, 6, 16, 5, 78, 1601, 8, 5, ... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [...
4 Alabama Alabama () is a state in the Southeastern regi... [Southeastern United States, United States, Te... [29, 56, 83, 107, 128, 144, 177, 217, 246, 272... [41, 69, 92, 114, 135, 158, 188, 237, 264, 283... [0, 37388, 36418, 16, 10, 194, 11, 5, 208, 211... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [...

Processing a single file is quick enough. Lets do the rest.


Encode the Dataset

So now we can encode the entire dataset. If the performance test is accurate then this would still take around a day. I’ve got a good feeling about this though.

Code
#collapse
from transformers import AutoTokenizer
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm

tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
DATA_FOLDER = Path("/data/blog/2021-08-07-wikipedia-preprocessing-performance")

DATA_FOLDER.mkdir(exist_ok=True, parents=True)

def encode(row) -> Dict[str, Any]:
    return encode_row(
        tokenizer=tokenizer,
        row=row,
        title_to_index=title_to_index,
    )

for path in tqdm(sorted(
    Path("/data/blog/2021-07-28-wikipedia-link-recognition/")
        .glob("*.gz.parquet")
)):
    df = pd.read_parquet(path)
    df = pd.merge(
        df,
        pd.DataFrame(df.apply(encode, axis="columns").tolist()),
        left_index=True,
        right_index=True
    )
    df.to_parquet(DATA_FOLDER / path.name)

At 3h 41m that suggests that the per row speed was around 2.2ms. I’m not sure why it’s so much faster. Might be good to spot check one of the files.

Code
df = pd.read_parquet(sorted(DATA_FOLDER.glob("*.gz.parquet"))[-1])
df.head()
title text link start end input_ids attention_mask label
0 David Stagg David Stagg (born 18 October 1983, in Townsvil... [Townsville, Queensland, rugby league, Queensl... [38, 50, 99, 155, 206, 245, 275, 457, 487, 516... [48, 60, 111, 181, 222, 265, 304, 473, 490, 53... [0, 8773, 312, 7165, 36, 5400, 504, 779, 13668... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [...
1 KTXT KTXT may refer to: * KTXT-FM, a radio station ... [KTXT-FM, KTTZ-TV] [21, 100] [28, 107] [0, 33893, 29070, 189, 9115, 7, 35, 1009, 2404... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [...
2 Utica Psychiatric Center The Utica Psychiatric Center, also known as Ut... [Utica, New York, New York (state), Greek Revi... [76, 110, 319, 585, 1459, 1971, 3801, 4679, 55... [81, 118, 332, 617, 1474, 1979, 3846, 4704, 56... [0, 133, 11183, 2426, 43254, 824, 6, 67, 684, ... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [...
3 Olean Wholesale Grocery Olean Wholesale Grocery was a retailers' coop... [retailers' cooperative, supermarket, New York... [31, 74, 90, 100, 118, 166, 228, 250, 272, 363... [53, 85, 98, 112, 122, 211, 236, 267, 286, 390... [0, 384, 21926, 3990, 7003, 1627, 7461, 36597,... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [...
4 Clibanarii empire, at Taq-i Bostan, near Kermanshah, Ira... [Kermanshah, Iran, Grivpanvar, Sasanian, Byzan... [31, 43, 203, 222, 255, 290, 320, 652, 829, 87... [41, 47, 213, 230, 264, 303, 330, 662, 838, 89... [0, 15167, 6, 23, 9002, 1343, 12, 118, 163, 26... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [...

This seems fine.


Sparse Tensor Operations

The encoded dataset is not enough. I have the token indices for each link, which cannot be directly used as the loss target. What is needed is a full 50k token tensor that has the appropriate indicies set. Vivifying this would take a vast amount of memory - it would be something like 6,000,000 x 50,000 x 4 bytes.

This matrix is sparse though. Only \(\frac{1}{1000}\) entries are set to a non zero value. A sparse tensor could represent the entire matrix within my memory budget.

So lets make that first and then see how easy it is to work with.

Code
#collapse
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
Code
%%time

import torch

lm_targets = torch.sparse_coo_tensor(
    indices=[
        [
            row
            for row, tokens in enumerate(token_df.tokens.values)
            for _ in tokens
        ],
        [
            column
            for row, tokens in enumerate(token_df.tokens.values)
            for column in tokens
        ],
    ],
    values=torch.ones(len(token_df) * 50),
    size=(len(token_df), tokenizer.vocab_size)
)
CPU times: user 35.4 s, sys: 4.17 s, total: 39.6 s
Wall time: 39 s
Code
%%timeit

torch.cat([
    lm_targets[index].to_dense()[None, :]
    for index in range(3)
])
1.62 s ± 3.25 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

There is clearly a problem here. The sparse tensor is way way too slow. I can’t use it directly for the binary cross entropy loss as binary cross entropy doesn’t support the sparse tensor. I need a better way to handle this.

I still think that the tensor cannot be converted to a dense tensor and fit into memory. We can check this though.

Code
lm_targets = lm_targets.to_dense()
RuntimeError: [enforce fail at CPUAllocator.cpp:67] . DefaultCPUAllocator: can't allocate memory: you tried to allocate 1272403786680 bytes. Error code 12 (Cannot allocate memory)

Not surprising - that’s over 1 TB of memory that it tried to allocate.

I’ll have to create the targets as dense tensors just before running the loss function. Doing this means that only a single batch of targets will be vivified at any one time.


Generate BCE Tensor at Runtime

The last thing is to try efficiently forming the target block using index operations. I was hoping not to do this as I feel that it is a bottleneck for the training performance. I’ll have to check that later.

The first thing to do is to convert the token dataframe into a tensor that we can use to efficiently look up the indexes that need to be set.

Code
BATCH_SIZE = 32
EPOCHS = 5
MODEL_NAME = "facebook/bart-base"
Code
#collapse
import torch
import numpy as np

token_indices = np.concatenate(token_df.tokens.values).reshape(-1, 50)
token_indices = torch.tensor(token_indices, dtype=torch.int)
token_indices = token_indices.cuda()
token_indices.shape
torch.Size([6328478, 50])

Now we can write something that can efficiently create the bce target for the link tokens. This works by flattening out the tensor and using offsets for each successive target. Doing this seems easier to understand than trying to set the N dimensional indexes correctly.

Code
#collapse
import torch
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def calculate_boundary_loss(
    predictions: torch.Tensor,
    all_labels: torch.Tensor
) -> torch.Tensor:
    """ Calculate the loss for the boundary predictions (start link, within link).
        The predictions are only the boundary predictions.
        The labels combine the boundary labels and the link target index. """
    return torch.nn.functional.binary_cross_entropy_with_logits(
        predictions.reshape(-1, 2),
        all_labels.reshape(-1, 3)[:, :2].float()
    )

def calculate_link_loss(
    predictions: torch.Tensor,
    all_labels: torch.Tensor, # [3] start, is_link, index
    token_indices: np.array, # index -> 50 tokens
    vocab_size: int = tokenizer.vocab_size,
) -> torch.Tensor:
    """ Calculate the loss for the link predictions.
        The labels for this are only valid within a link.
        The predictions are only the link target predictions.
        The labels combine the boundary labels and the link target index. """
    flat_indexes = all_labels.view(-1, 3).long()
    mask = flat_indexes[:, 1] > 0
    flat_indexes = flat_indexes[mask][:, 2]
    flat_predictions = predictions.view(-1, vocab_size)[mask]
    rows = flat_indexes.shape[0]

    targets = torch.zeros(vocab_size * rows, device=predictions.device)
    target_offsets = torch.tensor(range(rows), device=predictions.device) * vocab_size
    target_indexes = (
        token_indices[flat_indexes] + target_offsets[:, None]
    ).flatten()
    targets[target_indexes] = 1

    return torch.nn.functional.binary_cross_entropy_with_logits(
        flat_predictions,
        targets.view(-1, vocab_size)
    )
Code
%%timeit

calculate_boundary_loss(
    predictions=boundary_predictions,
    all_labels=all_labels
)
73.6 µs ± 182 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Code
%%timeit

calculate_link_loss(
    predictions=link_predictions,
    all_labels=all_labels,
    token_indices=token_indices
)
108 ms ± 7.11 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

This takes quite a while for the token loss. That makes me sad. I should break this time down to work out how much is in forming the tensor and how much is in the binary cross entropy loss. After all, I can’t shorten the binary cross entropy calculation but the target calculation can be worked on.

Anyway lets define the model (much the same as before) and then try training it.

Code
#collapse
from transformers import BartForConditionalGeneration, BartConfig
from transformers.models.bart.modeling_bart import shift_tokens_right
import torch

class BartForLinks(BartForConditionalGeneration):
    def __init__(self, config: BartConfig) -> None:
        super().__init__(config)
        self.link_head = torch.nn.Linear(in_features=config.d_model, out_features=2, bias=True)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        labels=None,
    ):
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
        )
        link_logits = self.lm_head(outputs[0]) + self.final_logits_bias
        boundary_logits = self.link_head(outputs[0])
        logits = torch.cat([
            link_logits,
            boundary_logits,
        ], dim=-1)

        output = (logits,) + outputs[1:]

        if labels is not None:
            boundary_loss = calculate_boundary_loss(
                predictions=boundary_logits,
                all_labels=labels
            )
            link_loss = calculate_link_loss(
                predictions=link_logits,
                all_labels=labels,
                token_indices=token_indices,
            )
            loss = boundary_loss + link_loss
            return ((loss,) + output)
        return output
Code
from transformers import AutoTokenizer
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
DATA_FOLDER = Path("/data/blog/2021-08-07-wikipedia-preprocessing-performance")

df = pd.read_parquet(sorted(DATA_FOLDER.glob("*.gz.parquet"))[-1])
df = df[["input_ids", "attention_mask", "label"]]
df
input_ids attention_mask label
0 [0, 8773, 312, 7165, 36, 5400, 504, 779, 13668... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [...
1 [0, 33893, 29070, 189, 9115, 7, 35, 1009, 2404... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [...
2 [0, 133, 11183, 2426, 43254, 824, 6, 67, 684, ... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [...
3 [0, 384, 21926, 3990, 7003, 1627, 7461, 36597,... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [...
4 [0, 15167, 6, 23, 9002, 1343, 12, 118, 163, 26... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [...
... ... ... ...
164852 [0, 11094, 3572, 5277, 424, 36, 5400, 195, 759... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [...
164853 [0, 1655, 1738, 312, 1722, 25420, 36, 2036, 49... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [...
164854 [0, 495, 41839, 38847, 17737, 16254, 102, 16, ... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [...
164855 [0, 133, 2197, 9, 5, 5428, 22699, 27975, 6, 67... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [...
164856 [0, 133, 511, 16, 10, 889, 9, 5, 1609, 12, 721... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [...

164857 rows × 3 columns

Code
import datasets

ds = datasets.Dataset.from_pandas(df)
split = ds.train_test_split(test_size=BATCH_SIZE*100)
split["train"]
Dataset({
    features: ['input_ids', 'attention_mask', 'label'],
    num_rows: 161657
})

It’s worth pointing out that this is ~160k rows to train with compared with ~20k before. This will be a comparison of training for longer as well as the changes to loss calculation.

Code
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = BartForLinks.from_pretrained(MODEL_NAME)
Some weights of BartForLinks were not initialized from the model checkpoint at facebook/bart-base and are newly initialized: ['link_head.weight', 'link_head.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Code
from pathlib import Path
from transformers import Trainer, TrainingArguments

MODEL_RUN_FOLDER = Path("/data/blog/2021-08-07-wikipedia-preprocessing-performance/runs")
MODEL_RUN_FOLDER.mkdir(parents=True, exist_ok=True)

training_args = TrainingArguments(
    report_to=[],            
    output_dir=MODEL_RUN_FOLDER / "output",
    overwrite_output_dir=True,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=5e-5,
    warmup_ratio=0.06,
    num_train_epochs=EPOCHS,
    evaluation_strategy="steps",
    logging_dir=MODEL_RUN_FOLDER / "output",
    logging_steps=1_000,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=split["train"],
    eval_dataset=split["test"],
    tokenizer=tokenizer,
)

trainer.train()
[25260/25260 6:22:40, Epoch 5/5]
Step Training Loss Validation Loss Runtime Samples Per Second
1000 0.165800 0.098542 30.318200 105.547000
2000 0.096400 0.095966 30.434200 105.145000
3000 0.092100 0.089802 30.372700 105.358000
4000 0.089500 0.086527 30.340600 105.469000
5000 0.087800 0.085644 30.414800 105.212000
6000 0.083200 0.085244 30.349200 105.439000
7000 0.082800 0.083528 30.374700 105.351000
8000 0.082300 0.083160 30.370200 105.366000
9000 0.081400 0.083385 30.405800 105.243000
10000 0.080800 0.081835 30.352400 105.428000
11000 0.075800 0.085361 30.379900 105.333000
12000 0.075200 0.085136 30.416100 105.207000
13000 0.075000 0.085434 30.355800 105.416000
14000 0.075000 0.084871 30.363100 105.391000
15000 0.074800 0.083396 30.339500 105.473000
16000 0.070600 0.085049 30.269200 105.718000
17000 0.069700 0.084664 30.351900 105.430000
18000 0.069500 0.083927 30.352500 105.428000
19000 0.069200 0.084356 30.287100 105.656000
20000 0.069100 0.083348 30.310400 105.574000
21000 0.066400 0.086040 30.319200 105.544000
22000 0.064900 0.086213 30.297800 105.618000
23000 0.065100 0.085610 30.277200 105.690000
24000 0.065400 0.085772 30.247700 105.793000
25000 0.065100 0.085953 30.321100 105.537000

TrainOutput(global_step=25260, training_loss=0.07957131958913917, metrics={'train_runtime': 22960.9925, 'train_samples_per_second': 1.1, 'total_flos': 1.7309594740053504e+17, 'epoch': 5.0, 'init_mem_cpu_alloc_delta': -154177536, 'init_mem_gpu_alloc_delta': 558664704, 'init_mem_cpu_peaked_delta': 154218496, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 21311488, 'train_mem_gpu_alloc_delta': 1683520512, 'train_mem_cpu_peaked_delta': 312684544, 'train_mem_gpu_peaked_delta': 17412875264})

This is running about 20% faster than before (~105 samples per second compared to ~85 before) which is nice.

Code
model.save_pretrained("/data/blog/2021-08-07-wikipedia-preprocessing-performance/model")

Evaluation

Since I haven’t produced nice metrics yet I’ll have to review the links that are spotted and the top tokens. This should give me an idea of if it’s working well or not.

Code
from typing import *
import torch

@torch.no_grad()
def infer(text: str) -> List[Tuple[str, Optional[List[str]]]]:
    tokens = tokenizer(
        text,
        return_attention_mask=False,
        return_tensors="pt",
        max_length=256,
        truncation=True,
    )["input_ids"]
    tokens = tokens.to(model.device)

    output = model(tokens)[0]
    lm_output = output[0, :, :model.config.vocab_size]
    link_tokens = lm_output.argsort(dim=-1, descending=True)[:, :50]
    is_link = (output[0, :, -1] > 0).tolist()

    text_tokens = tokenizer.batch_decode(tokens[0, :, None])

    return [
        (token, tokenizer.batch_decode(link[:10]) if is_link else None)
        for token, link, is_link in zip(text_tokens, link_tokens[:, :, None].tolist(), is_link)
    ]
Code
model.eval() ; None
Code
infer("I like to drive my Malibu")
[('<s>', None),
 ('I', None),
 (' like', None),
 (' to', None),
 (' drive', None),
 (' my', None),
 (' Mal',
  [' headlights',
   'ibu',
   ' chrome',
   ' Chevy',
   ' convertible',
   ' windshield',
   ' Toyota',
   ' bumper',
   ' sedan',
   ' Honda']),
 ('ibu',
  ['ibu',
   ' headlights',
   ' sedan',
   ' Chevy',
   ' chrome',
   ' convertible',
   ' bumper',
   ' Chevrolet',
   ' windshield',
   ' cars']),
 ('</s>', None)]
Code
infer("I like to drive to Malibu")
[('<s>', None),
 ('I', None),
 (' like', None),
 (' to', None),
 (' drive', None),
 (' to', None),
 (' Mal', None),
 ('ibu', None),
 ('</s>', None)]
Code
infer("I like to drink my Malibu")
[('<s>', None),
 ('I', None),
 (' like', None),
 (' to', None),
 (' drink', None),
 (' my', None),
 (' Mal',
  [' drank',
   ' bottles',
   ' drinks',
   ' Drink',
   ' bottled',
   ' drink',
   ' bottle',
   ' vodka',
   'ibu',
   ' champagne']),
 ('ibu',
  ['ibu',
   ' bottles',
   ' bottled',
   ' Drink',
   ' drink',
   ' drinks',
   ' bottle',
   ' drank',
   ' cocktails',
   ' CBD']),
 ('</s>', None)]
Code
infer("I like to listen to Malibu")
[('<s>', None),
 ('I', None),
 (' like', None),
 (' to', None),
 (' listen', None),
 (' to', None),
 (' Mal',
  [' louder',
   ' stereo',
   ':',
   'ibu',
   ' Kanye',
   ' rap',
   ' voic',
   ' rappers',
   ' Mal',
   ' music']),
 ('ibu',
  ['ibu',
   ' stereo',
   ' Mal',
   ' tunes',
   ' rap',
   ' Nirvana',
   'Mal',
   ' rappers',
   ' music',
   ' louder']),
 ('</s>', None)]
Code
infer("I like to read Malibu")
[('<s>', None),
 ('I', None),
 (' like', None),
 (' to', None),
 (' read', None),
 (' Mal',
  ['ibu',
   ' Mal',
   'Mal',
   ' in',
   '.',
   '?:',
   ':',
   ' trope',
   ' coolest',
   ' bikini']),
 ('ibu',
  ['ibu',
   ' Mal',
   'Mal',
   'azon',
   ' ebook',
   ' Kindle',
   ' surfing',
   ' Coconut',
   ' paperback',
   ' myths']),
 ('</s>', None)]
Code
infer("""
On Thursday 2nd September I will officially be Scottish.
I have ordered my kilt and sporran. Fortunately I can
already play the bagpipes so I should be able to ace the
citizens test even though I dislike Scotch whisky.""")
[('<s>', None),
 ('\n', None),
 ('On', None),
 (' Thursday', None),
 (' 2', None),
 ('nd', None),
 (' September', None),
 (' I', None),
 (' will', None),
 (' officially', None),
 (' be', None),
 (' Scottish',
  [' Scottish',
   ' Scots',
   ' Scotch',
   ' accents',
   ' vowel',
   'anism',
   ' pronunciation',
   ' fuelled',
   ' separat',
   ' Presbyter']),
 ('.', None),
 ('\n', None),
 ('I', None),
 (' have', None),
 (' ordered', None),
 (' my', None),
 (' k', None),
 ('ilt', None),
 (' and', None),
 (' spor', None),
 ('ran', None),
 ('.', None),
 (' Fortunately', None),
 (' I', None),
 (' can', None),
 ('\n', None),
 ('al', None),
 ('ready', None),
 (' play', None),
 (' the', None),
 (' bag', None),
 ('p', None),
 ('ipes', None),
 (' so', None),
 (' I', None),
 (' should', None),
 (' be', None),
 (' able', None),
 (' to', None),
 (' ace', None),
 (' the', None),
 ('\n', None),
 ('citizens', None),
 (' test', None),
 (' even', None),
 (' though', None),
 (' I', None),
 (' dislike', None),
 (' Scotch',
  [' Scotch',
   ' whisky',
   ' whiskey',
   ' drank',
   ' vodka',
   ' distilled',
   ' drinkers',
   ' alcohol',
   ' champagne',
   ' liquor']),
 (' whisky',
  [' whisky',
   ' Scotch',
   ' whiskey',
   ' vodka',
   ' champagne',
   ' malt',
   ' distilled',
   ' bottles',
   ' whisk',
   ' beverages']),
 ('.', None),
 ('</s>', None)]
Code
infer("Just tried using Chrome - did not work")
[('<s>', None),
 ('Just', None),
 (' tried', None),
 (' using', None),
 (' Chrome',
  [' Chrome',
   ' Google',
   ' Gmail',
   ' browsers',
   'Google',
   ' browser',
   ' Firefox',
   ' HTML',
   ' Browser',
   'google']),
 (' -', None),
 (' did', None),
 (' not', None),
 (' work', None),
 ('</s>', None)]

These results seem really good? The drive to Malibu one didn’t spot the Malibu entity, and the Scottish utterance didn’t spot quite a few things, so link spotting could do with some work. The qualities of the links seem extremely appropriate.

What is needed at this point is a way to work out the closest link to a given output. The thing I wanted to use most is cosine similarity, however this has a problem with vivifying the entire index. Using faiss might be a suitable solution as it can handle indexes that exceed the available memory.

There is something to consider with all of this - the target vectors have only 1 or 0 values. Would it be suitable to just calculate the intesection of the 50 tokens per page? Could softmax the output and then sum the indexes.

Code
#collapse
from typing import *
import pandas as pd
import torch

@torch.no_grad()
def link_match(
    text: str,
    token_indices: torch.Tensor,
    token_df: pd.DataFrame,
) -> List[Tuple[str, Optional[pd.Series]]]:
    tokens = tokenizer(
        text,
        return_attention_mask=False,
        return_tensors="pt",
        max_length=256,
        truncation=True,
    )["input_ids"]
    tokens = tokens.to(model.device)

    output = model(tokens)[0]
    lm_output = output[0, :, :model.config.vocab_size]
    is_link = (output[0, :, -1] > 0).tolist()
    text_tokens = tokenizer.batch_decode(tokens[0, :, None])

    return [
        (token, _best_match(token_lm_output, token_indices, token_df) if is_link else None)
        for token, token_lm_output, is_link in zip(text_tokens, lm_output, is_link)
        if is_link
    ]

@torch.no_grad()
def _best_match(
    lm_output: torch.Tensor, token_indices: torch.Tensor, token_df: pd.DataFrame, count: int = 10
) -> List[Tuple[str, float]]:
    lm_output = lm_output.softmax(dim=-1)
    similarity, indices = (
        lm_output[token_indices.long()]
            .sum(dim=-1)
            .sort(descending=True)
    )
    titles = token_df.index[indices[:count].tolist()]
    
    return [
        (title, score.item())
        for title, score in zip(titles, similarity[:count])
    ]
Code
link_match("I like to drive my Malibu", token_indices, token_df)
[(' Mal',
  [('Pontiac Grand Prix', 0.06864984333515167),
   ('Chevrolet Chevelle', 0.06712870299816132),
   ('Volkswagen Beetle', 0.06323488801717758),
   ('Buick LeSabre', 0.06209747493267059),
   ('Chevrolet Malibu', 0.06151050329208374),
   ('Chevrolet Chevette', 0.06141059845685959),
   ('Cadillac de Ville series', 0.061309993267059326),
   ('Cadillac Sixty Special', 0.06083429604768753),
   ('AMC Ambassador', 0.06063532456755638),
   ('Checker Taxi', 0.05995693802833557)]),
 ('ibu',
  [('Chevrolet Chevelle', 0.21512983739376068),
   ('Chevrolet Malibu', 0.21005548536777496),
   ('Chevrolet Chevy Malibu', 0.17331840097904205),
   ('GM Uzbekistan', 0.17286597192287445),
   ('Daewoo Magnus', 0.170789435505867),
   ('General Motors Lambda platform', 0.16903597116470337),
   ('General Motors G platform (1969)', 0.16706596314907074),
   ('General Motors L platform', 0.16692540049552917),
   ('General Motors Epsilon platform', 0.16670387983322144),
   ('Chevrolet Constantia', 0.16264280676841736)])]
Code
link_match("I like to drink my Malibu", token_indices, token_df)
[(' Mal',
  [('Rev (drink)', 0.05509297549724579),
   ('List of national drinks', 0.052384063601493835),
   ('Nutcracker (drink)', 0.04882398247718811),
   ('Hall & Woodhouse', 0.04539231210947037),
   ('Alcopop', 0.04496541991829872),
   ('Grodziskie', 0.0443902313709259),
   ('Sun Drop', 0.04417373612523079),
   ("Allen's Coffee Brandy", 0.04390135407447815),
   ('List of drinks', 0.04351271688938141),
   ('Bawls', 0.042810361832380295)]),
 ('ibu',
  [('Rev (drink)', 0.07191766798496246),
   ('Bawls', 0.07085655629634857),
   ('Nutcracker (drink)', 0.06999941170215607),
   ('Sun Drop', 0.0686754584312439),
   ('Orijin', 0.06737425923347473),
   ("Allen's Coffee Brandy", 0.0665014386177063),
   ('Inca Kola', 0.06455733627080917),
   ("Grandpa Graf's", 0.06308974325656891),
   ('Ramune', 0.06274813413619995),
   ('List of drinks', 0.06196099519729614)])]
Code
link_match("Just tried using Chrome - did not work", token_indices, token_df)
[(' Chrome',
  [('Google Docs', 0.12235280871391296),
   ('Google Quick Search Box', 0.12225416302680969),
   ('Browser extension', 0.11824813485145569),
   ('Google Chrome Experiments', 0.11770586669445038),
   ('Browser synchronization', 0.11544383317232132),
   ('Site-specific browser', 0.11451999843120575),
   ('Google Safe Browsing', 0.11403866112232208),
   ('Browser wars', 0.11401140689849854),
   ('Google Now', 0.11334848403930664),
   ('Internet Explorer', 0.11278551071882248)])]
Code
link_match("""
On Thursday 2nd September I will officially be Scottish.
I have ordered my kilt and sporran. Fortunately I can
already play the bagpipes so I should be able to ace the
citizens test even though I dislike Scotch whisky.""", token_indices, token_df)
[(' Scottish',
  [('Scottish people', 0.05110573768615723),
   ('Outline of Scotland', 0.04648435115814209),
   ('Scottish Cant', 0.03993150591850281),
   ('Borderers', 0.03975354880094528),
   ('Penny Scots', 0.038706615567207336),
   ('Caledonians (disambiguation)', 0.03863389417529106),
   ('Scotts', 0.03849662095308304),
   ('Scottish New Zealanders', 0.0384201854467392),
   ('Cohee', 0.03766922652721405),
   ('List of Scottish Americans', 0.037501320242881775)]),
 (' Scotch',
  [('Scotch whisky', 0.41742798686027527),
   ('Beam Suntory', 0.41561058163642883),
   ('Whisky', 0.41511985659599304),
   ('Sazerac Company', 0.41439545154571533),
   ('Irish coffee', 0.4140739440917969),
   ('Well drink', 0.4129604697227478),
   ('Pot still', 0.4125441312789917),
   ('Indian whisky', 0.41232988238334656),
   ('German whisky', 0.4116297662258148),
   ('Single malt whisky', 0.4113738238811493)]),
 (' whisky',
  [('Whisky', 0.6420206427574158),
   ('Scotch whisky', 0.6416874527931213),
   ('German whisky', 0.6406434774398804),
   ('Vat 69', 0.6404236555099487),
   ('Beam Suntory', 0.6403605937957764),
   ('Copper Fox Distillery', 0.6396106481552124),
   ('Glenfiddich', 0.6391409039497375),
   ('Diageo', 0.6386524438858032),
   ('Single malt whisky', 0.6385629177093506),
   ('Outline of whisky', 0.6385478377342224)])]

At this point I really need to group the entity tokens together and work with the outputs as a unit. It certainly feels like there is something good here.

Code
#collapse
from typing import *
import pandas as pd
import torch

@torch.no_grad()
def grouped_link_match(
    text: str,
    token_indices: torch.Tensor,
    token_df: pd.DataFrame,
) -> List[Tuple[str, Optional[pd.Series]]]:
    tokens = tokenizer(
        text,
        return_attention_mask=False,
        return_tensors="pt",
        max_length=256,
        truncation=True,
    )["input_ids"]
    tokens = tokens.to(model.device)

    output = model(tokens)[0]
    links = _group_links(tokens[0], output[0])
    return [
        (link, _best_match(lm_output, token_indices, token_df))
        for link, lm_output in links
    ]

@torch.no_grad()
def _best_match(
    lm_output: torch.Tensor, token_indices: torch.Tensor, token_df: pd.DataFrame, count: int = 10
) -> List[Tuple[str, float]]:
    lm_output = lm_output.softmax(dim=-1)
    similarity, indices = (
        lm_output[token_indices.long()]
            .sum(dim=-1)
            .sort(descending=True)
    )
    titles = token_df.index[indices[:count].tolist()]
    
    return [
        (title, score.item())
        for title, score in zip(titles, similarity[:count])
    ]

@torch.no_grad()
def _group_links(
    tokens: List[int],
    output: torch.Tensor
) -> List[Tuple[str, torch.Tensor]]:
    link_output = output[:, model.config.vocab_size:]
    lm_output = output[:, :model.config.vocab_size]

    current_lm = None
    current_tokens = None
    results = []
    for token, link, lm in zip(tokens, link_output, lm_output):
        # link starts when link_output[0] > 0,
        # link continues as long as link_output[0] <= 0 and link_output[1] > 0
        if current_lm is not None:
            if link[1] > 0 and link[0] <= 0:
                current_lm += lm
                current_tokens.append(token)
                continue
            results.append((tokenizer.decode(current_tokens), current_lm))
            current_lm = None
            current_tokens = None
        if link[0] > 0:
            current_lm = lm
            current_tokens = [token]
    if current_lm is not None:
        results.append((tokenizer.decode(current_tokens), current_lm))

    return results
Code
grouped_link_match("I like to drive my Malibu", token_indices, token_df)
[(' Malibu',
  [('Chevrolet Chevelle', 0.8978313207626343),
   ('Chevrolet Malibu', 0.8612598776817322),
   ('Tokyo Metro 1000 series', 0.7493164539337158),
   ('Chevrolet Chevy Malibu', 0.6922683715820312),
   ('GM Uzbekistan', 0.6906720995903015),
   ('Early fuel evaporator', 0.66644287109375),
   ('Ramos Arizpe Assembly', 0.6644928455352783),
   ('TC Mouras', 0.6639765501022339),
   ('TC Pista Mouras', 0.6638318300247192),
   ('Peekskill meteorite', 0.6621743440628052)])]
Code
grouped_link_match("I like to drink Malibu", token_indices, token_df)
[(' Malibu',
  [('Rev (drink)', 0.4330148696899414),
   ('Nutcracker (drink)', 0.4130439758300781),
   ('List of national drinks', 0.3681219220161438),
   ('Drinking culture of Korea', 0.3641522526741028),
   ('Daiquiri', 0.35108959674835205),
   ('Astro pop (cocktail)', 0.3366580605506897),
   ('Black drink', 0.32350510358810425),
   ('List of drinks', 0.3150339722633362),
   ('Alcopop', 0.31245726346969604),
   ('Soft drink', 0.30603936314582825)])]
Code
grouped_link_match("""
On Thursday 2nd September I will officially be Scottish.
I have ordered my kilt and sporran. Fortunately I can
already play the bagpipes so I should be able to ace the
citizens test even though I dislike Scotch whisky.""", token_indices, token_df)
[(' Scottish',
  [('Scottish people', 0.05110573768615723),
   ('Outline of Scotland', 0.04648435115814209),
   ('Scottish Cant', 0.03993150591850281),
   ('Borderers', 0.03975354880094528),
   ('Penny Scots', 0.038706615567207336),
   ('Caledonians (disambiguation)', 0.03863389417529106),
   ('Scotts', 0.03849662095308304),
   ('Scottish New Zealanders', 0.0384201854467392),
   ('Cohee', 0.03766922652721405),
   ('List of Scottish Americans', 0.037501320242881775)]),
 (' Scotch whisky',
  [('Beam Suntory', 0.9981793165206909),
   ('Irish coffee', 0.9980509877204895),
   ('Sazerac Company', 0.9980474710464478),
   ('Irish whiskey', 0.9980360269546509),
   ('Master blender', 0.9980276226997375),
   ('Brown–Forman', 0.9978804588317871),
   ('Diageo', 0.9978688359260559),
   ('Well drink', 0.9978491067886353),
   ('Scotch whisky', 0.9978407621383667),
   ('German whisky', 0.997822642326355)])]
Code
grouped_link_match("""
Written and Directed by the guy who did the voice of Bunny in Toy Story 4
""", token_indices, token_df)
[(' Toy Story 4',
  [('Toy (disambiguation)', 0.9941051006317139),
   ('Toy Story That Time Forgot', 0.9892314672470093),
   ('Toyfinity', 0.9786142110824585),
   ('Eddy Goldfarb', 0.977798342704773),
   ('Transogram', 0.977685272693634),
   ('Toy forts and castles', 0.977618396282196),
   ('Toy museum', 0.977617621421814),
   ('Penny toy', 0.977617621421814),
   ('Ronnen Harary', 0.9776003360748291),
   ('Toy Biz', 0.9775903820991516)])]

This shows what happens when the language model output is summed. What about if it is softmaxed first and then the average of the softmax scores is taken?

Code
#collapse
from typing import *
import pandas as pd
import torch

@torch.no_grad()
def grouped_link_match(
    text: str,
    token_indices: torch.Tensor,
    token_df: pd.DataFrame,
) -> List[Tuple[str, Optional[pd.Series]]]:
    tokens = tokenizer(
        text,
        return_attention_mask=False,
        return_tensors="pt",
        max_length=256,
        truncation=True,
    )["input_ids"]
    tokens = tokens.to(model.device)

    output = model(tokens)[0]
    links = _group_links(tokens[0], output[0])
    return [
        (link, _best_match(lm_output, token_indices, token_df))
        for link, lm_output in links
    ]

@torch.no_grad()
def _best_match(
    lm_output: torch.Tensor, token_indices: torch.Tensor, token_df: pd.DataFrame, count: int = 10
) -> List[Tuple[str, float]]:
    similarity, indices = (
        lm_output[token_indices.long()]
            .sum(dim=-1)
            .sort(descending=True)
    )
    titles = token_df.index[indices[:count].tolist()]
    
    return [
        (title, score.item())
        for title, score in zip(titles, similarity[:count])
    ]

@torch.no_grad()
def _group_links(
    tokens: List[int],
    output: torch.Tensor
) -> List[Tuple[str, torch.Tensor]]:
    # [LM_OUT] : IS_START, IS_LINK
    link_output = output[:, model.config.vocab_size:]
    lm_output = output[:, :model.config.vocab_size].softmax(dim=-1)

    current_lm = None
    current_tokens = None
    results = []
    for token, link, lm in zip(tokens, link_output, lm_output):
        # link starts when link_output[0] > 0,
        # link continues as long as link_output[0] <= 0 and link_output[1] > 0
        if current_lm is not None:
            if link[1] > 0 and link[0] <= 0:
                current_lm += lm
                current_tokens.append(token)
                continue
            results.append((tokenizer.decode(current_tokens), current_lm / len(current_tokens)))
            current_lm = None
            current_tokens = None
        if link[0] > 0:
            current_lm = lm
            current_tokens = [token]
    if current_lm is not None:
        results.append((tokenizer.decode(current_tokens), current_lm / len(current_tokens)))

    return results
Code
grouped_link_match("I like to drive my Malibu", token_indices, token_df)
[(' Malibu',
  [('Chevrolet Chevelle', 0.0009975789580494165),
   ('Chevrolet Malibu', 0.0009974718559533358),
   ('Chevrolet Chevy Malibu', 0.0009967833757400513),
   ('GM Uzbekistan', 0.0009967612568289042),
   ('Daewoo Magnus', 0.0009967193473130465),
   ('General Motors Lambda platform', 0.0009966971119865775),
   ('General Motors L platform', 0.0009966607904061675),
   ('General Motors Epsilon platform', 0.0009966471698135138),
   ('General Motors G platform (1969)', 0.000996638904325664),
   ('Early fuel evaporator', 0.0009966373909264803)])]
Code
grouped_link_match("I like to drink Malibu", token_indices, token_df)
[(' Malibu',
  [('Rev (drink)', 0.06084747612476349),
   ('List of national drinks', 0.05884766951203346),
   ('List of drinks', 0.05846615880727768),
   ('Nutcracker (drink)', 0.05762157216668129),
   ('Alcopop', 0.055029839277267456),
   ('Drinking culture of Korea', 0.05182889848947525),
   ('Cocktail', 0.051059555262327194),
   ('Non-alcoholic drink', 0.05031442642211914),
   ('Astro pop (cocktail)', 0.05030803754925728),
   ('Daiquiri', 0.050109438598155975)])]
Code
grouped_link_match("""
On Thursday 2nd September I will officially be Scottish.
I have ordered my kilt and sporran. Fortunately I can
already play the bagpipes so I should be able to ace the
citizens test even though I dislike Scotch whisky.""", token_indices, token_df)
[(' Scottish',
  [('Scottish people', 0.05110573023557663),
   ('Outline of Scotland', 0.04648435115814209),
   ('Scottish Cant', 0.03993150219321251),
   ('Borderers', 0.039753541350364685),
   ('Penny Scots', 0.03870661184191704),
   ('Caledonians (disambiguation)', 0.03863389045000076),
   ('Scotts', 0.03849661722779274),
   ('Scottish New Zealanders', 0.0384201817214489),
   ('Cohee', 0.03766922280192375),
   ('List of Scottish Americans', 0.037501316517591476)]),
 (' Scotch whisky',
  [('Scotch whisky', 0.5295577645301819),
   ('Whisky', 0.5285703539848328),
   ('Beam Suntory', 0.5279856324195862),
   ('German whisky', 0.5261366963386536),
   ('Sazerac Company', 0.5258281230926514),
   ('Well drink', 0.5252501368522644),
   ('Copper Fox Distillery', 0.5252254605293274),
   ('Irish coffee', 0.5250889658927917),
   ('Single malt whisky', 0.5249685049057007),
   ('Indian whisky', 0.5249568819999695)])]
Code
grouped_link_match("""
Written and Directed by the guy who did the voice of Bunny in Toy Story 4
""", token_indices, token_df)
[(' Toy Story 4',
  [('Toy (disambiguation)', 0.11604376882314682),
   ('Toy Story That Time Forgot', 0.11078159511089325),
   ('Toyfinity', 0.10907316952943802),
   ('Eddy Goldfarb', 0.10809780657291412),
   ('Toy forts and castles', 0.10793942213058472),
   ('Ronnen Harary', 0.10709124058485031),
   ('Transogram', 0.10682820528745651),
   ('Toy museum', 0.10634520649909973),
   ('Toy Biz', 0.1063094362616539),
   ('Penny toy', 0.1061907410621643)])]
Code
grouped_link_match("I saw the Statue of Liberty today", token_indices, token_df)
[(' Statue of Liberty',
  [('Replicas of the Statue of Liberty', 0.001020157360471785),
   ('Frédéric Auguste Bartholdi', 0.001020144671201706),
   ('Statue of Liberty', 0.0010201444383710623),
   ('Statue of Liberty in popular culture', 0.0010201348923146725),
   ('Statue of Liberty National Monument', 0.0010201276745647192),
   ('Egypt Carrying the Light to Asia', 0.00102012709248811),
   ('Strengthen the Arm of Liberty Monument (Pine Bluff, Arkansas)',
    0.0010201262775808573),
   ('Eugene Daub', 0.0010201223194599152),
   ('Sancarlone', 0.0010201220866292715),
   ('Statue of Liberty (Seattle)', 0.001020121737383306)])]

When I do this the noticeable difference is that the Scotch whisky disambiguation becomes correct. Fundamentally I think that playing around with the disambiguation approaches is fun but it needs to match the training target. I need to ensure that the training of the model and the entity disambiguation are aligned so that the training can improve the results.