Wikipedia IOB Link Boundaries

Link spotting seems weak, using inside - outside - beginning as the classifier could help?
Published

August 21, 2021

I’ve had trouble with the boundary detection of the wikipedia link model. It’s been suggested that the way that I am detecting the links is the problem. Instead of using two binary classifiers a three class classifier would be better - the IOB classifier {% cite DBLP:journals/corr/cmp-lg-9505040 %}.

This should be straightforward to implement so lets get started. I need to adjust the tagging slightly as the rules are slightly different - the boundary class is only used for separate links that are sequential (and so would be treated as a single link by the tagging). I’m inclined to require that the first token be the beginning tag to try to train this behaviour in - I do not think that sequential links are at all common.


Dataset Preparation

The first thing to do is to reprocess the dataset to change the link labels. It’s going from two binary classifiers to a 3 class classifier. This isn’t a big change from last time.

Code
#collapse

from typing import *
import torch
from transformers import AutoTokenizer
from pathlib import Path
import pandas as pd

OUTSIDE_CLASS = 0
INSIDE_CLASS = 1
BEGINNING_CLASS = 2

TITLE_TOKENS = sorted(
    Path("/data/blog/2021-08-01-wikipedia-page-pmi/")
        .glob("*-pmi.gz.parquet")
)

token_df = pd.concat([
    pd.read_parquet(path)
    for path in TITLE_TOKENS
]).set_index("title")

title_to_index = pd.read_parquet(
    "/data/blog/2021-07-30-wikipedia-data-generation/title-to-index.gz.parquet"
)["index"].to_dict()

def encode_row(
    tokenizer: AutoTokenizer,
    row: Union[pd.Series, Dict[str, Any]],
    title_to_index: Dict[str, int],
    max_length: int = 256,
) -> Dict[str, List[int]]:
    tokenized_text = tokenizer(
        row["text"],
        return_offsets_mapping=True,
        padding="max_length",
        truncation=True,
        max_length=max_length,
    )
    labels = _to_boundaries(
        token_offsets=tokenized_text["offset_mapping"],
        link_starts=row["start"],
        link_ends=row["end"],
        link_targets=row["link"],
        title_to_index=title_to_index,
    )
    return {
        "input_ids": tokenized_text["input_ids"],
        "attention_mask": tokenized_text["attention_mask"],
        "label": labels,
    }
    
def _to_boundaries(
    token_offsets: List[Tuple[int, int]],
    link_starts: List[int],
    link_ends: List[int],
    link_targets: List[str],
    title_to_index: Dict[str, int],
) -> torch.Tensor:
    boundaries = [[0, 0]] * len(token_offsets)

    link_iter = zip(link_starts, link_ends, link_targets)
    try:
        link_start, link_end, link_target = _next_link(
            token_start=0,
            links=link_iter,
            title_to_index=title_to_index,
        )

        within = False
        for index, (token_start, token_end) in enumerate(token_offsets):
            if token_start >= link_end:
                link_start, link_end, link_target = _next_link(
                    token_start=0,
                    links=link_iter,
                    title_to_index=title_to_index,
                )
                within = False

            if token_start < link_end and token_end > link_start:
                boundaries[index] = [
                    INSIDE_CLASS if within else BEGINNING_CLASS,
                    link_target,
                ]
                within = True
    except StopIteration:
        pass

    return boundaries

def _next_link(
    token_start: int,
    links: Iterator[Tuple[int, int, str]],
    title_to_index: Dict[str, int],
) -> Tuple[int, int, int]:
    link_start, link_end, link_target = next(links)
    while token_start >= link_end or link_target not in title_to_index:
        link_start, link_end, link_target = next(links)
    return link_start, link_end, title_to_index[link_target]
Code
#collapse
from transformers import AutoTokenizer
import pandas as pd

tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")

df = pd.read_parquet(
    "/data/blog/2021-07-28-wikipedia-link-recognition/enwiki-20210701-pages-articles1.xml-p1p41242.gz.parquet"
)

One thing to consider is how fast this encoding scheme is. I’m not expecting big changes from the previous run of 4m cpu time.

Code
%%time

def encode(row) -> Dict[str, Any]:
    return encode_row(
        tokenizer=tokenizer,
        row=row,
        title_to_index=title_to_index,
    )

processed_df = pd.merge(
    df,
    pd.DataFrame(df.apply(encode, axis="columns").tolist()),
    left_index=True,
    right_index=True
)
CPU times: user 3min 9s, sys: 6.22 s, total: 3min 16s
Wall time: 2min 39s

It’s quite a bit better! I wonder if I’ll ever deeply understand python performance. The difference is really a three wide row becomes two wide.

Anyway, lets complete the preprocessing of the dataset and then we can train.

Code
#collapse

from transformers import AutoTokenizer
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm

tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
DATA_FOLDER = Path("/data/blog/2021-08-21-link-evaluation")

DATA_FOLDER.mkdir(exist_ok=True, parents=True)

def encode(row) -> Dict[str, Any]:
    return encode_row(
        tokenizer=tokenizer,
        row=row,
        title_to_index=title_to_index,
    )

for path in tqdm(sorted(
    Path("/data/blog/2021-07-28-wikipedia-link-recognition/")
        .glob("*.gz.parquet")
)):
    destination = DATA_FOLDER / path.name
    if destination.exists():
        continue

    df = pd.read_parquet(path)
    df = pd.merge(
        df,
        pd.DataFrame(df.apply(encode, axis="columns").tolist()),
        left_index=True,
        right_index=True
    )
    df.to_parquet(destination)

Training

Now we can repeat the training process. Once again this is very similar to the previous post. The one difference is that the boundary loss is now using cross entropy.

Code
BATCH_SIZE = 32
EPOCHS = 5
MODEL_NAME = "facebook/bart-base"
Code
#collapse
import torch
import numpy as np

token_indices = np.concatenate(token_df.tokens.values).reshape(-1, 50)
token_indices = torch.tensor(token_indices, dtype=torch.int)
token_indices = token_indices.cuda()
token_indices.shape
torch.Size([6328478, 50])
Code
#collapse
import torch
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def calculate_boundary_loss(
    predictions: torch.Tensor,
    all_labels: torch.Tensor
) -> torch.Tensor:
    """ Calculate the loss for the boundary predictions (outside, inside, beginning).
        The predictions are only the boundary predictions.
        The labels combine the boundary labels and the link target index. """
    return torch.nn.functional.cross_entropy(
        predictions.reshape(-1, 3),
        all_labels.reshape(-1, 2)[:, 0]
    )

def calculate_link_loss(
    predictions: torch.Tensor,
    all_labels: torch.Tensor, # [3] start, is_link, index
    token_indices: np.array, # index -> 50 tokens
    vocab_size: int = tokenizer.vocab_size,
) -> torch.Tensor:
    """ Calculate the loss for the link predictions.
        The labels for this are only valid within a link.
        The predictions are only the link target predictions.
        The labels combine the boundary labels and the link target index. """
    flat_indexes = all_labels.view(-1, 2).long()
    mask = flat_indexes[:, 0] > 0
    flat_indexes = flat_indexes[mask][:, 1]
    flat_predictions = predictions.view(-1, vocab_size)[mask]
    rows = flat_indexes.shape[0]

    targets = torch.zeros(vocab_size * rows, device=predictions.device)
    target_offsets = torch.tensor(range(rows), device=predictions.device) * vocab_size
    target_indexes = (
        token_indices[flat_indexes] + target_offsets[:, None]
    ).flatten()
    targets[target_indexes] = 1

    return torch.nn.functional.binary_cross_entropy_with_logits(
        flat_predictions,
        targets.view(-1, vocab_size)
    )
Code
#collapse
from transformers import BartForConditionalGeneration, BartConfig
from transformers.models.bart.modeling_bart import shift_tokens_right
import torch

class BartForLinks(BartForConditionalGeneration):
    def __init__(self, config: BartConfig) -> None:
        super().__init__(config)
        self.link_head = torch.nn.Linear(in_features=config.d_model, out_features=3, bias=True)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        labels=None,
    ):
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
        )
        link_logits = self.lm_head(outputs[0]) + self.final_logits_bias
        boundary_logits = self.link_head(outputs[0])
        logits = torch.cat([
            link_logits,
            boundary_logits,
        ], dim=-1)

        output = (logits,) + outputs[1:]

        if labels is not None:
            boundary_loss = calculate_boundary_loss(
                predictions=boundary_logits,
                all_labels=labels
            )
            link_loss = calculate_link_loss(
                predictions=link_logits,
                all_labels=labels,
                token_indices=token_indices,
            )
            loss = boundary_loss + link_loss
            return ((loss,) + output)
        return output
Code
from transformers import AutoTokenizer
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
DATA_FOLDER = Path("/data/blog/2021-08-21-link-evaluation")

df = pd.read_parquet(sorted(DATA_FOLDER.glob("*.gz.parquet"))[-1])
df = df[["input_ids", "attention_mask", "label"]]
df
input_ids attention_mask label
0 [0, 8773, 312, 7165, 36, 5400, 504, 779, 13668... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0...
1 [0, 33893, 29070, 189, 9115, 7, 35, 1009, 2404... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0...
2 [0, 133, 11183, 2426, 43254, 824, 6, 67, 684, ... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0...
3 [0, 384, 21926, 3990, 7003, 1627, 7461, 36597,... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0...
4 [0, 15167, 6, 23, 9002, 1343, 12, 118, 163, 26... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0...
... ... ... ...
164852 [0, 11094, 3572, 5277, 424, 36, 5400, 195, 759... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0...
164853 [0, 1655, 1738, 312, 1722, 25420, 36, 2036, 49... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0...
164854 [0, 495, 41839, 38847, 17737, 16254, 102, 16, ... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0...
164855 [0, 133, 2197, 9, 5, 5428, 22699, 27975, 6, 67... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0...
164856 [0, 133, 511, 16, 10, 889, 9, 5, 1609, 12, 721... [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0...

164857 rows × 3 columns

Code
import datasets

ds = datasets.Dataset.from_pandas(df)
split = ds.train_test_split(test_size=BATCH_SIZE*100)
split["train"]
Dataset({
    features: ['input_ids', 'attention_mask', 'label'],
    num_rows: 161657
})
Code
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = BartForLinks.from_pretrained(MODEL_NAME)
Some weights of BartForLinks were not initialized from the model checkpoint at facebook/bart-base and are newly initialized: ['link_head.weight', 'link_head.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Code
from pathlib import Path
from transformers import Trainer, TrainingArguments

MODEL_RUN_FOLDER = Path("/data/blog/2021-08-21-link-evaluation/runs")
MODEL_RUN_FOLDER.mkdir(parents=True, exist_ok=True)

training_args = TrainingArguments(
    report_to=[],            
    output_dir=MODEL_RUN_FOLDER / "output",
    overwrite_output_dir=True,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=5e-5,
    warmup_ratio=0.06,
    num_train_epochs=EPOCHS,
    evaluation_strategy="steps",
    logging_dir=MODEL_RUN_FOLDER / "output",
    logging_steps=1_000,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=split["train"],
    eval_dataset=split["test"],
    tokenizer=tokenizer,
)

trainer.train()
[25260/25260 6:20:51, Epoch 5/5]
Step Training Loss Validation Loss Runtime Samples Per Second
1000 0.240600 0.151266 29.664400 107.874000
2000 0.144800 0.139962 29.945300 106.862000
3000 0.138400 0.134879 30.177600 106.039000
4000 0.134200 0.129425 30.204200 105.946000
5000 0.131800 0.131041 30.252100 105.778000
6000 0.124400 0.127548 30.361300 105.397000
7000 0.123700 0.125118 30.153500 106.124000
8000 0.122700 0.124304 30.119300 106.244000
9000 0.121700 0.124913 30.231700 105.849000
10000 0.120600 0.123535 30.430100 105.159000
11000 0.112600 0.128089 30.353300 105.425000
12000 0.111700 0.128253 30.139900 106.172000
13000 0.111200 0.129223 30.279900 105.681000
14000 0.111200 0.127928 30.347000 105.447000
15000 0.111200 0.125633 30.547500 104.755000
16000 0.104300 0.128413 30.297300 105.620000
17000 0.102800 0.126575 30.462400 105.047000
18000 0.102600 0.126771 30.314600 105.560000
19000 0.102200 0.127372 30.014700 106.614000
20000 0.102000 0.125605 30.007300 106.641000
21000 0.097500 0.130568 30.025900 106.575000
22000 0.095000 0.130214 30.281900 105.674000
23000 0.095400 0.129343 30.247800 105.793000
24000 0.095900 0.129448 30.339400 105.473000
25000 0.095400 0.129902 30.009200 106.634000

TrainOutput(global_step=25260, training_loss=0.11793305856881402, metrics={'train_runtime': 22852.2055, 'train_samples_per_second': 1.105, 'total_flos': 1.7309690213384448e+17, 'epoch': 5.0, 'init_mem_cpu_alloc_delta': -154144768, 'init_mem_gpu_alloc_delta': 558667776, 'init_mem_cpu_peaked_delta': 154222592, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 2867200, 'train_mem_gpu_alloc_delta': 1683478528, 'train_mem_cpu_peaked_delta': 330969088, 'train_mem_gpu_peaked_delta': 17412948992})
Code
model.save_pretrained("/data/blog/2021-08-21-link-evaluation/model")

Evaluation

I can just repeat the same evaluation as before. The lack of metrics is really hurting at this point as these evaluations are in no way systematic, and it takes quite a long time to train. Anyway just looking at the output should work for now.

Code
model = BartForLinks.from_pretrained("/data/blog/2021-08-21-link-evaluation/model")
model.eval() ; None
Code
#collapse
from typing import *
import torch

@torch.no_grad()
def infer(text: str) -> List[Tuple[str, Optional[List[str]]]]:
    tokens = tokenizer(
        text,
        return_attention_mask=False,
        return_tensors="pt",
        max_length=256,
        truncation=True,
    )["input_ids"]
    tokens = tokens.to(model.device)

    output = model(tokens)[0]
    lm_output = output[0, :, :model.config.vocab_size]
    link_tokens = lm_output.argsort(dim=-1, descending=True)[:, :50]
    is_link = (output[0, :, -1] > 0).tolist()

    text_tokens = tokenizer.batch_decode(tokens[0, :, None])

    return [
        (token, tokenizer.batch_decode(link[:10]) if is_link else None)
        for token, link, is_link in zip(text_tokens, link_tokens[:, :, None].tolist(), is_link)
    ]
Code
infer("I like to drive my Malibu")
[('<s>', None),
 ('I', None),
 (' like', None),
 (' to', None),
 (' drive', None),
 (' my', None),
 (' Mal', ['</s>', '<s>', ' of', ' ', '.', ' *', ' in', ' is', ' a', '�']),
 ('ibu', None),
 ('</s>', None)]

Well this is absolutely terrible. I’m inclined to think that the difference in scale between binary cross entropy and cross entropy is the underlying problem. Let’s have a look at the actual token overlap for the predictions and the truth.

Code
#collapse
from typing import *
import torch

@torch.no_grad()
def show_tokens(text: str) -> List[Tuple[str, Optional[List[str]]]]:
    tokens = tokenizer(
        text,
        return_attention_mask=False,
        return_tensors="pt",
        max_length=256,
        truncation=True,
    )["input_ids"]
    tokens = tokens.to(model.device)

    output = model(tokens)[0]
    lm_output = output[0, :, :model.config.vocab_size]
    link_tokens = lm_output.argsort(dim=-1, descending=True)[:, :50]
    is_link = (output[0, :, -1] > 0).tolist()

    text_tokens = tokenizer.batch_decode(tokens[0, :, None])

    return [
        (token, link[:10] if is_link else None)
        for token, link, is_link in zip(text_tokens, link_tokens[:, :, None].tolist(), is_link)
    ]
Code
from typing import *
import torch

@torch.no_grad()
def infer(text: str) -> List[Tuple[str, Optional[List[str]]]]:
    tokens = tokenizer(
        text,
        return_attention_mask=False,
        return_tensors="pt",
        max_length=256,
        truncation=True,
    )["input_ids"]
    return [
        (token, tokenizer.batch_decode(link[:10]) if link else None)
        for token, link in simple_infer(text)
    ]

@torch.no_grad()
def simple_infer(text: str) -> List[Tuple[str, Optional[List[str]]]]:
    tokens = tokenizer(
        text,
        return_attention_mask=False,
        return_tensors="pt",
        max_length=256,
        truncation=True,
    )["input_ids"]
    tokens = tokens.to(model.device)

    output = model(tokens)[0]
    lm_output = output[0, :, :model.config.vocab_size]
    link_tokens = lm_output.argsort(dim=-1, descending=True)[:, :50]
    is_link = (output[0, :, model.config.vocab_size:].argmax(dim=-1) > 0).tolist()

    text_tokens = tokenizer.batch_decode(tokens[0, :, None])

    return [
        (token, link[:10] if is_link else None)
        for token, link, is_link in zip(text_tokens, link_tokens[:, :, None].tolist(), is_link)
    ]
Code
infer("I like to drive my Malibu")
[('<s>', None),
 ('I', None),
 (' like', None),
 (' to', None),
 (' drive', None),
 (' my', None),
 (' Mal', ['</s>', '<s>', ' of', ' ', '.', ' *', ' in', ' is', ' a', '�']),
 ('ibu', ['</s>', '<s>', ' ', ' the', '.', ' *', ' in', ' a', ',', ' is']),
 ('</s>', None)]

It isn’t making a single correct token prediction in the top 10 predictions. Given how well it was performing before I can only assume this is the mismatch between the two loss scores. However, if the model is paying too much attention to the boundary targets then shouldn’t it be better at spotting Malibu? It’s very strange.