Sequence to Aspect Sentiment

Can Sequence to Sequence models be trained to extract entities and mark up sentiment?
Published

July 3, 2021

Sequence to sequence models take a text and produce another text. They can be used for tasks like translation. It’s also possible to use them to better understand the text without having a fixed classification target. This is useful when you want to extract things from the text - like locations or people or activities.


Entity Extraction

I am interested in aspect sentiment. This is the sentiment of an utterance towards an entity. For this an entity is anything that could be the subject of an emotion - people, food, locations, music etc.

The first task of the model is to extract the entities so that they can have a sentiment classification. An easy way to do this with a sequence to sequence model is to ask the model to use semaphore tokens to mark the start and end of an entity, otherwise copying the input:

Here I’ve shown the entity in the output in green to show that the words that are contained within the brackets are marked up as an entity. A sequence to sequence model can choose to issue multiple output tokens for a single input, and equally it can choose to drop tokens.


Sequence To Entity

So let’s start looking at sequence to sequence models. The first thing to do is to find a suitable dataset for this task. Given that we want to recognize nouns as they are likely to be the entities in question it might be possible to use an existing pos tagger to generate the target text?

The example spacy code for analyzing text is as follows:

Code
# pip install -U spacy
# python -m spacy download en_core_web_sm
import spacy

# Load English tokenizer, tagger, parser and NER
nlp = spacy.load("en_core_web_sm")

# Process whole documents
text = ("When Sebastian Thrun started working on self-driving cars at "
        "Google in 2007, few people outside of the company took him "
        "seriously. “I can tell you very senior CEOs of major American "
        "car companies would shake my hand and turn away because I wasn’t "
        "worth talking to,” said Thrun, in an interview with Recode earlier "
        "this week.")
doc = nlp(text)

# Analyze syntax
print("Noun phrases:", [chunk.text for chunk in doc.noun_chunks])
print("Verbs:", [token.lemma_ for token in doc if token.pos_ == "VERB"])

# Find named entities, phrases and concepts
for entity in doc.ents:
    print(entity.text, entity.label_)
Noun phrases: ['Sebastian Thrun', 'self-driving cars', 'Google', 'few people', 'the company', 'him', 'I', 'you', 'very senior CEOs', 'major American car companies', 'my hand', 'I', 'Thrun', 'an interview', 'Recode']
Verbs: ['start', 'work', 'drive', 'take', 'can', 'tell', 'would', 'shake', 'turn', 'talk', 'say']
Sebastian Thrun PERSON
Google ORG
2007 DATE
American NORP
Thrun PERSON
Recode LOC
earlier this week DATE

So I would want to exclude any date entities and otherwise mark everything up. Let’s give it a go.

Code
[entity for entity in doc.ents if entity.label_ != "DATE"]
[Sebastian Thrun, Google, American, Thrun, Recode]

The entity objects have span information so it should be possible to alter the text to wrap them in parenthesis. So I would just need a lot of text to handle.

The wikitext dataset looks reasonable for this.

Code
#hide_output
from datasets import load_dataset

wikitext = load_dataset("wikitext", "wikitext-103-v1")
Reusing dataset wikitext (/home/matthew/.cache/huggingface/datasets/wikitext/wikitext-103-v1/1.0.0/aa5e094000ec7afeb74c3be92c88313cd6f132d564c7effd961c10fd47c76f20)
Code
wikitext.keys()
dict_keys(['test', 'train', 'validation'])
Code
wikitext["train"][10]["text"][:200]
" The game 's battle system , the <unk> system , is carried over directly from <unk> Chronicles . During missions , players select each unit using a top @-@ down perspective of the battlefield map : on"

The problem here is that the text contains a lot of wiki markup. I’m not super happy training on this.

Code
#hide_output
from datasets import load_dataset

amazon_polarity = load_dataset("amazon_polarity")
Reusing dataset amazon_polarity (/home/matthew/.cache/huggingface/datasets/amazon_polarity/amazon_polarity/3.0.0/ac31acedf6cda6bc2aa50d448f48bbad69a3dd8efc607d2ff1a9e65c2476b4c1)
Code
amazon_polarity.keys()
dict_keys(['train', 'test'])
Code
amazon_polarity["train"][0]["content"]
'This sound track was beautiful! It paints the senery in your mind so well I would recomend it even to people who hate vid. game music! I have played the game Chrono Cross but out of all of the games I have ever played it has the best music! It backs away from crude keyboarding and takes a fresher step with grate guitars and soulful orchestras. It would impress anyone who cares to listen! ^_^'

So this looks good enough. After all, people do write like this so this is reasonably representative.

Entity Extraction

I want to be able to mark up the entities that are within this text. Since I don’t have the time to do this manually I am going to rely on SpaCy to do the annotation for me. The SpaCy tagger reports an accuracy of 0.97 so it should be a good start.

How can I mark up the text to show the entities? I can adjust the example code that they use to extract them, as well as coming up with a list of labels that I am not interested in. It is difficult to find a list of these possible labels, so while I was able to find this blog post which provides a list, I can’t be sure that it is complete.

Code
from typing import *
import spacy

nlp = spacy.load("en_core_web_sm")

def extract_entities(text: str) -> List[str]:
    doc = nlp(text)
    return [
        entity.text
        for entity in doc.ents
        if entity.label_ not in {
            "DATE", "TIME", "PERCENT", "MONEY", "QUANTITY", "ORDINAL", "CARDINAL"
        }
    ]

extract_entities(amazon_polarity["train"][0]["content"])
['Chrono Cross']
Code
text = amazon_polarity["train"][1]["content"]

extract_entities(text), text
(["Yasunori Mitsuda's"],
 "I'm reading a lot of reviews saying that this is the best 'game soundtrack' and I figured that I'd write a review to disagree a bit. This in my opinino is Yasunori Mitsuda's ultimate masterpiece. The music is timeless and I'm been listening to it for years now and its beauty simply refuses to fade.The price tag on this is pretty staggering I must say, but if you are going to buy any cd for this much money, this is the only one that I feel would be worth every penny.")
Code
text = amazon_polarity["train"][2]["content"]

extract_entities(text), text
(['Prisoners of Fate',
  'A Distant Promise',
  'Time',
  'Dreamwatch',
  'Chronomantique',
  'Chrono Trigger',
  'Xenogears'],
 'This soundtrack is my favorite music of all time, hands down. The intense sadness of "Prisoners of Fate" (which means all the more if you\'ve played the game) and the hope in "A Distant Promise" and "Girl who Stole the Star" have been an important inspiration to me personally throughout my teen years. The higher energy tracks like "Chrono Cross ~ Time\'s Scar~", "Time of the Dreamwatch", and "Chronomantique" (indefinably remeniscent of Chrono Trigger) are all absolutely superb as well.This soundtrack is amazing music, probably the best of this composer\'s work (I haven\'t heard the Xenogears soundtrack, so I can\'t say for sure), and even if you\'ve never played the game, it would be worth twice the price to buy it.I wish I could give it 6 stars.')
Code
[(entity.text, entity.label_) for entity in nlp(amazon_polarity["train"][2]["content"]).ents]
[('Prisoners of Fate', 'WORK_OF_ART'),
 ('A Distant Promise', 'WORK_OF_ART'),
 ('Time', 'ORG'),
 ('Dreamwatch', 'ORG'),
 ('Chronomantique', 'WORK_OF_ART'),
 ('Chrono Trigger', 'PERSON'),
 ('Xenogears', 'PERSON'),
 ('6', 'CARDINAL')]

These entities seem reasonable.

Entity Tagging

The next thing to do is to update the text to insert the brackets around the text.

Code
from typing import *
import spacy

nlp = spacy.load("en_core_web_sm")

def markup_entities(text: str) -> List[str]:
    spans = extract_spans(text)
    for span in spans[::-1]: # reverse the list to avoid changing text positions
        text = text[:span.start_char] + f"[{span.text}]" + text[span.end_char:]
    return text

def extract_spans(text: str) -> List[spacy.tokens.span.Span]:
    doc = nlp(text)
    return [
        entity
        for entity in doc.ents
        if entity.label_ not in {"DATE", "TIME", "PERCENT", "MONEY", "QUANTITY", "ORDINAL", "CARDINAL"}
    ]
Code
markup_entities(amazon_polarity["train"][2]["content"])
'This soundtrack is my favorite music of all time, hands down. The intense sadness of "[Prisoners of Fate]" (which means all the more if you\'ve played the game) and the hope in "[A Distant Promise]" and "Girl who Stole the Star" have been an important inspiration to me personally throughout my teen years. The higher energy tracks like "Chrono Cross ~ [Time]\'s Scar~", "Time of the [Dreamwatch]", and "[Chronomantique]" (indefinably remeniscent of [Chrono Trigger]) are all absolutely superb as well.This soundtrack is amazing music, probably the best of this composer\'s work (I haven\'t heard the [Xenogears] soundtrack, so I can\'t say for sure), and even if you\'ve never played the game, it would be worth twice the price to buy it.I wish I could give it 6 stars.'

This seems pretty good. So we just need to map this data into something suitable. Then it can be trained.


Seq2Seq Training

Now it’s time to train a seq2seq model for this. The first thing to do is tokenize the inputs and targets.

Code
from datasets import Dataset

train_ds = Dataset.from_dict({
    "text": amazon_polarity["train"]["content"][:1000]
})
test_ds = Dataset.from_dict({
    "text": amazon_polarity["test"]["content"][:1000]
})
Code
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")

def encode(text: str) -> Dict[str, List[int]]:
    target = markup_entities(text)
    return {
        "input_ids": tokenizer(
            text,
            return_attention_mask=False,
            padding="max_length",
            max_length=128,
            truncation=True
        )["input_ids"],
        "label": tokenizer(
            target,
            return_attention_mask=False,
            padding="max_length",
            max_length=128,
            truncation=True
        )["input_ids"]
    }

train_ds = train_ds.map(encode, input_columns=["text"])
test_ds = test_ds.map(encode, input_columns=["text"])
train_ds
Dataset({
    features: ['input_ids', 'label', 'text'],
    num_rows: 1000
})
Code
from pathlib import Path
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments

model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-base")
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")

MODEL_RUN_FOLDER = Path("/data/blog/2021-07-03-sequence-to-aspect-sentiment/runs")
MODEL_RUN_FOLDER.mkdir(parents=True, exist_ok=True)

training_args = Seq2SeqTrainingArguments(
    report_to=[],            
    output_dir=MODEL_RUN_FOLDER / "output",
    overwrite_output_dir=True,
    per_device_train_batch_size=32,
    learning_rate=5e-5,
    num_train_epochs=5,
    evaluation_strategy="epoch",
    logging_dir=MODEL_RUN_FOLDER / "output",
    logging_steps=100,
    load_best_model_at_end=True,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    tokenizer=tokenizer,
    # compute_metrics=compute_metrics,
    # optimizers=(optimizer, None)
)

trainer.train()
[160/160 01:32, Epoch 5/5]
Epoch Training Loss Validation Loss Runtime Samples Per Second
1 No log 1.101301 3.819500 261.817000
2 No log 0.375796 3.856500 259.300000
3 No log 0.128995 3.884000 257.465000
4 1.166900 0.084706 3.918600 255.195000
5 1.166900 0.072723 3.952300 253.015000

TrainOutput(global_step=160, training_loss=0.7644804030656814, metrics={'train_runtime': 92.6352, 'train_samples_per_second': 1.727, 'total_flos': 535374397440000.0, 'epoch': 5.0, 'init_mem_cpu_alloc_delta': 1222438912, 'init_mem_gpu_alloc_delta': 558658048, 'init_mem_cpu_peaked_delta': 381460480, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 398512128, 'train_mem_gpu_alloc_delta': 2325842944, 'train_mem_cpu_peaked_delta': 720609280, 'train_mem_gpu_peaked_delta': 7369723392})

So I don’t really know what the best evaluation metric is for this. If it just learns to do a direct word for word copy then it could get quite accurate very quickly. Concentrating only on the number of parenthesis would allow the words to break. Two metrics might be good?

Code
import torch

def predict(text: str, length: int = 50) -> str:
    input_ids = tokenizer(text, return_attention_mask=False)["input_ids"]
    return predict_input_ids(input_ids, length=length)

def predict_input_ids(input_ids: List[int], length: int = 50) -> str:
    output = infer_beam(input_ids, length=length)
    return tokenizer.decode(output)

@torch.no_grad()
def infer_greedy(input_ids: List[int], length: int = 50) -> List[int]:
    input_ids = torch.tensor(input_ids, dtype=torch.long, device=model.device)[None, :]
    output = model.generate(
        input_ids=input_ids,
        min_length=length,
        max_length=length,
    )
    return output[0].tolist()

@torch.no_grad()
def infer_beam(input_ids: List[int], length: int = 50) -> List[int]:
    input_ids = torch.tensor(input_ids, dtype=torch.long, device=model.device)[None, :]
    output = model.generate(
        input_ids=input_ids,
        min_length=length,
        max_length=length,
        num_beams=5,
    )
    return output[0].tolist()
    

text = "Selling Britain’s state-owned water authorities seemed like a good idea to Conservative ministers in the 1980s"
predict(text, length=20), predict(text, length=30) # I searched for this number
('</s><s>Selling Britain’s state-owned water authorities seemed like a good idea to</s>',
 '</s><s>Selling [Britain]’s state-owned water authorities seemed like a good idea to [Conservative] ministers in the 1980s</s>')
Code
markup_entities(text)
'Selling [Britain]’s state-owned water authorities seemed like a good idea to [Conservative] ministers in the 1980s'
Code
tokenizer.decode(test_ds[0]["label"])
"<s>My lovely [Pat] has one of the GREAT voices of her generation. I have listened to this CD for YEARS and I still LOVE IT. When I'm in a good mood it makes me feel better. A bad mood just evaporates like sugar in the rain. This CD just oozes [LIFE]. Vocals are jusat STUUNNING and lyrics just kill. One of life's hidden gems. This is a desert isle CD in my book. Why she never made it big is just beyond me. Everytime I play this, no matter black, white, young, old, male, female</s>"
Code
predict_input_ids(test_ds[0]["input_ids"], length=len(test_ds[0]["label"]))
"</s><s>My lovely [Pat] has one of the GREAT voices of her generation. I have listened to this CD for YEARS and I still LOVE IT. When I'm in a good mood it makes me feel better. A bad mood just evaporates like sugar in the rain. This CD just oozes LIFE. Vocals are jusat STUUNNING and lyrics just kill. One of life's hidden gems. This is a desert isle CD in my book. Why she never made it big is just beyond me. Everytime I play this, no matter black, white, young, old, male, female EVERY</s>"

It looks like working out the target length is quite desirable. It can also miss some of the examples from the amazon dataset (LIFE was marked up by spacy but not by the model).

Fundamentally this is actually performing a harder task than the ideal implementation - the ideal implementation should be heavily restricted. That implementation can:

  • Copy the current token
  • Start a bracket section, OR if in a bracket section finish it

Refining the text generation to adhere to these requirements would be good. I wonder if a three (two?) class classifier would be the best way for this model to operate?

I have to bear in mind that I want to extend this model to predicting aspect sentiment, so any refinement would have to accomodate this. A multi class classifier could still work. Ouputs could be copy, bracket start, bracket end, positive, neutral, negative; the model output would be restricted at each step.

To make such a refinement I need to better understand how the seq2seq approach works and how I can train a model like BART.


Restricted Seq2Seq

So I’ve reviewed the seq2seq code and it appears that it is feeding the sequence to the current point into the model and then extracting the predicted logit for the next token. Turning this into a prediction should be straightforward and I can use AutoModelForSequenceClassification which broadens the range of possible models significantly.

What I want to do is to fit this within the huggingface trainer. To do this I need to be able to return a loss for the model output. Let’s give this a go, starting with just the entity identification.

Sequence Alignment

This is quite a tricky thing to implement well, so I want to start by writing an example of aligning the input tokens with the label tokens. The difference between the two sequences is that the input tokens lack the markers that indicate entity boundaries. For example, the input listen to El Duke becomes listen to [ El Duke ]. Taking each separate word as a token this would lead to the following action sequence:

input current token action
listen copy
listen to copy
listen to El start entity
listen to [ El copy
listen to [ El Duke copy
listen to [ El Duke . end entity
listen to [ El Duke ] . copy

Since this is quite involved I want to be able to check my logic. I’m going to write a test to do so.

Code
from enum import Enum
from dataclasses import dataclass

T = TypeVar("T")

@dataclass
class Sequence(Generic[T]):
    tokens: List[T]
    index: int = 0

    def is_oob(self) -> bool:
        return self.index >= len(self.tokens)

    def token(self) -> T:
        return self.tokens[self.index]

    def history(self) -> List[T]:
        return self.tokens[:self.index]

    def increment(self) -> None:
        self.index += 1

class Action(Enum):
    copy = "copy"
    start_entity = "start entity"
    end_entity = "end entity"

@dataclass
class Output(Generic[T]):
    input_tokens: List[T]
    current_token: T
    action: Action
Code
def align(
    input_ids: List[T],
    labels: List[T],
    start_token: T = "[",
    end_token: T = "]",
) -> Iterator[Output]:
    marker_tokens = {start_token, end_token}
    input_sequence = Sequence(tokens=input_ids)
    label_sequence = Sequence(tokens=labels)

    while not (input_sequence.is_oob() and label_sequence.is_oob()):
        if input_sequence.is_oob() or label_sequence.is_oob():
            break # for actual data the input & label pair can be truncated
            raise AssertionError(f"sequences consumed unevenly: {input_sequence}, {label_sequence}")

        input_token = input_sequence.token()
        label_token = label_sequence.token()
        history = label_sequence.history()

        if label_token == start_token:
            yield Output(
                input_tokens=history,
                current_token=input_token,
                action=Action.start_entity
            )
        elif label_token == end_token:
            yield Output(
                input_tokens=history,
                current_token=input_token,
                action=Action.end_entity
            )
        else:
            yield Output(
                input_tokens=history,
                current_token=input_token,
                action=Action.copy
            )
            input_sequence.increment()
        label_sequence.increment()
Code
for action in align("listen to El Duke .".split(), "listen to [ El Duke ] .".split()):
    print(action)
Output(input_tokens=[], current_token='listen', action=<Action.copy: 'copy'>)
Output(input_tokens=['listen'], current_token='to', action=<Action.copy: 'copy'>)
Output(input_tokens=['listen', 'to'], current_token='El', action=<Action.start_entity: 'start entity'>)
Output(input_tokens=['listen', 'to', '['], current_token='El', action=<Action.copy: 'copy'>)
Output(input_tokens=['listen', 'to', '[', 'El'], current_token='Duke', action=<Action.copy: 'copy'>)
Output(input_tokens=['listen', 'to', '[', 'El', 'Duke'], current_token='.', action=<Action.end_entity: 'end entity'>)
Output(input_tokens=['listen', 'to', '[', 'El', 'Duke', ']'], current_token='.', action=<Action.copy: 'copy'>)
Code
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2", pad_token="<|endoftext|>")
tokenizer(" [ ]")
{'input_ids': [685, 2361], 'attention_mask': [1, 1]}
Code
tokenizer.get_vocab()["Ġ["], tokenizer.get_vocab()["Ġ]"]
(685, 2361)
Code
for action in align(
    tokenizer("listen to El Duke .")["input_ids"],
    tokenizer("listen to [ El Duke ] .")["input_ids"],
    start_token=tokenizer.get_vocab()["Ġ["],
    end_token=tokenizer.get_vocab()["Ġ]"]
):
    print(action)
Output(input_tokens=[], current_token=4868, action=<Action.copy: 'copy'>)
Output(input_tokens=[4868], current_token=268, action=<Action.copy: 'copy'>)
Output(input_tokens=[4868, 268], current_token=284, action=<Action.copy: 'copy'>)
Output(input_tokens=[4868, 268, 284], current_token=2574, action=<Action.start_entity: 'start entity'>)
Output(input_tokens=[4868, 268, 284, 685], current_token=2574, action=<Action.copy: 'copy'>)
Output(input_tokens=[4868, 268, 284, 685, 2574], current_token=11083, action=<Action.copy: 'copy'>)
Output(input_tokens=[4868, 268, 284, 685, 2574, 11083], current_token=764, action=<Action.end_entity: 'end entity'>)
Output(input_tokens=[4868, 268, 284, 685, 2574, 11083, 2361], current_token=764, action=<Action.copy: 'copy'>)
Code
for action in align(
    tokenizer("listen to El Duke .", return_tensors="pt")["input_ids"][0],
    tokenizer("listen to [ El Duke ] .", return_tensors="pt")["input_ids"][0],
    start_token=tokenizer.get_vocab()["Ġ["],
    end_token=tokenizer.get_vocab()["Ġ]"]
):
    print(action)
Output(input_tokens=tensor([], dtype=torch.int64), current_token=tensor(4868), action=<Action.copy: 'copy'>)
Output(input_tokens=tensor([4868]), current_token=tensor(268), action=<Action.copy: 'copy'>)
Output(input_tokens=tensor([4868,  268]), current_token=tensor(284), action=<Action.copy: 'copy'>)
Output(input_tokens=tensor([4868,  268,  284]), current_token=tensor(2574), action=<Action.start_entity: 'start entity'>)
Output(input_tokens=tensor([4868,  268,  284,  685]), current_token=tensor(2574), action=<Action.copy: 'copy'>)
Output(input_tokens=tensor([4868,  268,  284,  685, 2574]), current_token=tensor(11083), action=<Action.copy: 'copy'>)
Output(input_tokens=tensor([ 4868,   268,   284,   685,  2574, 11083]), current_token=tensor(764), action=<Action.end_entity: 'end entity'>)
Output(input_tokens=tensor([ 4868,   268,   284,   685,  2574, 11083,  2361]), current_token=tensor(764), action=<Action.copy: 'copy'>)

I’m pretty happy with this algorithm. It produces the sequence as desired and it even works with tensors. Since this will be done repeatedly lets see how fast it is.

Code
input_tokens = tokenizer("listen to El Duke .", return_tensors="pt")["input_ids"][0]
label_tokens = tokenizer("listen to [ El Duke ] .", return_tensors="pt")["input_ids"][0]
start_token=tokenizer.get_vocab()["Ġ["]
end_token=tokenizer.get_vocab()["Ġ]"]
Code
%%timeit

list(align(
    input_tokens,
    label_tokens,
    start_token=start_token,
    end_token=end_token
))
113 µs ± 449 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

It might seem fast but this is pretty slow. This step needs to be done for every sequence in the input and most sequences will be much longer than this. I should optimize this.

Code
import torch

def fast_align(
    input_ids: torch.Tensor,
    labels: torch.Tensor,
    start_token: int = tokenizer.get_vocab()["Ġ["],
    end_token: int = tokenizer.get_vocab()["Ġ]"],
) -> Tuple[torch.Tensor, Action]:
    input_index = 0
    for label_index in range(labels.shape[0]):
        label_token = labels[label_index]
        tokens = torch.cat([labels[:label_index], input_ids[input_index][None]]) # this costs ~3ms

        if label_token == start_token:
            yield (tokens, Action.start_entity)
        elif label_token == end_token:
            yield (tokens, Action.end_entity)
        else:
            yield (tokens, Action.copy)
            input_index += 1
Code
for action in fast_align(
    input_tokens,
    label_tokens,
    start_token=start_token,
    end_token=end_token
):
    print(action)
(tensor([4868]), <Action.copy: 'copy'>)
(tensor([4868,  268]), <Action.copy: 'copy'>)
(tensor([4868,  268,  284]), <Action.copy: 'copy'>)
(tensor([4868,  268,  284, 2574]), <Action.start_entity: 'start entity'>)
(tensor([4868,  268,  284,  685, 2574]), <Action.copy: 'copy'>)
(tensor([ 4868,   268,   284,   685,  2574, 11083]), <Action.copy: 'copy'>)
(tensor([ 4868,   268,   284,   685,  2574, 11083,   764]), <Action.end_entity: 'end entity'>)
(tensor([ 4868,   268,   284,   685,  2574, 11083,  2361,   764]), <Action.copy: 'copy'>)
Code
%%timeit

list(fast_align(
    input_tokens,
    label_tokens,
    start_token=start_token,
    end_token=end_token
))
122 µs ± 378 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

This is slower, however it returns data that is more in line with what is actually wanted. The problem is the use of torch.cat to add the label onto the history.

More broadly this needs to work by batch. Creating the batches might be faster as it can do more tensor operations. What is tricky here is the repeated inputs that need to be introduced.

Code
# alignining this way means the target index is always 1
active_indexes = torch.tensor([
    [1, 0], # action is copy,  [start, copy]
    [0, 1], # action is start, [copy,  start]
    [0, 2], # action is end,   [copy,  end]
])

def batch_align(
    input_ids: torch.Tensor,
    labels: torch.Tensor,
    start_token: int = tokenizer.get_vocab()["Ġ["],
    end_token: int = tokenizer.get_vocab()["Ġ]"],
    pad_token: int = tokenizer.pad_token_id,
    active_indexes: torch.Tensor = active_indexes,
    device: torch.device = torch.device("cpu")
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    label_size = labels.shape[0]

    block = torch.ones((label_size, label_size), dtype=torch.long, device=device)
    history_mask = block.tril(diagonal=-1)
    attention_mask = block.tril()
    padding = block.triu(diagonal=1) * pad_token

    batch = labels.repeat((label_size, 1)) * history_mask

    # this makes the input repeat every time there is a marker token in the label
    # given input:                  token marker token marker token
    # becomes a repetition mask of: True  False  True  False  True
    # which cumsums to:             1     1      2     2      3
    # the associated input_id is one less
    # there is a special case where the first token is a marker, to handle that need to max(0)
    start_token_indices = (labels == start_token)
    end_token_indices = (labels == end_token)
    repetition_mask = ~(start_token_indices | end_token_indices)

    # repetition indices is currently wrong, it repeats the token prior to the entity marker
    # offset by one to address problem with alignment of repetition indices
    # this offsetting costs ~1.5ms !
    repetition_mask = torch.cat([
        torch.tensor([True], device=device),
        repetition_mask[:-1]
    ])
    repetition_indices = repetition_mask.cumsum(dim=0) - 1
    
    batch_diagonal = input_ids[repetition_indices]
    
    batch = batch + batch_diagonal.diag() + padding
    
    # the targets are correctly aligned
    targets = active_indexes[
        start_token_indices.long() + (end_token_indices * 2)
    ]
    
    return batch, attention_mask, targets
Code
batch_align(
    input_tokens,
    label_tokens,
    start_token=start_token,
    end_token=end_token
)
(tensor([[ 4868, 50256, 50256, 50256, 50256, 50256, 50256, 50256],
         [ 4868,   268, 50256, 50256, 50256, 50256, 50256, 50256],
         [ 4868,   268,   284, 50256, 50256, 50256, 50256, 50256],
         [ 4868,   268,   284,  2574, 50256, 50256, 50256, 50256],
         [ 4868,   268,   284,   685,  2574, 50256, 50256, 50256],
         [ 4868,   268,   284,   685,  2574, 11083, 50256, 50256],
         [ 4868,   268,   284,   685,  2574, 11083,   764, 50256],
         [ 4868,   268,   284,   685,  2574, 11083,  2361,   764]]),
 tensor([[1, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 0],
         [1, 1, 1, 1, 1, 1, 1, 1]]),
 tensor([[1, 0],
         [1, 0],
         [1, 0],
         [0, 1],
         [1, 0],
         [1, 0],
         [0, 2],
         [1, 0]]))
Code
%%timeit

batch_align(
    input_tokens,
    label_tokens,
    start_token=start_token,
    end_token=end_token
)
101 µs ± 99.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

These tests have been performed with a very short sequence. The real test would be to compare against an actual sequence from the dataset. Such a sequence would be considerably longer.

I also want to be able to truncate the sequences to keep the training quick. Ideally the maximum sequence length will correspond to the maximum batch size so an entire sequence can be processed in a single pass.

Code
%%timeit

list(align(
    torch.tensor(test_ds[0]["input_ids"], dtype=torch.long),
    torch.tensor(test_ds[0]["label"], dtype=torch.long),
    start_token=start_token,
    end_token=end_token
))
2.04 ms ± 4.59 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Code
%%timeit

list(fast_align(
    torch.tensor(test_ds[0]["input_ids"], dtype=torch.long),
    torch.tensor(test_ds[0]["label"], dtype=torch.long),
    start_token=start_token,
    end_token=end_token
))
2.29 ms ± 3.06 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Code
%%timeit

batch_align(
    torch.tensor(test_ds[0]["input_ids"], dtype=torch.long).cuda(),
    torch.tensor(test_ds[0]["label"], dtype=torch.long).cuda(),
    start_token=start_token,
    end_token=end_token,
    active_indexes=active_indexes.cuda(),
    device=torch.device("cuda")
)
549 µs ± 11.1 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
Code
%%timeit

batch_align(
    torch.tensor(test_ds[0]["input_ids"], dtype=torch.long).cpu(),
    torch.tensor(test_ds[0]["label"], dtype=torch.long).cpu(),
    start_token=start_token,
    end_token=end_token,
    active_indexes=active_indexes.cpu(),
    device=torch.device("cpu")
)
474 µs ± 122 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

I think that this shows how well the batch based approach works. It’s interesting that the “fast” version is slower. Optimizing python is hard.

The final part of this is to incorporate the third entry in the tuple to determine the loss of the output. This entry is the indicies of the model logits that are of interest given the current context.

Code
_, _, indicies = batch_align(
    torch.tensor(test_ds[0]["input_ids"], dtype=torch.long).cpu(),
    torch.tensor(test_ds[0]["label"], dtype=torch.long).cpu(),
    start_token=start_token,
    end_token=end_token,
    active_indexes=active_indexes.cpu(),
    device=torch.device("cpu")
)

example_output = torch.rand(indicies.shape[0], 3)
Code
example_output.shape, torch.cat([
    example_output[range(128), indicies[:, 0]][:, None],
    example_output[range(128), indicies[:, 1]][:, None],
], dim=1).shape
(torch.Size([128, 3]), torch.Size([128, 2]))

Unfortunately I can’t find a way to select the subset of the outputs using a single expression. It feels like there is a better way to do this, for now it will have to do.

Code
%%timeit

torch.nn.functional.cross_entropy(
    torch.cat([
        example_output[range(indicies.shape[0]), indicies[:, 0]][:, None],
        example_output[range(indicies.shape[0]), indicies[:, 1]][:, None],
    ], dim=1),
    torch.ones(indicies.shape[0], dtype=torch.long)
)
72.6 µs ± 63.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Regenerate Dataset

The dataset needs to be regenerated because the entity marks now need spaces surrounding them. I’m also going to move to GPT2 as we are no longer restricted to BART due to moving from seq2seq. So lets get on with it.

Code
from typing import *
import spacy
import regex as re

nlp = spacy.load("en_core_web_sm")

def markup_entities(text: str) -> List[str]:
    # strip any existing entity markers
    text = re.sub(r"[\[\]\s]+", " ", text)
    spans = extract_spans(text)
    for span in spans[::-1]: # reverse the list to avoid changing text positions
        text = text[:span.start_char].rstrip() + f" [ {span.text} ] " + text[span.end_char:].lstrip()
    return text

def extract_spans(text: str) -> List[spacy.tokens.span.Span]:
    doc = nlp(text)
    return [
        entity
        for entity in doc.ents
        if entity.label_ not in {"DATE", "TIME", "PERCENT", "MONEY", "QUANTITY", "ORDINAL", "CARDINAL"}
    ]
Code
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2", pad_token="<|endoftext|>")
SIZE = 64

def encode(text: str) -> Dict[str, List[int]]:
    target = markup_entities(text)
    return {
        "input_ids": tokenizer(
            text,
            return_attention_mask=False,
            # padding="max_length",
            max_length=SIZE,
            truncation=True
        )["input_ids"],
        "label": tokenizer(
            target,
            return_attention_mask=False,
            # padding="max_length",
            max_length=SIZE,
            truncation=True
        )["input_ids"]
    }
Code
from datasets import Dataset

train_ds = Dataset.from_dict({
    "text": amazon_polarity["train"]["content"][:1000]
})
test_ds = Dataset.from_dict({
    "text": amazon_polarity["test"]["content"][:1000]
})

train_ds = train_ds.map(encode, input_columns=["text"])
test_ds = test_ds.map(encode, input_columns=["text"])
train_ds
Dataset({
    features: ['input_ids', 'label', 'text'],
    num_rows: 1000
})
Code
from transformers import AutoModelForSequenceClassification, GPT2ForSequenceClassification, AutoConfig
import torch
from typing import *

# alignining this way means the target index is always 1
active_indexes = torch.tensor([
    [1, 0], # action is copy,  [start, copy]
    [0, 1], # action is start, [copy,  start]
    [0, 2], # action is end,   [copy,  end]
])

class EntitySequenceClassifier(
    # AutoModelForSequenceClassification
    GPT2ForSequenceClassification
):
    def __init__(self, config: AutoConfig) -> None:
        config.num_labels = 3 # copy token, start entity, end entity
        config.pad_token_id = tokenizer.pad_token_id
        super().__init__(config)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        try:
            batch, attention_mask, targets = self.batch_align(
                input_ids=input_ids[0],
                labels=labels[0],
                device=self.device,
            )

            batch_size = batch.shape[0]
            output = super().forward(input_ids=batch, attention_mask=attention_mask).logits
            restricted_output = torch.cat([
                output[range(batch_size), targets[:, 0]][:, None],
                output[range(batch_size), targets[:, 1]][:, None],
            ], dim=1)

            loss = torch.nn.functional.cross_entropy(
                restricted_output,
                torch.ones(batch_size, dtype=torch.long, device=self.device)
            )

            return (loss, output)
        except:
            print(f"FAILED: {input_ids.shape}, {labels.shape}, {input_ids[0, :5]}")
            raise

    @staticmethod
    def diagonalize_input(
        input_ids: torch.Tensor,
        pad_token: int = tokenizer.pad_token_id,
    ) -> torch.Tensor:
        """ This converts the input_ids from a sequence into a batch of increasing size.
            Given [1, 2, 3]
            this returns [[1, pad, pad],
                          [1, 2,   pad],
                          [1, 2,   3  ]]
            This can then be used to determine the output for each token in the sequence. """
        # this is no good, need to predict the start of entities and then incorporate that in the input
        batch_size = input_ids.shape[0]
        block = torch.ones((batch_size, batch_size))

        input_block = block.tril() * input_ids.repeat((batch_size, 1))
        padding = block.triu(diagonal=1) * pad_token

        return input_block + padding

    @staticmethod
    def diagonalize_labels(input_ids: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """ This converts the input_ids and labels from a sequence into a batch of increasing size.
            The labels are used for teacher forcing to provide the correct context to the model.
            The input_ids are to provide the final prediction input per row.
            The 
            
            Given input_ids: [i1, i2, i3] and history [h1, start, h2, end, h3] """
    @staticmethod
    def batch_align(
        input_ids: torch.Tensor,
        labels: torch.Tensor,
        start_token: int = tokenizer.get_vocab()["Ġ["],
        end_token: int = tokenizer.get_vocab()["Ġ]"],
        pad_token: int = tokenizer.pad_token_id,
        active_indexes: torch.Tensor = active_indexes,
        device: torch.device = torch.device("cpu")
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        input_size = input_ids.shape[0]
        label_size = labels.shape[0]

        # this makes the input repeat every time there is a marker token in the label
        # given input:                  token marker token marker token
        # becomes a repetition mask of: True  False  True  False  True
        # which cumsums to:             1     1      2     2      3
        # the associated input_id is one less
        # there is a special case where the first token is a marker, to handle that need to max(0)
        start_token_indices = (labels == start_token)
        end_token_indices = (labels == end_token)
        repetition_mask = ~(start_token_indices | end_token_indices)

        # repetition indices is currently wrong, it repeats the token prior to the entity marker
        # offset by one to address problem with alignment of repetition indices
        # this offsetting costs ~1.5ms !
        repetition_mask = torch.cat([
            torch.tensor([True], device=device),
            repetition_mask[:-1]
        ])
        repetition_indices = repetition_mask.cumsum(dim=0) - 1

        if repetition_indices.max() >= input_size:
            # Index 728 has a problem where the repetition mask goes beyond the label range
            # that needs further investigation, for now just truncate
            label_size = (repetition_indices == input_size).long().argmax()
            labels = labels[:label_size]
            repetition_indices = repetition_indices[:label_size]
            start_token_indices = start_token_indices[:label_size]
            end_token_indices = end_token_indices[:label_size]

        block = torch.ones((label_size, label_size), dtype=torch.long, device=device)
        history_mask = block.tril(diagonal=-1)
        attention_mask = block.tril()
        padding = block.triu(diagonal=1) * pad_token

        batch_diagonal = input_ids[repetition_indices]

        batch = labels.repeat((label_size, 1)) * history_mask
        batch = batch + batch_diagonal.diag() + padding

        # the targets are correctly aligned
        targets = active_indexes.to(device)[
            start_token_indices.long() + (end_token_indices * 2)
        ]

        return batch, attention_mask, targets
Code
model = EntitySequenceClassifier.from_pretrained("gpt2")
type(model)
Some weights of EntitySequenceClassifier were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
__main__.EntitySequenceClassifier
Code
from pathlib import Path
from transformers import Trainer, TrainingArguments

MODEL_RUN_FOLDER = Path("/data/blog/2021-07-03-sequence-to-aspect-sentiment/runs")
MODEL_RUN_FOLDER.mkdir(parents=True, exist_ok=True)

training_args = TrainingArguments(
    report_to=[],            
    output_dir=MODEL_RUN_FOLDER / "output",
    overwrite_output_dir=True,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    learning_rate=5e-5,
    num_train_epochs=5,
    evaluation_strategy="epoch",
    logging_dir=MODEL_RUN_FOLDER / "output",
    logging_steps=100,
    load_best_model_at_end=True,
    
    # no_cuda=True, # debugging
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    tokenizer=tokenizer,
    # compute_metrics=compute_metrics,
    # optimizers=(optimizer, None)
)

trainer.train()
[5000/5000 30:30, Epoch 5/5]
Epoch Training Loss Validation Loss Runtime Samples Per Second
1 0.036700 0.049386 77.091200 12.972000
2 0.028600 0.039281 77.222500 12.950000
3 0.019100 0.049895 77.531800 12.898000
4 0.014100 0.053474 77.726500 12.866000
5 0.009000 0.059107 77.843600 12.846000

TrainOutput(global_step=5000, training_loss=0.025022708290815352, metrics={'train_runtime': 1830.5887, 'train_samples_per_second': 2.731, 'total_flos': 209667536824320.0, 'epoch': 5.0, 'init_mem_cpu_alloc_delta': -154005504, 'init_mem_gpu_alloc_delta': 511157248, 'init_mem_cpu_peaked_delta': 154136576, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': -110075904, 'train_mem_gpu_alloc_delta': 2044350464, 'train_mem_cpu_peaked_delta': 129765376, 'train_mem_gpu_peaked_delta': 4782121472})
Code
model.save_pretrained(Path("/data/blog/2021-07-03-sequence-to-aspect-sentiment/model"))
Code
from pathlib import Path

model = EntitySequenceClassifier.from_pretrained(Path("/data/blog/2021-07-03-sequence-to-aspect-sentiment/model"))

Evaluation

I need to be able to run text through this model and see how well it does. So I’m going to rewrite it a bit to allow for this.

The way that the training goes is this. Given the dataset entry:

\[ \begin{aligned} \text{input} &= [ \text{input}_1, \text{input}_2, \text{input}_3 ] \\ \text{label} &= [ \text{label}_1, \text{label}_{start}, \text{label}_2, \text{label}_{end}, \text{label}_3 ] \end{aligned} \]

We first take the input and label and work out where the boundaries are. We use these boundaries to repeat the input where there is a boundary.

\[ \begin{aligned} \text{boundaries} &= [ \text{False}, \text{True}, \text{False}, \text{True}, \text{False} ] \\ \text{expanded input} &= [ \text{input}_1, \text{input}_2, \text{input}_2, \text{input}_3, \text{input}_3 ] \end{aligned} \]

We can then form the labels into a triangle ready to overlay this expanded input across the diagonal.

\[ \begin{array}{ c c c c c c } \text{expanded label} = &[ [ \text{blank}, & \text{blank}, & \text{blank}, & \text{blank}, & \text{blank} ], \\ & [ \text{label}_1, & \text{blank}, & \text{blank}, & \text{blank}, & \text{blank} ], \\ & [ \text{label}_1, & \text{label}_{start}, & \text{blank}, & \text{blank}, & \text{blank} ], \\ & [ \text{label}_1, & \text{label}_{start}, & \text{label}_2, & \text{blank}, & \text{blank} ], \\ & [ \text{label}_1, & \text{label}_{start}, & \text{label}_2, & \text{label}_{end}, & \text{blank} ] ] \end{array} \]

Then we can add the input, and this is then ready to be passed to the model.

\[ \begin{array}{ c c c c c c } \text{full input} = &[ [ \text{input}_1, & \text{blank}, & \text{blank}, & \text{blank}, & \text{blank} ], \\ & [ \text{label}_1, & \text{input}_2, & \text{blank}, & \text{blank}, & \text{blank} ], \\ & [ \text{label}_1, & \text{label}_{start}, & \text{input}_2, & \text{blank}, & \text{blank} ], \\ & [ \text{label}_1, & \text{label}_{start}, & \text{label}_2, & \text{input}_3, & \text{blank} ], \\ & [ \text{label}_1, & \text{label}_{start}, & \text{label}_2, & \text{label}_{end}, & \text{input}_3 ] ] \end{array} \]

We can then take the outputs at each point in the diagonal to determine what the corresponding action is - copy, start entity or end entity.

Bearing this in mind we can see that the model is trained to predict left to right and it needs correct history to predict. This means that a prediction requires repeated processing, so it’s quite poor efficiency. There are some improvements to the model structure that can be done to improve this.

For now let’s work on the code to do inference over a single row.

Code
from transformers import AutoModelForSequenceClassification, GPT2ForSequenceClassification, AutoConfig
import torch
from typing import *

# alignining this way means the target index is always 1
active_indexes = torch.tensor([
    [1, 0], # action is copy,  [start, copy]
    [0, 1], # action is start, [copy,  start]
    [0, 2], # action is end,   [copy,  end]
])

class InferenceEntitySequenceClassifier(
    # AutoModelForSequenceClassification
    GPT2ForSequenceClassification
):
    def __init__(self, config: AutoConfig) -> None:
        config.num_labels = 3 # copy token, start entity, end entity
        config.pad_token_id = tokenizer.pad_token_id
        super().__init__(config)

    def entities(
        self,
        input_ids: torch.Tensor,
        start_token: int = tokenizer.get_vocab()["Ġ["],
        end_token: int = tokenizer.get_vocab()["Ġ]"],
    ) -> List[List[int]]:
        history = []
        entities = []
        current_entity = None
        
        index = 0
        length = input_ids.shape[0]
        while index < length:
            current_token = input_ids[index].item()
            current_input = torch.tensor(
                history + [current_token], dtype=torch.long, device=self.device
            )
            output = super().forward(input_ids=current_input[None, :]).logits[0]
            if current_entity is None:
                # can copy or start entity
                if output[[0, 1]].argmax() == 1: # start
                    history.append(start_token)
                    current_entity = []
                else:
                    history.append(current_token)
                    index += 1
            else:
                # can copy or end entity
                if output[[0, 2]].argmax() == 1: # end
                    history.append(end_token)
                    entities.append(current_entity)
                    current_entity = None
                else:
                    history.append(current_token)
                    current_entity.append(current_token)
                    index += 1
        if current_entity is not None:
            entities.append(current_entity)
        return entities
Code
from pathlib import Path

model = InferenceEntitySequenceClassifier.from_pretrained(Path("/data/blog/2021-07-03-sequence-to-aspect-sentiment/model"))
Code
text = "Selling Britain’s state-owned water authorities seemed like a good idea to Conservative ministers in the 1980s"
model.entities(input_ids=tokenizer(text, return_tensors="pt")["input_ids"][0])
[[], [], []]
Code
test_ds[10]["text"]
'I currently live in Europe, and this is the book I recommend for my visitors. It covers many countries, colour pictures, and is a nice starter for before you go, and once you are there.'
Code
tokenizer.decode(test_ds[10]["label"])
'I currently live in [ Europe ], and this is the book I recommend for my visitors. It covers many countries, colour pictures, and is a nice starter for before you go, and once you are there.'
Code
model.entities(input_ids=tokenizer(test_ds[10]["text"], return_tensors="pt")["input_ids"][0])
[[], [], [], []]

So this is broken at the moment. The entities finish immediately. I can force it to include at least one token…

Code
from transformers import AutoModelForSequenceClassification, GPT2ForSequenceClassification, AutoConfig
import torch
from typing import *

# alignining this way means the target index is always 1
active_indexes = torch.tensor([
    [1, 0], # action is copy,  [start, copy]
    [0, 1], # action is start, [copy,  start]
    [0, 2], # action is end,   [copy,  end]
])

class InferenceEntitySequenceClassifier(
    # AutoModelForSequenceClassification
    GPT2ForSequenceClassification
):
    def __init__(self, config: AutoConfig) -> None:
        config.num_labels = 3 # copy token, start entity, end entity
        config.pad_token_id = tokenizer.pad_token_id
        super().__init__(config)

    def entities(
        self,
        input_ids: torch.Tensor,
        start_token: int = tokenizer.get_vocab()["Ġ["],
        end_token: int = tokenizer.get_vocab()["Ġ]"],
    ) -> List[List[int]]:
        history = []
        entities = []
        current_entity = None
        
        index = 0
        length = input_ids.shape[0]
        while index < length:
            current_token = input_ids[index].item()
            current_input = torch.tensor(
                history + [current_token], dtype=torch.long, device=self.device
            )
            output = super().forward(input_ids=current_input[None, :]).logits[0]
            if current_entity is None:
                # can copy or start entity
                if output[[0, 1]].argmax() == 1: # start
                    history.append(start_token)
                    current_entity = []
                else:
                    history.append(current_token)
                    index += 1
            elif not current_entity:
                # force entities to have at least one token
                history.append(current_token)
                current_entity.append(current_token)
                index += 1
            else:
                # can copy or end entity
                if output[[0, 2]].argmax() == 1: # end
                    history.append(end_token)
                    entities.append(current_entity)
                    current_entity = None
                else:
                    history.append(current_token)
                    current_entity.append(current_token)
                    index += 1
        if current_entity is not None:
            entities.append(current_entity)
        return entities
Code
from pathlib import Path

model = InferenceEntitySequenceClassifier.from_pretrained(Path("/data/blog/2021-07-03-sequence-to-aspect-sentiment/model"))
Code
text = "Selling Britain’s state-owned water authorities seemed like a good idea to Conservative ministers in the 1980s"
output = model.entities(input_ids=tokenizer(text, return_tensors="pt")["input_ids"][0])
tokenizer.batch_decode(output), output
([' Britain', ' Conservative'], [[5491], [11132]])
Code
output = model.entities(input_ids=tokenizer(test_ds[10]["text"], return_tensors="pt")["input_ids"][0])
tokenizer.batch_decode(output), output
([' Europe'], [[2031]])

So forcing every entity to have at least one token dramatically improves the output. This links into some of the improvements that I want to make:

  • start action can be optimized to start and copy as there is always at least one token in an entity
  • can have a start copy AND end action if start and end are above a threshold for a single token
  • copy could be implicit - start and end could be thresholded instead
  • A model could predict over every token in a single pass - as the outputs are inspected from left to right the state could be maintained so that the appropriate outputs are considered
  • This would mean a move to multi label loss (BCE with logits?) as the model predictions for start and end become independent

I’m going to implement this stuff in another notebook though. This one is already long and very messy.