Aspect Sentiment Metrics

How to evaluate the performance?
Published

July 20, 2021

I’ve been working on creating an aspect sentiment model and I have successfully trained it. When performing manual evaluations it certainly can predict some entity sentiment correctly, so the approach seems to be viable.

What I need to do is a more systematic evaluation of the model to allow the comparison of different training approaches. The loss is only useful when the loss calculation is stable, yet the loss calculation is the best way to change the performance.


Metrics

So I am going to write a metric function for this. The huggingface trainer takes a compute_metrics function which receives an transformers.EvalPrediction object. This is a glorified tuple that has predictions and label_ids which are both numpy arrays.

I don’t get access to any other statistics, so no loss value or anything.

There are two primary metrics relating to the two tasks that are being performed - entity extraction and sentiment. I want to use sklearn to determine the accuracy of these, and it is possible to evaluate them separately.

Code
from typing import *
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support
)
from transformers import EvalPrediction

def compute_metrics(pred: EvalPrediction) -> Dict[str, float]:
    labels = pred.label_ids.reshape(-1, 3)
    predictions = pred.predictions.reshape(-1, 5)

    entity_labels = labels[:, :2]
    entity_predictions = predictions[:, :2] > 0.

    entity_accuracy = accuracy_score(entity_labels, entity_predictions)

    (
        (entity_start_precision, entity_end_precision),
        (entity_start_recall,    entity_end_recall),
        (entity_start_fscore,    entity_end_fscore),
        _
    ) = precision_recall_fscore_support(entity_labels, entity_predictions)

    sentiment_mask = labels[:, 1] > 0
    sentiment_labels = labels[sentiment_mask, 2]
    sentiment_predictions = predictions[sentiment_mask, 2:].argmax(axis=1)

    sentiment_accuracy = accuracy_score(sentiment_labels, sentiment_predictions)

    (
        (sentiment_negative_precision, sentiment_neutral_precision, sentiment_positive_precision),
        (sentiment_negative_recall,    sentiment_neutral_recall,    sentiment_positive_recall),
        (sentiment_negative_fscore,    sentiment_neutral_fscore,    sentiment_positive_fscore),
        _
    ) = precision_recall_fscore_support(sentiment_labels, sentiment_predictions)

    return {
        "quality": (
            entity_end_fscore *
            sentiment_negative_fscore *
            sentiment_neutral_fscore *
            sentiment_positive_fscore
        ),

        "entity_accuracy":        entity_accuracy,
        "entity_start_precision": entity_start_precision,
        "entity_start_recall":    entity_start_recall,
        "entity_start_f1_score":  entity_start_fscore,
        "entity_end_precision":   entity_end_precision,
        "entity_end_recall":      entity_end_recall,
        "entity_end_f1_score":    entity_end_fscore,

        "sentiment_accuracy":           sentiment_accuracy,
        "sentiment_negative_precision": sentiment_negative_precision,
        "sentiment_negative_recall":    sentiment_negative_recall,
        "sentiment_negative_f1_score":  sentiment_negative_fscore,
        "sentiment_neutral_precision":  sentiment_neutral_precision,
        "sentiment_neutral_recall":     sentiment_neutral_recall,
        "sentiment_neutral_f1_score":   sentiment_neutral_fscore,
        "sentiment_positive_precision": sentiment_positive_precision,
        "sentiment_positive_recall":    sentiment_positive_recall,
        "sentiment_positive_f1_score":  sentiment_positive_fscore,
    }

This is a load of different metrics. I thought it would be be useful to track the different ways that the model can perform. The quality metric is a made up metric that can be used by the trainer to pick the best model, the guiding principle of it is that the entity sentiment predictions should be correct.


Model Definition and Dataset

The only way to evalute these metrics is to see how well they describe the model, so we have to train it again. First we have to load the dataset and model. This is the same as the previous blog post so you can skip to the next section if you wish.

Code
MODEL_NAME = "facebook/bart-base"
MAXIMUM_TOKEN_LENGTH = 128
BATCH_SIZE = 64
EPOCHS = 80
Code
#collapse
from typing import *
from transformers import BartModel, AutoConfig
import torch

class EntitySentimentSequenceClassifier(BartModel):
    def __init__(self, config: AutoConfig) -> None:
        config.num_labels = 5 # start and copy, end and copy, negative, neutral, positive
        super().__init__(config)
        # bart model for sequence classification actually has a more complex classification head
        self.score = torch.nn.Linear(
            in_features=config.d_model,
            out_features=config.num_labels,
            bias=False,
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, ...]:
        outputs = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        hidden_states = outputs[0]  # last hidden state
        predictions = self.score(hidden_states)

        if labels is not None:
            entity_loss = torch.nn.functional.binary_cross_entropy_with_logits(
                predictions[:, :, :2],
                labels[:, :, :2].float(),
            )

            flat_predictions = predictions.reshape(-1, 5)
            flat_labels = labels.reshape(-1, 3)

            end_mask = flat_labels[:, 1] > 0
            sentiment_predictions = flat_predictions[end_mask, 2:]
            sentiment_targets = flat_labels[end_mask, 2]

            sentiment_loss = torch.nn.functional.cross_entropy(
                sentiment_predictions,
                sentiment_targets
            )

            loss = entity_loss + sentiment_loss
            return (loss, predictions)
        return (predictions,)
Code
#collapse
from typing import *
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
sentiment_index = {
    "negative": 0,
    "neutral": 1,
    "positive": 2,
}

def encode(row: Dict[str, Any]) -> Dict[str, Any]:
    text = row["text"]
    entities = row["entities"]
    
    span_starts = {entity["start"] for entity in entities}
    span_ends = {entity["end"] for entity in entities}
    end_sentiments = {
        entity["end"]: sentiment_index[entity["sentiment"]]
        for entity in entities
    }

    tokenized_text = tokenizer(
        text,
        return_offsets_mapping=True,
        max_length=MAXIMUM_TOKEN_LENGTH,
        truncation=True,
        padding="max_length"
    )
    offset_mapping = tokenized_text["offset_mapping"]

    boundaries = [
        (
            int(start in span_starts and start != end),
            int(end in span_ends and start != end),
            end_sentiments.get(end, 0)
        )
        for start, end in offset_mapping
    ]
    return {
        "input_ids": tokenized_text["input_ids"],
        "attention_mask": tokenized_text["attention_mask"],
        "label": boundaries,
    }
Code
#hide_output
import pandas as pd
from datasets import Dataset

train_df = pd.read_parquet("/data/blog/2021-07-18-aspect-sentiment-dataset/train.gz.parquet")
validation_df = pd.read_parquet("/data/blog/2021-07-18-aspect-sentiment-dataset/validation.gz.parquet")
test_df = pd.read_parquet("/data/blog/2021-07-18-aspect-sentiment-dataset/test.gz.parquet")

train_ds = Dataset.from_pandas(train_df)
train_ds = train_ds.map(encode)

validation_ds = Dataset.from_pandas(validation_df)
validation_ds = validation_ds.map(encode)

test_ds = Dataset.from_pandas(test_df)
test_ds = test_ds.map(encode)

Training

Now we can train the model and see what the metrics say about it!

Code
#hide_output
model = EntitySentimentSequenceClassifier.from_pretrained(MODEL_NAME)
Some weights of EntitySentimentSequenceClassifier were not initialized from the model checkpoint at facebook/bart-base and are newly initialized: ['model.score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Code
from pathlib import Path
from transformers import Trainer, TrainingArguments

MODEL_RUN_FOLDER = Path("/data/blog/2021-07-20-aspect-sentiment-metrics/runs")
MODEL_RUN_FOLDER.mkdir(parents=True, exist_ok=True)

training_args = TrainingArguments(
    report_to=[],            
    output_dir=MODEL_RUN_FOLDER / "output",
    overwrite_output_dir=True,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=5e-5,
    warmup_ratio=0.06,
    num_train_epochs=EPOCHS,
    evaluation_strategy="epoch",
    logging_dir=MODEL_RUN_FOLDER / "output",
    logging_steps=100,
    load_best_model_at_end=True,
    metric_for_best_model="quality",
    greater_is_better=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=validation_ds,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

trainer.train()
[5440/5440 1:11:59, Epoch 80/80]
Epoch Training Loss Validation Loss Quality Entity Accuracy Entity Start Precision Entity Start Recall Entity Start F1 Score Entity End Precision Entity End Recall Entity End F1 Score Sentiment Accuracy Sentiment Negative Precision Sentiment Negative Recall Sentiment Negative F1 Score Sentiment Neutral Precision Sentiment Neutral Recall Sentiment Neutral F1 Score Sentiment Positive Precision Sentiment Positive Recall Sentiment Positive F1 Score Runtime Samples Per Second
1 No log 0.797070 0.000000 0.973109 0.538462 0.005255 0.010409 0.000000 0.000000 0.000000 0.680180 0.682353 0.535385 0.600000 0.706128 0.839404 0.767020 0.626741 0.558313 0.590551 1.690100 295.843000
2 1.103500 0.588151 0.005753 0.973172 0.842105 0.120120 0.210250 0.400000 0.006006 0.011834 0.786787 0.788991 0.793846 0.791411 0.850943 0.746689 0.795414 0.713684 0.841191 0.772210 1.764400 283.390000
3 0.570700 0.520855 0.226896 0.975859 0.744952 0.526276 0.616806 0.705128 0.289039 0.410011 0.825826 0.795252 0.824615 0.809668 0.848837 0.846026 0.847430 0.816794 0.796526 0.806533 1.794800 278.578000
4 0.570700 0.494224 0.349037 0.978688 0.695030 0.692943 0.693985 0.655422 0.612613 0.633295 0.824324 0.816199 0.806154 0.811146 0.843234 0.846026 0.844628 0.802469 0.806452 0.804455 1.785500 280.039000
5 0.425600 0.567770 0.359286 0.980750 0.732513 0.715465 0.723889 0.695161 0.647147 0.670295 0.819069 0.828383 0.772308 0.799363 0.821875 0.870861 0.845659 0.807198 0.779156 0.792929 1.769900 282.502000
6 0.311800 0.605942 0.377081 0.980953 0.703806 0.763514 0.732445 0.696035 0.711712 0.703786 0.816066 0.819936 0.784615 0.801887 0.862676 0.811258 0.836177 0.754967 0.848635 0.799065 1.753700 285.108000
7 0.311800 0.680506 0.393409 0.980891 0.680386 0.794294 0.732941 0.676375 0.784535 0.726451 0.819069 0.776163 0.821538 0.798206 0.878403 0.801325 0.838095 0.778032 0.843672 0.809524 1.774600 281.750000
8 0.210300 0.790352 0.394818 0.982375 0.735036 0.756006 0.745374 0.742901 0.726727 0.734725 0.816817 0.742782 0.870769 0.801700 0.853195 0.817881 0.835165 0.836022 0.771712 0.802581 1.779400 280.986000
9 0.152800 0.848886 0.395077 0.982781 0.723803 0.783033 0.752254 0.718001 0.787538 0.751164 0.813814 0.815789 0.763077 0.788553 0.826645 0.852649 0.839446 0.792593 0.796526 0.794554 1.799700 277.817000
10 0.152800 0.971851 0.398642 0.982766 0.719093 0.786036 0.751076 0.717694 0.813063 0.762408 0.807808 0.760563 0.830769 0.794118 0.886320 0.761589 0.819234 0.755459 0.858561 0.803717 1.785400 280.054000
11 0.115100 1.135353 0.386189 0.983609 0.716578 0.804805 0.758133 0.733020 0.818318 0.773324 0.794294 0.779762 0.806154 0.792738 0.896341 0.730132 0.804745 0.704365 0.880893 0.782800 1.782800 280.459000
12 0.085900 0.993660 0.419968 0.984078 0.726662 0.812312 0.767104 0.732981 0.832583 0.779613 0.819820 0.815047 0.800000 0.807453 0.825197 0.867550 0.845843 0.814815 0.764268 0.788732 1.794600 278.615000
13 0.085900 1.026248 0.411750 0.983125 0.711985 0.829580 0.766297 0.687682 0.884384 0.773727 0.811562 0.788571 0.849231 0.817778 0.873840 0.779801 0.824147 0.753950 0.828784 0.789598 1.760800 283.959000
14 0.070800 1.142620 0.408904 0.985297 0.770818 0.785285 0.777984 0.763598 0.822072 0.791757 0.804805 0.751351 0.855385 0.800000 0.858696 0.784768 0.820069 0.780488 0.794045 0.787208 1.790800 279.198000
15 0.062800 1.141223 0.397795 0.984859 0.738420 0.813814 0.774286 0.756592 0.840090 0.796158 0.797297 0.806557 0.756923 0.780952 0.848214 0.786424 0.816151 0.730193 0.846154 0.783908 1.800400 277.722000
16 0.062800 1.261426 0.421695 0.985203 0.744990 0.809309 0.775819 0.773203 0.831832 0.801447 0.814565 0.763006 0.812308 0.786885 0.837134 0.850993 0.844007 0.825269 0.761787 0.792258 1.808500 276.475000
17 0.057400 1.425968 0.406145 0.985234 0.752606 0.813063 0.781667 0.761002 0.843844 0.800285 0.798799 0.787425 0.809231 0.798179 0.870476 0.756623 0.809566 0.727273 0.853598 0.785388 1.804600 277.064000
18 0.050500 1.417094 0.414896 0.985531 0.757322 0.815315 0.785249 0.754980 0.853604 0.801268 0.805556 0.756233 0.840000 0.795918 0.886973 0.766556 0.822380 0.750557 0.836228 0.791080 1.807100 276.682000
19 0.050500 1.370741 0.410895 0.985500 0.750687 0.820571 0.784075 0.752802 0.857357 0.801685 0.804054 0.772455 0.793846 0.783005 0.866055 0.781457 0.821584 0.752759 0.846154 0.796729 1.817800 275.052000
20 0.050800 1.269680 0.427258 0.985750 0.761871 0.795045 0.778104 0.766260 0.849099 0.805556 0.816066 0.792793 0.812308 0.802432 0.813272 0.872517 0.841853 0.843305 0.734491 0.785146 1.770800 282.360000
21 0.042700 1.373681 0.443625 0.985750 0.739837 0.819820 0.777778 0.761747 0.864114 0.809708 0.824324 0.791541 0.806154 0.798780 0.840650 0.855960 0.848236 0.826425 0.791563 0.808619 1.809400 276.332000
22 0.042700 1.320477 0.438937 0.985906 0.761803 0.799550 0.780220 0.782394 0.834084 0.807413 0.822072 0.818770 0.778462 0.798107 0.830400 0.859272 0.844589 0.811558 0.801489 0.806492 1.826300 273.781000
23 0.038400 1.403632 0.444728 0.986078 0.756250 0.817568 0.785714 0.775051 0.853604 0.812433 0.825075 0.848276 0.756923 0.800000 0.822257 0.880795 0.850520 0.812658 0.796526 0.804511 1.763700 283.503000
24 0.035800 1.466172 0.437135 0.986297 0.758766 0.812312 0.784627 0.776119 0.858859 0.815396 0.814565 0.793510 0.827692 0.810241 0.876611 0.788079 0.829991 0.755556 0.843672 0.797186 1.881400 265.760000
25 0.035700 1.437172 0.436925 0.986453 0.767065 0.818318 0.791863 0.775874 0.849850 0.811179 0.819069 0.814103 0.781538 0.797488 0.836334 0.846026 0.841152 0.797066 0.808933 0.802956 1.823300 274.229000
26 0.035700 1.501757 0.428386 0.986516 0.777614 0.792793 0.785130 0.789773 0.834835 0.811679 0.812312 0.787879 0.800000 0.793893 0.850953 0.812914 0.831499 0.778824 0.821340 0.799517 1.809900 276.263000
27 0.033900 1.640986 0.431580 0.986484 0.761672 0.820571 0.790025 0.779442 0.859610 0.817565 0.813063 0.801887 0.784615 0.793157 0.834163 0.832781 0.833471 0.790754 0.806452 0.798526 1.837100 272.163000
28 0.027900 1.660942 0.417174 0.986219 0.762500 0.824324 0.792208 0.771018 0.846847 0.807156 0.807057 0.807818 0.763077 0.784810 0.872029 0.789735 0.828844 0.732218 0.868486 0.794552 1.833100 272.767000
29 0.027900 1.593555 0.439876 0.986422 0.754589 0.833333 0.792009 0.762279 0.873874 0.814271 0.819069 0.791541 0.806154 0.798780 0.842809 0.834437 0.838602 0.806452 0.806452 0.806452 1.825800 273.850000
30 0.031100 1.634480 0.429155 0.986891 0.769391 0.834084 0.800432 0.783784 0.849099 0.815135 0.812312 0.775148 0.806154 0.790347 0.843803 0.822848 0.833194 0.797531 0.801489 0.799505 1.819600 274.788000
31 0.029600 1.771458 0.422438 0.986594 0.764666 0.841592 0.801287 0.770492 0.846847 0.806867 0.807057 0.760989 0.852308 0.804064 0.876190 0.761589 0.814880 0.762980 0.838710 0.799054 1.917700 260.732000
32 0.029600 1.604859 0.437154 0.986812 0.768802 0.828829 0.797688 0.783727 0.846096 0.813718 0.816817 0.789157 0.806154 0.797565 0.867857 0.804636 0.835052 0.772727 0.843672 0.806643 1.860600 268.728000
33 0.032500 1.675310 0.444260 0.986625 0.771348 0.820571 0.795198 0.782456 0.837087 0.808850 0.822823 0.774929 0.836923 0.804734 0.867958 0.816225 0.841297 0.801453 0.821340 0.811275 1.875100 266.655000
34 0.023600 1.821172 0.435406 0.986844 0.780101 0.812312 0.795881 0.788652 0.834835 0.811087 0.814565 0.740260 0.876923 0.802817 0.890595 0.768212 0.824889 0.788732 0.833747 0.810615 1.839900 271.751000
35 0.023600 1.630846 0.452027 0.986734 0.786921 0.804054 0.795395 0.791424 0.817568 0.804284 0.830330 0.788406 0.836923 0.811940 0.850082 0.854305 0.852188 0.836842 0.789082 0.812261 1.866100 267.935000
36 0.023500 1.696693 0.434577 0.986953 0.790087 0.813814 0.801775 0.783498 0.834084 0.808000 0.818318 0.837288 0.760000 0.796774 0.844221 0.834437 0.839301 0.770455 0.841191 0.804270 1.909300 261.869000
37 0.020700 1.783487 0.436999 0.987156 0.787373 0.814565 0.800738 0.792640 0.840841 0.816029 0.816817 0.814103 0.781538 0.797488 0.852740 0.824503 0.838384 0.770642 0.833747 0.800954 1.851100 270.108000
38 0.020700 1.757433 0.431940 0.986969 0.788012 0.809309 0.798519 0.793377 0.827327 0.809996 0.813814 0.811321 0.793846 0.802488 0.860963 0.799669 0.829185 0.757174 0.851117 0.801402 1.893000 264.129000
39 0.019900 1.710060 0.445421 0.987031 0.784838 0.816066 0.800147 0.788502 0.834084 0.810653 0.824324 0.824675 0.781538 0.802528 0.861538 0.834437 0.847771 0.774487 0.843672 0.807601 1.913500 261.300000
40 0.021100 1.679544 0.447749 0.987078 0.788971 0.805556 0.797177 0.804769 0.810811 0.807779 0.824324 0.817901 0.815385 0.816641 0.862369 0.819536 0.840407 0.778802 0.838710 0.807646 1.989600 251.305000
41 0.021100 1.772447 0.439737 0.987078 0.795522 0.800300 0.797904 0.805078 0.809309 0.807188 0.821321 0.809969 0.800000 0.804954 0.857877 0.829470 0.843434 0.779859 0.826303 0.802410 2.427700 205.959000
42 0.017800 1.769264 0.437806 0.986609 0.781726 0.809309 0.795278 0.782361 0.825826 0.803506 0.822823 0.838926 0.769231 0.802568 0.826498 0.867550 0.846527 0.805000 0.799007 0.801993 1.814600 275.537000
43 0.018600 1.792316 0.445057 0.986906 0.794449 0.795045 0.794747 0.801788 0.807808 0.804787 0.826577 0.833333 0.784615 0.808241 0.833068 0.867550 0.849959 0.811083 0.799007 0.805000 1.871400 267.178000
44 0.018600 1.717856 0.443054 0.987062 0.790715 0.805556 0.798066 0.798687 0.822072 0.810211 0.820571 0.825949 0.803077 0.814353 0.845763 0.826159 0.835846 0.781690 0.826303 0.803378 3.030500 164.987000
45 0.016800 1.719061 0.449725 0.987000 0.799849 0.795045 0.797440 0.811450 0.798048 0.804693 0.828829 0.833866 0.803077 0.818182 0.840580 0.864238 0.852245 0.806533 0.796526 0.801498 1.996600 250.426000
46 0.016200 1.789883 0.441381 0.986750 0.784571 0.809309 0.796748 0.800300 0.800300 0.800300 0.825075 0.842809 0.775385 0.807692 0.840131 0.852649 0.846343 0.790476 0.823821 0.806804 1.909100 261.905000
47 0.016200 1.738527 0.448426 0.986875 0.777936 0.820571 0.798685 0.789964 0.827327 0.808214 0.825826 0.820755 0.803077 0.811820 0.857143 0.834437 0.845638 0.786385 0.831266 0.808203 1.898800 263.329000
48 0.015000 1.759438 0.448316 0.986688 0.782734 0.816817 0.799412 0.790087 0.813814 0.801775 0.828078 0.797619 0.824615 0.810893 0.858844 0.836093 0.847315 0.808824 0.818859 0.813810 2.141100 233.524000
49 0.014300 1.739141 0.438471 0.987031 0.785869 0.818318 0.801765 0.793054 0.822823 0.807664 0.822072 0.810897 0.778462 0.794349 0.852596 0.842715 0.847627 0.787234 0.826303 0.806295 2.268300 220.431000
50 0.015200 1.806629 0.437686 0.986938 0.781585 0.822072 0.801317 0.790123 0.816817 0.803248 0.822072 0.832237 0.778462 0.804452 0.832528 0.855960 0.844082 0.798526 0.806452 0.802469 2.116600 236.230000
51 0.015200 1.751909 0.448987 0.987078 0.783803 0.813814 0.798527 0.795488 0.820571 0.807834 0.827327 0.806061 0.818462 0.812214 0.839806 0.859272 0.849427 0.825521 0.786600 0.805591 2.417800 206.803000
52 0.013600 1.756309 0.446924 0.987359 0.792522 0.811562 0.801929 0.799267 0.819069 0.809047 0.825826 0.807339 0.812308 0.809816 0.842020 0.855960 0.848933 0.815857 0.791563 0.803526 2.290800 218.269000
53 0.013600 1.771644 0.449768 0.987203 0.793662 0.808559 0.801041 0.802985 0.807808 0.805389 0.829580 0.804281 0.809231 0.806748 0.857143 0.854305 0.855721 0.808933 0.808933 0.808933 2.209900 226.255000
54 0.013600 1.794925 0.450868 0.987203 0.789627 0.811562 0.800444 0.803254 0.815315 0.809240 0.827327 0.801802 0.821538 0.811550 0.852596 0.842715 0.847627 0.810945 0.808933 0.809938 2.327500 214.822000
55 0.011700 1.794530 0.443402 0.987266 0.801815 0.795796 0.798794 0.805556 0.805556 0.805556 0.825075 0.832787 0.781538 0.806349 0.837097 0.859272 0.848039 0.800983 0.808933 0.804938 2.170300 230.387000
56 0.013500 1.799630 0.439336 0.986969 0.790923 0.798048 0.794469 0.797912 0.803303 0.800599 0.824324 0.799392 0.809231 0.804281 0.847934 0.849338 0.848635 0.809045 0.799007 0.803995 3.287800 152.079000
57 0.013500 1.805381 0.440039 0.987062 0.793208 0.789039 0.791118 0.812117 0.795045 0.803490 0.824324 0.808642 0.806154 0.807396 0.830696 0.869205 0.849515 0.827128 0.771712 0.798460 2.088400 239.419000
58 0.012400 1.781618 0.446814 0.987047 0.791420 0.803303 0.797317 0.793860 0.815315 0.804444 0.825826 0.831210 0.803077 0.816901 0.840722 0.847682 0.844188 0.799511 0.811414 0.805419 2.302100 217.195000
59 0.011700 1.823947 0.449294 0.986969 0.789668 0.803303 0.796427 0.798369 0.808559 0.803432 0.828078 0.811550 0.821538 0.816514 0.841503 0.852649 0.847039 0.820972 0.796526 0.808564 2.424200 206.255000
60 0.011700 1.827804 0.444374 0.986984 0.794872 0.791291 0.793078 0.810225 0.785285 0.797560 0.827327 0.818462 0.818462 0.818462 0.832268 0.862583 0.847154 0.826772 0.781638 0.803571 2.178200 229.545000
61 0.011600 1.809431 0.452122 0.987000 0.802752 0.788288 0.795455 0.813917 0.781532 0.797396 0.831832 0.828125 0.815385 0.821705 0.842532 0.859272 0.850820 0.818182 0.803970 0.811014 2.068400 241.730000
62 0.010800 1.830662 0.449956 0.987328 0.798200 0.798799 0.798499 0.806912 0.806306 0.806609 0.828078 0.828571 0.803077 0.815625 0.832803 0.865894 0.849026 0.820051 0.791563 0.805556 2.177800 229.585000
63 0.010800 1.831762 0.447547 0.987203 0.794074 0.804805 0.799403 0.805263 0.804054 0.804658 0.826577 0.823899 0.806154 0.814930 0.836305 0.854305 0.845209 0.813602 0.801489 0.807500 2.326200 214.944000
64 0.010500 1.829183 0.442622 0.987078 0.792692 0.798048 0.795361 0.807780 0.795045 0.801362 0.825075 0.808642 0.806154 0.807396 0.843234 0.846026 0.844628 0.810945 0.808933 0.809938 2.899100 172.470000
65 0.010200 1.798383 0.440475 0.987062 0.791356 0.797297 0.794316 0.809670 0.792042 0.800759 0.824324 0.806154 0.806154 0.806154 0.837398 0.852649 0.844955 0.818878 0.796526 0.807547 2.067400 241.854000
66 0.010200 1.800962 0.451185 0.987172 0.785922 0.813063 0.799262 0.795903 0.816817 0.806225 0.828829 0.828025 0.800000 0.813772 0.843393 0.855960 0.849630 0.807407 0.811414 0.809406 2.146000 232.988000
67 0.010700 1.809239 0.449513 0.987297 0.792873 0.801802 0.797312 0.810280 0.804805 0.807533 0.826577 0.800604 0.815385 0.807927 0.856655 0.831126 0.843697 0.804819 0.828784 0.816626 2.300700 217.322000
68 0.009400 1.818187 0.440534 0.987187 0.791233 0.799550 0.795370 0.807721 0.801051 0.804372 0.822823 0.827922 0.784615 0.805687 0.847826 0.839404 0.843594 0.784038 0.828784 0.805790 2.070600 241.479000
69 0.009400 1.839509 0.449072 0.987281 0.797156 0.799550 0.798351 0.811972 0.794294 0.803036 0.828829 0.838816 0.784615 0.810811 0.849587 0.850993 0.850289 0.791962 0.831266 0.811138 2.190800 228.229000
70 0.009500 1.833614 0.440818 0.987141 0.787187 0.802553 0.794796 0.805745 0.800300 0.803013 0.823574 0.820513 0.787692 0.803768 0.850420 0.837748 0.844037 0.788235 0.831266 0.809179 2.240400 223.177000
71 0.009400 1.792183 0.437578 0.987266 0.795609 0.789039 0.792311 0.814815 0.792793 0.803653 0.821321 0.805556 0.803077 0.804314 0.841322 0.842715 0.842018 0.803970 0.803970 0.803970 2.265900 220.660000
72 0.009400 1.809693 0.446543 0.987187 0.794623 0.798799 0.796705 0.805891 0.801051 0.803464 0.826577 0.805471 0.815385 0.810398 0.856176 0.837748 0.846862 0.800971 0.818859 0.809816 1.897500 263.507000
73 0.009300 1.782456 0.439629 0.987141 0.793876 0.798048 0.795957 0.812159 0.782282 0.796941 0.825075 0.823718 0.790769 0.806907 0.844884 0.847682 0.846281 0.797101 0.818859 0.807834 2.327800 214.797000
74 0.009500 1.808447 0.442311 0.987141 0.789903 0.798799 0.794326 0.811583 0.789039 0.800152 0.825075 0.804281 0.809231 0.806748 0.856899 0.832781 0.844668 0.796651 0.826303 0.811206 2.126700 235.101000
75 0.008800 1.795766 0.450452 0.987062 0.793284 0.798048 0.795659 0.807663 0.791291 0.799393 0.830330 0.824841 0.796923 0.810642 0.854027 0.842715 0.848333 0.800948 0.838710 0.819394 2.204600 226.797000
76 0.008800 1.804343 0.448934 0.987062 0.790526 0.801802 0.796124 0.805471 0.795796 0.800604 0.828829 0.832258 0.793846 0.812598 0.850000 0.844371 0.847176 0.796209 0.833747 0.814545 2.178600 229.509000
77 0.008800 1.799190 0.445429 0.987016 0.789513 0.802553 0.795979 0.805513 0.789790 0.797574 0.828078 0.824841 0.796923 0.810642 0.847682 0.847682 0.847682 0.801932 0.823821 0.812729 2.220700 225.156000
78 0.009000 1.787097 0.444187 0.987000 0.792354 0.793544 0.792948 0.807220 0.789039 0.798026 0.827327 0.815625 0.803077 0.809302 0.847682 0.847682 0.847682 0.806373 0.816377 0.811344 2.230900 224.124000
79 0.009000 1.790349 0.442649 0.986969 0.791480 0.795045 0.793258 0.806452 0.788288 0.797267 0.826577 0.817610 0.800000 0.808709 0.847430 0.846026 0.846727 0.802920 0.818859 0.810811 2.184300 228.903000
80 0.008700 1.792766 0.445089 0.986984 0.791636 0.795796 0.793710 0.805833 0.788288 0.796964 0.828078 0.824841 0.796923 0.810642 0.847682 0.847682 0.847682 0.801932 0.823821 0.812729 2.210100 226.236000

TrainOutput(global_step=5440, training_loss=0.07378851720405852, metrics={'train_runtime': 4319.8751, 'train_samples_per_second': 1.259, 'total_flos': 3.680907436228608e+16, 'epoch': 80.0, 'init_mem_cpu_alloc_delta': 2128576512, 'init_mem_gpu_alloc_delta': 558472192, 'init_mem_cpu_peaked_delta': 380325888, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 329842688, 'train_mem_gpu_alloc_delta': 2244640256, 'train_mem_cpu_peaked_delta': 720216064, 'train_mem_gpu_peaked_delta': 8832347136})
Code
model.save_pretrained(Path("/data/blog/2021-07-20-aspect-sentiment-metrics/model"))

So that is a monstrous amount of output. It’s interesting for me but much too much for this blog post.

The summary is that it has an entity end F1 score of around 0.8, and a sentiment accuracy around 0.82. The SOTA scores for the aspect sentiment on this dataset is between 0.83 and 0.85 (source). It’s not terrible.


Evaluation

I hooked the previous version of this up to a gradio and people had a play with it. One of the things that was a let down was the app would error if there were no entities in the text. This is due to the specific order of the operations, so that’s something that I need to fix this time around.

Code
sentiment_names = ["negative", "neutral", "positive"]

def aspect_sentiment(text: str) -> List[Tuple[str, str]]:
    tokenized_text = tokenizer(text, return_tensors="pt")

    with torch.no_grad():
        input_ids = tokenized_text["input_ids"].to(model.device)
        output = model(input_ids=input_ids)[0]
        entity_boundaries = output[:, :, :2] > 0.
        entity_mask = (output[:, :, 1] > 0.).flatten()

        # performing argmax early to avoid problem with no entity predictions
        entity_sentiment = (
            output.reshape(-1, 5)
                [:, 2:]
                .argmax(dim=-1)
                [entity_mask]
        )

    entities = tokenizer.batch_decode([
        [input_id]
        for input_id, boundaries in zip(tokenized_text["input_ids"][0], entity_boundaries[0])
        if True in boundaries
    ])

    return [
        (entity, sentiment_names[sentiment])
        for entity, sentiment in zip(entities, entity_sentiment.tolist())
    ]
Code
aspect_sentiment("the hotel had oversold the rooms, there was no place for us")
[]
Code
aspect_sentiment("The food was terrible, but the view was fantastic")
[(' food', 'negative'), (' view', 'positive')]

So that works. If you’re interested in using gradio, you can see the code required below:

Code
#hide_output
import gradio as gr

def gradio_sentiment(text: str) -> str:
    results = aspect_sentiment(text)
    if not results:
        return "No entities found"
    return "\n".join(f"entity: {entity}, sentiment: {sentiment}" for entity, sentiment in results)

gr.Interface(
    fn=gradio_sentiment,
    inputs=["textbox"],
    outputs="text"
).launch(share=True)
Running locally at: http://127.0.0.1:7860/
This share link will expire in 24 hours. If you need a permanent link, visit: https://gradio.app/introducing-hosted (NEW!)
Running on External URL: https://33922.gradio.app
Interface loading below...
Tip: Add interpretation to your model by simply adding `interpretation="default"` to `Interface()`
(<Flask 'gradio.networking'>,
 'http://127.0.0.1:7860/',
 'https://33922.gradio.app')

That’s all fine. What happens if we look at the sentiment per word?

Code
sentiment_names = ["negative", "neutral", "positive"]

def word_sentiment(text: str) -> List[Tuple[str, str]]:
    tokenized_text = tokenizer(text, return_tensors="pt")

    with torch.no_grad():
        input_ids = tokenized_text["input_ids"].to(model.device)
        output = model(input_ids=input_ids)[0]
        token_sentiment = (
            output.reshape(-1, 5)
                [:, 2:]
                .softmax(dim=-1)
        )

    return [
        {
            "token": tokenizer.decode(token),
            "negative": round(sentiment[0], 3),
            "neutral": round(sentiment[1], 3),
            "positive": round(sentiment[2], 3),
        }
        for token, sentiment in zip(input_ids[0].tolist(), token_sentiment.tolist())
    ]
Code
word_sentiment("The food was terrible but the view was fantastic")
[{'token': '<s>', 'negative': 0.097, 'neutral': 0.042, 'positive': 0.861},
 {'token': 'The', 'negative': 0.999, 'neutral': 0.0, 'positive': 0.001},
 {'token': ' food', 'negative': 1.0, 'neutral': 0.0, 'positive': 0.0},
 {'token': ' was', 'negative': 1.0, 'neutral': 0.0, 'positive': 0.0},
 {'token': ' terrible',
  'negative': 0.993,
  'neutral': 0.005,
  'positive': 0.002},
 {'token': ' but', 'negative': 0.93, 'neutral': 0.035, 'positive': 0.036},
 {'token': ' the', 'negative': 0.0, 'neutral': 0.0, 'positive': 1.0},
 {'token': ' view', 'negative': 0.0, 'neutral': 0.0, 'positive': 1.0},
 {'token': ' was', 'negative': 0.0, 'neutral': 0.0, 'positive': 1.0},
 {'token': ' fantastic',
  'negative': 0.001,
  'neutral': 0.005,
  'positive': 0.994},
 {'token': '</s>', 'negative': 0.023, 'neutral': 0.159, 'positive': 0.819}]

This output suggests that spans of the text share sentiment levels. It’s quite hard to decode this output so visualizing it would help. I’m going to translate this to something that can color the tokens.

This is an extremely complex way of coloring the text, it’s what I have available to me right now though. The text will become a graphviz graph where each node is a token with a color related to the sentiment - red is mostly negative, blue is mostly neutral and green is mostly positive.

Code
show_word_sentiment("The food was terrible but the view was fantastic")

Code
show_word_sentiment("Marriott food was better than Hilton's, but the Hilton view was fantastic.")


Potential Improvements

After making this there are a few ideas I have about improving it. Broadly the model does two things, and the improvements are directed to either entity extraction or token sentiment.

Entity Extraction

Given that the sentiment allocation across the tokens seems reasonable it may be that the entity extraction could be dropped from the model entirely. There are ways to extract entities from the text, such as spacy, that work well. The spacy pos tagger has an accuracy of 97% and the named entity recognition system has an f1 score of 0.84 (source).

If the entity extraction is retained then it would be perfectly possible to train it separately from the token sentiment. Currently the sentiment is only considered when the entity is marked as ending. Instead a semaphore value can be used to exclude the tokens that do not have a known sentiment. This would fully separate entity extraction and token sentiment.

Token Sentiment

Checking the sentiment only for the end token is an artifact of the sequence to sequence approach that was first tried. Instead if the sentiment label is applied to every token in the entity the token sentiment allocation could be more consistent.

It may be possible to pretrain the model on document level sentiment and then refine it. The amount of document level sentiment data is considerably larger so this might help.

Token Relationships

In many ways the sentiment in the text can be viewed as a relationship between different entities in the text. For example the text Marriott food was better than Hilton’s is a comparison between Marriott food and Hilton food. With a better idea of the relationships it might be possible to represent the relationship as a vector, which could then be clustered to find instances of similar relationships between entity pairs.

This would be achievable by using cosine similiarity loss - similar to semantic search. To efficiently generate the actual pairing it might be nice to evaluate something similar to CLIP, where the image and text were processed to produce separate vectors, and the final classification was a dot product of this. This also reminds me of the adafactor / adam difference where the individual values can be appropriately estimated from the dot product.

SOTA Review

I need to evaluate the different approaches that have been published. The MAMS SOTA results are better than what this approach is currently achieving (this uses the smaller BART model though).