Wikipedia Link Model

First go at training a Wikipedia Link recognizer
Published

August 4, 2021

Now that I’ve processed the data I can finally train a model. Obviously the first thing to do is to process the data a bit more!

This time I need to generate the tokens for the text and work out what each link is targetting. Let’s smash that out and then work on the model definition.


Dataset

So, the final stretch. This needs to take the text for the page and tokenize it. The tokens that are part of a link then need to be marked up with the PMI tokens. I’ve already come up with code that can do the link markup, so I just need to apply it now.

Code
from pathlib import Path
import pandas as pd

DATA_FOLDER = Path("/data/blog/2021-08-01-wikipedia-page-pmi/")
TITLE_TOKENS = sorted(DATA_FOLDER.glob("*-pmi.gz.parquet"))

token_df = pd.concat([
    pd.read_parquet(path)
    for path in TITLE_TOKENS
]).set_index("title")
CPU times: user 4.86 s, sys: 1.71 s, total: 6.57 s
Wall time: 6.07 s
Code
from typing import *
import torch
from transformers import AutoTokenizer

ZERO = torch.zeros(50, dtype=torch.int) - 1
NO_ENTITY = torch.cat([
    torch.tensor([False, False]),
    ZERO,
])

def encode(
    tokenizer: AutoTokenizer,
    row: Union[pd.Series, Dict[str, Any]],
    link_tokens: pd.DataFrame,
    max_length: int = 256,
) -> Dict[str, torch.Tensor]:
    tokenized_text = tokenizer(
        row["text"],
        return_offsets_mapping=True,
        padding="max_length",
        truncation=True,
        max_length=max_length,
    )
    labels = to_boundaries(
        token_offsets=tokenized_text["offset_mapping"],
        link_starts=row["start"],
        link_ends=row["end"],
        link_targets=row["link"],
        link_tokens=link_tokens,
    )
    return {
        "input_ids": torch.tensor(tokenized_text["input_ids"], dtype=torch.int),
        "attention_mask": torch.tensor(tokenized_text["attention_mask"], dtype=torch.int),
        "label": labels,
    }
    
def to_boundaries(
    token_offsets: List[Tuple[int, int]],
    link_starts: List[int],
    link_ends: List[int],
    link_targets: List[str],
    link_tokens: pd.DataFrame,
) -> torch.Tensor:
    boundaries = []

    link_iter = zip(link_starts, link_ends, link_targets)
    try:
        link_start, link_end, link_target = next(link_iter)

        while link_target not in link_tokens.index:
            link_start, link_end, link_target = next(link_iter)
        tokens = link_tokens.loc[link_target].item()

        within = False
        for token_start, token_end in token_offsets:
            if token_start == token_end: # zero width token
                boundaries.append(
                    torch.cat([
                        torch.tensor([False, within]),
                        ZERO,
                    ])
                )
                continue

            while token_start >= link_end:
                link_start, link_end, link_target = next(link_iter)
                within = False

            if token_start < link_end and token_end > link_start:
                # inside link
                boundaries.append(
                    torch.cat([
                        torch.tensor([not within, True]),
                        torch.tensor(tokens, dtype=torch.int),
                    ])
                )
                within = True
            else:
                boundaries.append(NO_ENTITY)
    except StopIteration:
        boundaries += [NO_ENTITY] * (len(token_offsets) - len(boundaries))

    return torch.cat(boundaries).reshape(len(token_offsets), -1)
Code
from transformers import AutoTokenizer
import datasets
import pandas as pd

tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")

df = pd.read_parquet(
    "/data/blog/2021-07-28-wikipedia-link-recognition/enwiki-20210701-pages-articles1.xml-p1p41242.gz.parquet"
)
ds = datasets.Dataset.from_pandas(df)
Code
ds = ds.map(lambda row: encode(tokenizer, row, link_tokens=token_df))

df = ds.to_pandas()
df.to_parquet(
    "/data/blog/2021-08-04-complete-wikipedia-data/dataset.gz.parquet",
    compression="gzip",
)
Code
import pandas as pd

df = (
    pd.read_parquet("/data/blog/2021-08-04-complete-wikipedia-data/dataset.gz.parquet")
        .drop(columns=["end", "link", "start", "text", "title"])
)

So this is incredibly slow. 12 hours to process 21k rows! I’ll need to improve this quite a lot to be able to use all the data.


Create Model

Now we need to define the model. I’ve commented out large parts of this because the tokens seem to get broken when the decoder input ids are generated. This is something for me to investigate further, but for now the task is not conditional generation so it’s unlikely that the specific decoder input ids are needed?

Code
from transformers import BartForConditionalGeneration, BartConfig
from transformers.models.bart.modeling_bart import shift_tokens_right
import torch

class BartForLinks(BartForConditionalGeneration):
    def __init__(self, config: BartConfig) -> None:
        super().__init__(config)
        self.link_head = torch.nn.Linear(in_features=config.d_model, out_features=2, bias=True)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        # decoder_input_ids=None,
        # decoder_attention_mask=None,
        # head_mask=None,
        # decoder_head_mask=None,
        # encoder_outputs=None,
        # past_key_values=None,
        # inputs_embeds=None,
        # decoder_inputs_embeds=None,
        labels=None,
        # use_cache=None,
        # output_attentions=None,
        # output_hidden_states=None,
        # return_dict=None,
    ):
        # if labels is not None:
        #     if decoder_input_ids is None:
        #         decoder_input_ids = shift_tokens_right(
        #             labels, self.config.pad_token_id, self.config.decoder_start_token_id
        #         )

        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            # decoder_input_ids=decoder_input_ids,
            # encoder_outputs=encoder_outputs,
            # decoder_attention_mask=decoder_attention_mask,
            # head_mask=head_mask,
            # decoder_head_mask=decoder_head_mask,
            # past_key_values=past_key_values,
            # inputs_embeds=inputs_embeds,
            # decoder_inputs_embeds=decoder_inputs_embeds,
            # use_cache=use_cache,
            # output_attentions=output_attentions,
            # output_hidden_states=output_hidden_states,
            # return_dict=return_dict,
        )
        lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
        link_logits = self.link_head(outputs[0])
        logits = torch.cat([
            lm_logits,
            link_logits,
        ], dim=-1)

        output = (logits,) + outputs[1:]

        if labels is not None:
            # calculate labels 0-1 against link logits as bce
            # calculate labels 2-52 against lm logits as bce
            flat_link_logits = link_logits.view(-1, 2)
            flat_labels = labels.view(-1, 52)

            link_loss = torch.nn.functional.binary_cross_entropy_with_logits(
                flat_link_logits, flat_labels[:, :2].float()
            )

            # having trouble producing the correct mask for entities
            lm_mask = flat_labels[:, 2:].min(dim=-1).values >= 0
            lm_predictions = lm_logits.view(-1, self.config.vocab_size)[lm_mask]
            lm_tokens = flat_labels[lm_mask, 2:]

            # need to make the targets for the tokens
            # this is a tensor of mostly zeros where the flat_labels are used to set the 1s
            # there will be 50 1s per entity token, so the tensor is quite sparse
            selected_tokens, token_count = lm_tokens.shape
            lm_targets = torch.sparse_coo_tensor(
                indices=torch.cat([
                    (
                        torch.tensor(range(selected_tokens), device=self.device)
                            .repeat_interleave(token_count)
                            .view(1,-1)
                    ),
                    lm_tokens.view(1, -1)
                ]),
                values=torch.ones(selected_tokens * token_count, device=self.device),
                size=(selected_tokens, self.config.vocab_size),
            )

            token_loss = torch.nn.functional.binary_cross_entropy_with_logits(
                lm_predictions, lm_targets.to_dense()
            )

            loss = link_loss + token_loss
            return ((loss,) + output)
        return output
Code
import datasets

ds = datasets.Dataset.from_pandas(df)
Code
split = ds.train_test_split(test_size=0.1)
Code
BATCH_SIZE = 32
EPOCHS = 5
MODEL_NAME = "facebook/bart-base"
Code
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = BartForLinks.from_pretrained(MODEL_NAME)
Some weights of BartForLinks were not initialized from the model checkpoint at facebook/bart-base and are newly initialized: ['link_head.weight', 'link_head.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Code
from pathlib import Path
from transformers import Trainer, TrainingArguments

MODEL_RUN_FOLDER = Path("/data/blog/2021-08-04-complete-wikipedia-data/runs")
MODEL_RUN_FOLDER.mkdir(parents=True, exist_ok=True)

training_args = TrainingArguments(
    report_to=[],            
    output_dir=MODEL_RUN_FOLDER / "output",
    overwrite_output_dir=True,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=5e-5,
    warmup_ratio=0.06,
    num_train_epochs=EPOCHS,
    evaluation_strategy="steps",
    logging_dir=MODEL_RUN_FOLDER / "output",
    logging_steps=100,
    # load_best_model_at_end=True,
    # metric_for_best_model="quality",
    # greater_is_better=True,
    
    # no_cuda = True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=split["train"],
    eval_dataset=split["test"],
    tokenizer=tokenizer,
    # compute_metrics=compute_metrics,
)

trainer.train()
[2965/2965 58:09, Epoch 5/5]
Step Training Loss Validation Loss Runtime Samples Per Second
100 0.443200 0.188154 24.273200 86.845000
200 0.172800 0.150623 24.637300 85.561000
300 0.152500 0.141921 25.081800 84.045000
400 0.143400 0.135531 25.065300 84.100000
500 0.143600 0.133825 24.953300 84.478000
600 0.138800 0.142380 24.943000 84.513000
700 0.130600 0.132518 24.893000 84.682000
800 0.126800 0.131047 24.915800 84.605000
900 0.126600 0.132441 24.946000 84.503000
1000 0.127000 0.128251 24.940200 84.522000
1100 0.125100 0.129588 24.885500 84.708000
1200 0.122300 0.131716 24.886600 84.704000
1300 0.114600 0.131400 24.821000 84.928000
1400 0.114400 0.133982 24.927100 84.566000
1500 0.114000 0.127237 24.833000 84.887000
1600 0.114100 0.131260 24.896600 84.670000
1700 0.114400 0.128378 24.928500 84.562000
1800 0.112500 0.130995 24.898400 84.664000
1900 0.107500 0.131809 24.866900 84.771000
2000 0.106900 0.128181 24.917400 84.599000
2100 0.106500 0.131768 24.903500 84.647000
2200 0.105600 0.130586 24.911500 84.620000
2300 0.105000 0.131242 24.905000 84.642000
2400 0.105600 0.131699 24.863200 84.784000
2500 0.100200 0.134511 24.852700 84.820000
2600 0.099500 0.136548 24.955500 84.470000
2700 0.101900 0.132327 24.894600 84.677000
2800 0.099900 0.133143 24.925000 84.574000
2900 0.099800 0.132836 24.918800 84.595000

TrainOutput(global_step=2965, training_loss=0.1295419500006793, metrics={'train_runtime': 3490.3774, 'train_samples_per_second': 0.849, 'total_flos': 2.03123287094784e+16, 'epoch': 5.0, 'init_mem_cpu_alloc_delta': 2571313152, 'init_mem_gpu_alloc_delta': 558664704, 'init_mem_cpu_peaked_delta': 152035328, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 37343232, 'train_mem_gpu_alloc_delta': 1687041024, 'train_mem_cpu_peaked_delta': 308629504, 'train_mem_gpu_peaked_delta': 18292079104})
Code
model.save_pretrained("/data/blog/2021-08-04-complete-wikipedia-data/model")
Code
from typing import *
import torch

@torch.no_grad()
def infer(text: str) -> List[Tuple[str, List[str]]]:
    tokens = tokenizer(
        text,
        return_attention_mask=False,
        return_tensors="pt",
        max_length=256,
        truncation=True,
    )["input_ids"]
    tokens = tokens.to(model.device)

    output = model(tokens)[0]
    lm_output = output[0, :, :model.config.vocab_size]
    link_tokens = lm_output.argsort(dim=-1, descending=True)[:, :50]
    is_link = (output[0, :, -1] > 0).tolist()

    text_tokens = tokenizer.batch_decode(tokens[0, :, None])
    # return [
    #     (token, is_link)
    #     for token, is_link in zip(text_tokens, is_link)
    # ]
    return [
        (token, tokenizer.batch_decode(link[:10]) if is_link else None)
        for token, link, is_link in zip(text_tokens, link_tokens[:, :, None].tolist(), is_link)
    ]
Code
model.eval() ; None
Code
infer("I like to drive my Malibu")
[('<s>', None),
 ('I', None),
 (' like', None),
 (' to', None),
 (' drive', None),
 (' my', None),
 (' Mal',
  [' Mal',
   ' Honda',
   ' Ferrari',
   ' Bentley',
   ' Chevy',
   '.',
   'ibu',
   ':',
   ' Jaguar',
   ' in']),
 ('ibu',
  ['ibu',
   ' convertible',
   ' Chevy',
   ' sedan',
   ' Jaguar',
   ' Bentley',
   ' Corvette',
   ' Model',
   ' Honda',
   ' Ferrari']),
 ('</s>', None)]
Code
infer("I like my Malibu on the rocks")
[('<s>', None),
 ('I', None),
 (' like', None),
 (' my', None),
 (' Mal',
  [' Mal', ':', '.', ' Honda', ' in', ' on', 'ibu', ' from', 'Mal', ',']),
 ('ibu',
  ['ibu',
   ' Beach',
   ' a',
   ' convertible',
   ' by',
   ' to',
   ' in',
   'boat',
   ' infinity',
   ' on']),
 (' on', None),
 (' the', None),
 (' rocks', None),
 ('</s>', None)]
Code
infer("I like to drink my Malibu")
[('<s>', None),
 ('I', None),
 (' like', None),
 (' to', None),
 (' drink', None),
 (' my', None),
 (' Mal',
  [' Mal',
   ' Coconut',
   ':',
   ' *',
   '.',
   ' from',
   ' avocado',
   ' in',
   ' on',
   ' hottest']),
 ('ibu',
  ['ibu',
   ' cocktails',
   ' vodka',
   ' cocktail',
   ' whiskey',
   ' whisky',
   ' sushi',
   ' wine',
   ' bourbon',
   ' brewed']),
 ('</s>', None)]
Code
infer("I like to drive to Malibu")
[('<s>', None),
 ('I', None),
 (' like', None),
 (' to', None),
 (' drive', None),
 (' to', None),
 (' Mal',
  [' Mal',
   ':',
   'ibu',
   ' from',
   ' Venice',
   'Mal',
   ' sunset',
   ' coolest',
   ' in',
   '.']),
 ('ibu',
  ['ibu',
   ' Beach',
   ' beach',
   ' neighborhoods',
   ' Highlands',
   ' Springs',
   ' fueled',
   ' Canyon',
   ' wildfires',
   ' by']),
 ('</s>', None)]
Code
infer("I like to drink my Coke")
[('<s>', None),
 ('I', None),
 (' like', None),
 (' to', None),
 (' drink', None),
 (' my', None),
 (' Coke', None),
 ('</s>', None)]
Code
infer("Florentino Perez hasn't taken losing out on Neymar to Barcelona well at all")
[('<s>', None),
 ('Fl', ['Fl', ' Flore', ' fl', 'FL', 'R', 'fl', 'Er', 'FS', 'Real', '.']),
 ('ore',
  [' Flore',
   'FC',
   ' Pep',
   ' discontent',
   'nt',
   ' Di',
   'Fl',
   ' combinations',
   'cel',
   ' punishments']),
 ('nt',
  [' Flore',
   ' Pep',
   'nt',
   ' combinations',
   ' discontent',
   ':',
   ' Bern',
   ' Messi',
   ' Barcelona',
   ' rightly']),
 ('ino',
  [' Flore',
   ' Pep',
   ':',
   ' Bern',
   ' Barcelona',
   ' Messi',
   ' discontent',
   ' arrivals',
   ' rightly',
   ' massively']),
 (' Perez',
  [' Flore',
   ' Pep',
   ' Perez',
   ':',
   ' Bern',
   ' Messi',
   ' Barcelona',
   ' discontent',
   ' combinations',
   ' Re']),
 (' hasn', None),
 ("'t", None),
 (' taken', None),
 (' losing', None),
 (' out', None),
 (' on', None),
 (' Ney',
  [' Ney',
   ' Messi',
   ':',
   ' stopp',
   ' on',
   ' in',
   ' *',
   ' Ronaldo',
   ' with',
   ' a']),
 ('mar', None),
 (' to', None),
 (' Barcelona', None),
 (' well', None),
 (' at', None),
 (' all', None),
 ('</s>', None)]

So these seem pretty good considering this has trained on only a fraction of the full dataset.

It would be good to make the final preprocessing step more efficient - it would take half a year to process the entire dataset! The next thing would be to make the model training faster, again if that were run over the whole dataset it would be about a day per epoch.