Single Pass Entity Extraction

Marking up entities in a single pass
Published

July 17, 2021

I’ve recently been working on entity extraction using deep learning models. The aim is to train a model to extract the entities and then use that same model to determine the sentiment of the text to that entity. This is called aspect sentiment.

The first part of achieving this is investigating how to extract the entities. I vastly over complicated this process in the last blog post by making it into a sequence 2 sequence problem where the entities can be surrounded by semaphores.

In this post I am going to revise the approach to be simpler and more computationally efficient.


Entity Boundary Prediction

A sequence to sequence model takes an input sequence and produces an output sequence. I was using this as a way to add markers to the tokens that are entity boundaries (the start or end of an entity). This then required processing each token in turn to see if the model was going to mark it as an entity boundary. So running and training the model required running the model once for each token in the sequence.

The underlying model that is used produces an output per token in the sequence already. If I run that output through a linear layer then I can get predictions for every token in the sequence in a single pass. Since I want entities I can do two things - either predict if the token is part of an entity, or predict the entity boundaries.

If the model predicts that a token is part of an entity then two sequential entities will be treated as one. If the model predicts entity boundaries then it can distinguish them. Predicting the boundaries is slightly harder so lets start with that.

Code
MODEL_NAME = "facebook/bart-base"
MAXIMUM_TOKEN_LENGTH = 64
BATCH_SIZE = 64

Dataset Preparation

The first thing is to take the text and establish which tokens start or end the entities. Spacy produces spans that have the start and end character, and the huggingface tokenizer can also return the token offsets. So we can link up the tokens to the entity spans.

Since we are tokenizing the text, which is required to train the model, we can map the dataset completely to the required form. The spans returned by spacy are in order so we can pass through the tokens and work out if they are a boundary.

Code
#collapse
from typing import *

import spacy
import regex as re
from transformers import AutoTokenizer

nlp = spacy.load("en_core_web_sm")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)


def encode(text: str) -> Dict[str, Any]:
    spans = extract_spans(text)
    span_starts = {span.start_char for span in spans}
    span_ends = {span.end_char for span in spans}

    tokenized_text = tokenizer(
        text,
        return_offsets_mapping=True,
        max_length=MAXIMUM_TOKEN_LENGTH,
        truncation=True,
        padding="max_length"
    )
    offset_mapping = tokenized_text["offset_mapping"]

    boundaries = [
        (
            start in span_starts and start != end,
            end in span_ends and start != end
        )
        for start, end in offset_mapping
    ]
    return {
        "input_ids": tokenized_text["input_ids"],
        "attention_mask": tokenized_text["attention_mask"],
        "label": boundaries,
    }

def extract_spans(text: str) -> List[spacy.tokens.span.Span]:
    doc = nlp(text)
    return [
        entity
        for entity in doc.ents
        if entity.label_ not in {
            "DATE", "TIME", "PERCENT", "MONEY",
            "QUANTITY", "ORDINAL", "CARDINAL"
        }
    ]

Does this approach work? Let’s see what we think the entities are in this text…

Code
text = "Selling Britain's state-owned water authorities seemed like a good idea to Conservative ministers in the 1980s"
output = encode(text)

tokenizer.batch_decode([
    [input_id]
    for input_id, boundaries in zip(output["input_ids"], output["label"])
    if True in boundaries
])
[' Britain', ' Conservative']

That’s great, so lets prepare our data.

Code
from datasets import load_dataset, Dataset

amazon_polarity = load_dataset("amazon_polarity")

train_ds = Dataset.from_dict({
    "text": amazon_polarity["train"]["content"][:1000]
})
test_ds = Dataset.from_dict({
    "text": amazon_polarity["test"]["content"][:1000]
})

train_ds = train_ds.map(encode, input_columns=["text"])
test_ds = test_ds.map(encode, input_columns=["text"])
train_ds
Reusing dataset amazon_polarity (/home/matthew/.cache/huggingface/datasets/amazon_polarity/amazon_polarity/3.0.0/ac31acedf6cda6bc2aa50d448f48bbad69a3dd8efc607d2ff1a9e65c2476b4c1)
Dataset({
    features: ['attention_mask', 'input_ids', 'label', 'text'],
    num_rows: 1000
})

Entity Prediction Model

Writing this is dramatically simpler than before. It’s a two output classifier which predicts \([ \text{start entity}, \text{end entity} ]\). Since both can be true at once (the token is a single token entity) this is a multi label classifier. As such binary cross entropy is the loss function to use.

Code
from typing import *
from transformers import BartModel, AutoConfig
import torch

class EntitySequenceClassifier(BartModel):
    def __init__(self, config: AutoConfig) -> None:
        config.num_labels = 2 # start and copy, end and copy
        super().__init__(config)
        # bart model for sequence classification actually has a more complex classification head
        self.score = torch.nn.Linear(
            in_features=config.d_model,
            out_features=config.num_labels,
            bias=False,
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, ...]:
        outputs = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        hidden_states = outputs[0]  # last hidden state
        predictions = self.score(hidden_states)

        if labels is not None:
            loss = torch.nn.functional.binary_cross_entropy_with_logits(
                predictions,
                labels.float(),
            )
            return (loss, predictions)
        return (predictions,)

Training

We have our dataset and the model, lets try training it. At some point I should write some metrics for this. Since it is a vastly more restricted problem than before I think it might be possible to use sklearn classification report.

Code
model = EntitySequenceClassifier.from_pretrained(MODEL_NAME)
Some weights of EntitySequenceClassifier were not initialized from the model checkpoint at facebook/bart-base and are newly initialized: ['model.score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Code
from pathlib import Path
from transformers import Trainer, TrainingArguments

MODEL_RUN_FOLDER = Path("/data/blog/2021-07-17-entity-extraction/runs")
MODEL_RUN_FOLDER.mkdir(parents=True, exist_ok=True)

training_args = TrainingArguments(
    report_to=[],            
    output_dir=MODEL_RUN_FOLDER / "output",
    overwrite_output_dir=True,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=5e-5,
    num_train_epochs=5,
    evaluation_strategy="epoch",
    logging_dir=MODEL_RUN_FOLDER / "output",
    logging_steps=100,
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    tokenizer=tokenizer,
    # compute_metrics=compute_metrics,
)

trainer.train()
[80/80 00:47, Epoch 5/5]
Epoch Training Loss Validation Loss Runtime Samples Per Second
1 No log 0.040026 1.360600 734.968000
2 No log 0.031034 1.369500 730.175000
3 No log 0.027327 1.384400 722.352000
4 No log 0.026527 1.377100 726.186000
5 No log 0.025954 1.378300 725.523000

TrainOutput(global_step=80, training_loss=0.04551212191581726, metrics={'train_runtime': 48.0473, 'train_samples_per_second': 1.665, 'total_flos': 267690147840000.0, 'epoch': 5.0, 'init_mem_cpu_alloc_delta': 0, 'init_mem_gpu_alloc_delta': 560035840, 'init_mem_cpu_peaked_delta': 0, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 23236608, 'train_mem_gpu_alloc_delta': 2342675968, 'train_mem_cpu_peaked_delta': 154267648, 'train_mem_gpu_peaked_delta': 3379756032})
Code
model.save_pretrained(Path("/data/blog/2021-07-17-entity-extraction/model"))

Evaluation

A proper evaluation would be nice. For now let’s see what entities it can take out of this text.

Code
text = "Selling Britain's state-owned water authorities seemed like a good idea to Conservative ministers in the 1980s"
tokenized_text = tokenizer(text, return_tensors="pt")

with torch.no_grad():
    input_ids = tokenized_text["input_ids"].to(model.device)
    output = model(input_ids=input_ids)[0]
    output = output > 0.

tokenizer.batch_decode([
    [input_id]
    for input_id, boundaries in zip(tokenized_text["input_ids"][0], output[0])
    if True in boundaries
])
[' Britain', ' Conservative']

It trained really fast and it works. This is really encouraging.