I’ve been working on creating an aspect sentiment model and I have successfully trained it. When performing manual evaluations it certainly can predict some entity sentiment correctly, so the approach seems to be viable.
What I need to do is a more systematic evaluation of the model to allow the comparison of different training approaches. The loss is only useful when the loss calculation is stable, yet the loss calculation is the best way to change the performance.
Metrics
So I am going to write a metric function for this. The huggingface trainer takes a compute_metrics
function which receives an transformers.EvalPrediction object. This is a glorified tuple that has predictions
and label_ids
which are both numpy arrays.
I don’t get access to any other statistics, so no loss value or anything.
There are two primary metrics relating to the two tasks that are being performed - entity extraction and sentiment. I want to use sklearn to determine the accuracy of these, and it is possible to evaluate them separately.
Code
from typing import *
from sklearn.metrics import (
accuracy_score, precision_recall_fscore_support
)
from transformers import EvalPrediction
def compute_metrics(pred: EvalPrediction) -> Dict[str , float ]:
labels = pred.label_ids.reshape(- 1 , 3 )
predictions = pred.predictions.reshape(- 1 , 5 )
entity_labels = labels[:, :2 ]
entity_predictions = predictions[:, :2 ] > 0.
entity_accuracy = accuracy_score(entity_labels, entity_predictions)
(
(entity_start_precision, entity_end_precision),
(entity_start_recall, entity_end_recall),
(entity_start_fscore, entity_end_fscore),
_
) = precision_recall_fscore_support(entity_labels, entity_predictions)
sentiment_mask = labels[:, 1 ] > 0
sentiment_labels = labels[sentiment_mask, 2 ]
sentiment_predictions = predictions[sentiment_mask, 2 :].argmax(axis= 1 )
sentiment_accuracy = accuracy_score(sentiment_labels, sentiment_predictions)
(
(sentiment_negative_precision, sentiment_neutral_precision, sentiment_positive_precision),
(sentiment_negative_recall, sentiment_neutral_recall, sentiment_positive_recall),
(sentiment_negative_fscore, sentiment_neutral_fscore, sentiment_positive_fscore),
_
) = precision_recall_fscore_support(sentiment_labels, sentiment_predictions)
return {
"quality" : (
entity_end_fscore *
sentiment_negative_fscore *
sentiment_neutral_fscore *
sentiment_positive_fscore
),
"entity_accuracy" : entity_accuracy,
"entity_start_precision" : entity_start_precision,
"entity_start_recall" : entity_start_recall,
"entity_start_f1_score" : entity_start_fscore,
"entity_end_precision" : entity_end_precision,
"entity_end_recall" : entity_end_recall,
"entity_end_f1_score" : entity_end_fscore,
"sentiment_accuracy" : sentiment_accuracy,
"sentiment_negative_precision" : sentiment_negative_precision,
"sentiment_negative_recall" : sentiment_negative_recall,
"sentiment_negative_f1_score" : sentiment_negative_fscore,
"sentiment_neutral_precision" : sentiment_neutral_precision,
"sentiment_neutral_recall" : sentiment_neutral_recall,
"sentiment_neutral_f1_score" : sentiment_neutral_fscore,
"sentiment_positive_precision" : sentiment_positive_precision,
"sentiment_positive_recall" : sentiment_positive_recall,
"sentiment_positive_f1_score" : sentiment_positive_fscore,
}
This is a load of different metrics. I thought it would be be useful to track the different ways that the model can perform. The quality metric is a made up metric that can be used by the trainer to pick the best model, the guiding principle of it is that the entity sentiment predictions should be correct.
Model Definition and Dataset
The only way to evalute these metrics is to see how well they describe the model, so we have to train it again. First we have to load the dataset and model. This is the same as the previous blog post so you can skip to the next section if you wish.
Code
MODEL_NAME = "facebook/bart-base"
MAXIMUM_TOKEN_LENGTH = 128
BATCH_SIZE = 64
EPOCHS = 80
Code
#collapse
from typing import *
from transformers import BartModel, AutoConfig
import torch
class EntitySentimentSequenceClassifier(BartModel):
def __init__ (self , config: AutoConfig) -> None :
config.num_labels = 5 # start and copy, end and copy, negative, neutral, positive
super ().__init__ (config)
# bart model for sequence classification actually has a more complex classification head
self .score = torch.nn.Linear(
in_features= config.d_model,
out_features= config.num_labels,
bias= False ,
)
def forward(
self ,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None ,
labels: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, ...]:
outputs = super ().forward(
input_ids= input_ids,
attention_mask= attention_mask,
)
hidden_states = outputs[0 ] # last hidden state
predictions = self .score(hidden_states)
if labels is not None :
entity_loss = torch.nn.functional.binary_cross_entropy_with_logits(
predictions[:, :, :2 ],
labels[:, :, :2 ].float (),
)
flat_predictions = predictions.reshape(- 1 , 5 )
flat_labels = labels.reshape(- 1 , 3 )
end_mask = flat_labels[:, 1 ] > 0
sentiment_predictions = flat_predictions[end_mask, 2 :]
sentiment_targets = flat_labels[end_mask, 2 ]
sentiment_loss = torch.nn.functional.cross_entropy(
sentiment_predictions,
sentiment_targets
)
loss = entity_loss + sentiment_loss
return (loss, predictions)
return (predictions,)
Code
#collapse
from typing import *
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
sentiment_index = {
"negative" : 0 ,
"neutral" : 1 ,
"positive" : 2 ,
}
def encode(row: Dict[str , Any]) -> Dict[str , Any]:
text = row["text" ]
entities = row["entities" ]
span_starts = {entity["start" ] for entity in entities}
span_ends = {entity["end" ] for entity in entities}
end_sentiments = {
entity["end" ]: sentiment_index[entity["sentiment" ]]
for entity in entities
}
tokenized_text = tokenizer(
text,
return_offsets_mapping= True ,
max_length= MAXIMUM_TOKEN_LENGTH,
truncation= True ,
padding= "max_length"
)
offset_mapping = tokenized_text["offset_mapping" ]
boundaries = [
(
int (start in span_starts and start != end),
int (end in span_ends and start != end),
end_sentiments.get(end, 0 )
)
for start, end in offset_mapping
]
return {
"input_ids" : tokenized_text["input_ids" ],
"attention_mask" : tokenized_text["attention_mask" ],
"label" : boundaries,
}
Code
#hide_output
import pandas as pd
from datasets import Dataset
train_df = pd.read_parquet("/data/blog/2021-07-18-aspect-sentiment-dataset/train.gz.parquet" )
validation_df = pd.read_parquet("/data/blog/2021-07-18-aspect-sentiment-dataset/validation.gz.parquet" )
test_df = pd.read_parquet("/data/blog/2021-07-18-aspect-sentiment-dataset/test.gz.parquet" )
train_ds = Dataset.from_pandas(train_df)
train_ds = train_ds.map (encode)
validation_ds = Dataset.from_pandas(validation_df)
validation_ds = validation_ds.map (encode)
test_ds = Dataset.from_pandas(test_df)
test_ds = test_ds.map (encode)
Training
Now we can train the model and see what the metrics say about it!
Code
#hide_output
model = EntitySentimentSequenceClassifier.from_pretrained(MODEL_NAME)
Some weights of EntitySentimentSequenceClassifier were not initialized from the model checkpoint at facebook/bart-base and are newly initialized: ['model.score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Code
from pathlib import Path
from transformers import Trainer, TrainingArguments
MODEL_RUN_FOLDER = Path("/data/blog/2021-07-20-aspect-sentiment-metrics/runs" )
MODEL_RUN_FOLDER.mkdir(parents= True , exist_ok= True )
training_args = TrainingArguments(
report_to= [],
output_dir= MODEL_RUN_FOLDER / "output" ,
overwrite_output_dir= True ,
per_device_train_batch_size= BATCH_SIZE,
per_device_eval_batch_size= BATCH_SIZE,
learning_rate= 5e-5 ,
warmup_ratio= 0.06 ,
num_train_epochs= EPOCHS,
evaluation_strategy= "epoch" ,
logging_dir= MODEL_RUN_FOLDER / "output" ,
logging_steps= 100 ,
load_best_model_at_end= True ,
metric_for_best_model= "quality" ,
greater_is_better= True ,
)
trainer = Trainer(
model= model,
args= training_args,
train_dataset= train_ds,
eval_dataset= validation_ds,
tokenizer= tokenizer,
compute_metrics= compute_metrics,
)
trainer.train()
[5440/5440 1:11:59, Epoch 80/80]
Epoch
Training Loss
Validation Loss
Quality
Entity Accuracy
Entity Start Precision
Entity Start Recall
Entity Start F1 Score
Entity End Precision
Entity End Recall
Entity End F1 Score
Sentiment Accuracy
Sentiment Negative Precision
Sentiment Negative Recall
Sentiment Negative F1 Score
Sentiment Neutral Precision
Sentiment Neutral Recall
Sentiment Neutral F1 Score
Sentiment Positive Precision
Sentiment Positive Recall
Sentiment Positive F1 Score
Runtime
Samples Per Second
1
No log
0.797070
0.000000
0.973109
0.538462
0.005255
0.010409
0.000000
0.000000
0.000000
0.680180
0.682353
0.535385
0.600000
0.706128
0.839404
0.767020
0.626741
0.558313
0.590551
1.690100
295.843000
2
1.103500
0.588151
0.005753
0.973172
0.842105
0.120120
0.210250
0.400000
0.006006
0.011834
0.786787
0.788991
0.793846
0.791411
0.850943
0.746689
0.795414
0.713684
0.841191
0.772210
1.764400
283.390000
3
0.570700
0.520855
0.226896
0.975859
0.744952
0.526276
0.616806
0.705128
0.289039
0.410011
0.825826
0.795252
0.824615
0.809668
0.848837
0.846026
0.847430
0.816794
0.796526
0.806533
1.794800
278.578000
4
0.570700
0.494224
0.349037
0.978688
0.695030
0.692943
0.693985
0.655422
0.612613
0.633295
0.824324
0.816199
0.806154
0.811146
0.843234
0.846026
0.844628
0.802469
0.806452
0.804455
1.785500
280.039000
5
0.425600
0.567770
0.359286
0.980750
0.732513
0.715465
0.723889
0.695161
0.647147
0.670295
0.819069
0.828383
0.772308
0.799363
0.821875
0.870861
0.845659
0.807198
0.779156
0.792929
1.769900
282.502000
6
0.311800
0.605942
0.377081
0.980953
0.703806
0.763514
0.732445
0.696035
0.711712
0.703786
0.816066
0.819936
0.784615
0.801887
0.862676
0.811258
0.836177
0.754967
0.848635
0.799065
1.753700
285.108000
7
0.311800
0.680506
0.393409
0.980891
0.680386
0.794294
0.732941
0.676375
0.784535
0.726451
0.819069
0.776163
0.821538
0.798206
0.878403
0.801325
0.838095
0.778032
0.843672
0.809524
1.774600
281.750000
8
0.210300
0.790352
0.394818
0.982375
0.735036
0.756006
0.745374
0.742901
0.726727
0.734725
0.816817
0.742782
0.870769
0.801700
0.853195
0.817881
0.835165
0.836022
0.771712
0.802581
1.779400
280.986000
9
0.152800
0.848886
0.395077
0.982781
0.723803
0.783033
0.752254
0.718001
0.787538
0.751164
0.813814
0.815789
0.763077
0.788553
0.826645
0.852649
0.839446
0.792593
0.796526
0.794554
1.799700
277.817000
10
0.152800
0.971851
0.398642
0.982766
0.719093
0.786036
0.751076
0.717694
0.813063
0.762408
0.807808
0.760563
0.830769
0.794118
0.886320
0.761589
0.819234
0.755459
0.858561
0.803717
1.785400
280.054000
11
0.115100
1.135353
0.386189
0.983609
0.716578
0.804805
0.758133
0.733020
0.818318
0.773324
0.794294
0.779762
0.806154
0.792738
0.896341
0.730132
0.804745
0.704365
0.880893
0.782800
1.782800
280.459000
12
0.085900
0.993660
0.419968
0.984078
0.726662
0.812312
0.767104
0.732981
0.832583
0.779613
0.819820
0.815047
0.800000
0.807453
0.825197
0.867550
0.845843
0.814815
0.764268
0.788732
1.794600
278.615000
13
0.085900
1.026248
0.411750
0.983125
0.711985
0.829580
0.766297
0.687682
0.884384
0.773727
0.811562
0.788571
0.849231
0.817778
0.873840
0.779801
0.824147
0.753950
0.828784
0.789598
1.760800
283.959000
14
0.070800
1.142620
0.408904
0.985297
0.770818
0.785285
0.777984
0.763598
0.822072
0.791757
0.804805
0.751351
0.855385
0.800000
0.858696
0.784768
0.820069
0.780488
0.794045
0.787208
1.790800
279.198000
15
0.062800
1.141223
0.397795
0.984859
0.738420
0.813814
0.774286
0.756592
0.840090
0.796158
0.797297
0.806557
0.756923
0.780952
0.848214
0.786424
0.816151
0.730193
0.846154
0.783908
1.800400
277.722000
16
0.062800
1.261426
0.421695
0.985203
0.744990
0.809309
0.775819
0.773203
0.831832
0.801447
0.814565
0.763006
0.812308
0.786885
0.837134
0.850993
0.844007
0.825269
0.761787
0.792258
1.808500
276.475000
17
0.057400
1.425968
0.406145
0.985234
0.752606
0.813063
0.781667
0.761002
0.843844
0.800285
0.798799
0.787425
0.809231
0.798179
0.870476
0.756623
0.809566
0.727273
0.853598
0.785388
1.804600
277.064000
18
0.050500
1.417094
0.414896
0.985531
0.757322
0.815315
0.785249
0.754980
0.853604
0.801268
0.805556
0.756233
0.840000
0.795918
0.886973
0.766556
0.822380
0.750557
0.836228
0.791080
1.807100
276.682000
19
0.050500
1.370741
0.410895
0.985500
0.750687
0.820571
0.784075
0.752802
0.857357
0.801685
0.804054
0.772455
0.793846
0.783005
0.866055
0.781457
0.821584
0.752759
0.846154
0.796729
1.817800
275.052000
20
0.050800
1.269680
0.427258
0.985750
0.761871
0.795045
0.778104
0.766260
0.849099
0.805556
0.816066
0.792793
0.812308
0.802432
0.813272
0.872517
0.841853
0.843305
0.734491
0.785146
1.770800
282.360000
21
0.042700
1.373681
0.443625
0.985750
0.739837
0.819820
0.777778
0.761747
0.864114
0.809708
0.824324
0.791541
0.806154
0.798780
0.840650
0.855960
0.848236
0.826425
0.791563
0.808619
1.809400
276.332000
22
0.042700
1.320477
0.438937
0.985906
0.761803
0.799550
0.780220
0.782394
0.834084
0.807413
0.822072
0.818770
0.778462
0.798107
0.830400
0.859272
0.844589
0.811558
0.801489
0.806492
1.826300
273.781000
23
0.038400
1.403632
0.444728
0.986078
0.756250
0.817568
0.785714
0.775051
0.853604
0.812433
0.825075
0.848276
0.756923
0.800000
0.822257
0.880795
0.850520
0.812658
0.796526
0.804511
1.763700
283.503000
24
0.035800
1.466172
0.437135
0.986297
0.758766
0.812312
0.784627
0.776119
0.858859
0.815396
0.814565
0.793510
0.827692
0.810241
0.876611
0.788079
0.829991
0.755556
0.843672
0.797186
1.881400
265.760000
25
0.035700
1.437172
0.436925
0.986453
0.767065
0.818318
0.791863
0.775874
0.849850
0.811179
0.819069
0.814103
0.781538
0.797488
0.836334
0.846026
0.841152
0.797066
0.808933
0.802956
1.823300
274.229000
26
0.035700
1.501757
0.428386
0.986516
0.777614
0.792793
0.785130
0.789773
0.834835
0.811679
0.812312
0.787879
0.800000
0.793893
0.850953
0.812914
0.831499
0.778824
0.821340
0.799517
1.809900
276.263000
27
0.033900
1.640986
0.431580
0.986484
0.761672
0.820571
0.790025
0.779442
0.859610
0.817565
0.813063
0.801887
0.784615
0.793157
0.834163
0.832781
0.833471
0.790754
0.806452
0.798526
1.837100
272.163000
28
0.027900
1.660942
0.417174
0.986219
0.762500
0.824324
0.792208
0.771018
0.846847
0.807156
0.807057
0.807818
0.763077
0.784810
0.872029
0.789735
0.828844
0.732218
0.868486
0.794552
1.833100
272.767000
29
0.027900
1.593555
0.439876
0.986422
0.754589
0.833333
0.792009
0.762279
0.873874
0.814271
0.819069
0.791541
0.806154
0.798780
0.842809
0.834437
0.838602
0.806452
0.806452
0.806452
1.825800
273.850000
30
0.031100
1.634480
0.429155
0.986891
0.769391
0.834084
0.800432
0.783784
0.849099
0.815135
0.812312
0.775148
0.806154
0.790347
0.843803
0.822848
0.833194
0.797531
0.801489
0.799505
1.819600
274.788000
31
0.029600
1.771458
0.422438
0.986594
0.764666
0.841592
0.801287
0.770492
0.846847
0.806867
0.807057
0.760989
0.852308
0.804064
0.876190
0.761589
0.814880
0.762980
0.838710
0.799054
1.917700
260.732000
32
0.029600
1.604859
0.437154
0.986812
0.768802
0.828829
0.797688
0.783727
0.846096
0.813718
0.816817
0.789157
0.806154
0.797565
0.867857
0.804636
0.835052
0.772727
0.843672
0.806643
1.860600
268.728000
33
0.032500
1.675310
0.444260
0.986625
0.771348
0.820571
0.795198
0.782456
0.837087
0.808850
0.822823
0.774929
0.836923
0.804734
0.867958
0.816225
0.841297
0.801453
0.821340
0.811275
1.875100
266.655000
34
0.023600
1.821172
0.435406
0.986844
0.780101
0.812312
0.795881
0.788652
0.834835
0.811087
0.814565
0.740260
0.876923
0.802817
0.890595
0.768212
0.824889
0.788732
0.833747
0.810615
1.839900
271.751000
35
0.023600
1.630846
0.452027
0.986734
0.786921
0.804054
0.795395
0.791424
0.817568
0.804284
0.830330
0.788406
0.836923
0.811940
0.850082
0.854305
0.852188
0.836842
0.789082
0.812261
1.866100
267.935000
36
0.023500
1.696693
0.434577
0.986953
0.790087
0.813814
0.801775
0.783498
0.834084
0.808000
0.818318
0.837288
0.760000
0.796774
0.844221
0.834437
0.839301
0.770455
0.841191
0.804270
1.909300
261.869000
37
0.020700
1.783487
0.436999
0.987156
0.787373
0.814565
0.800738
0.792640
0.840841
0.816029
0.816817
0.814103
0.781538
0.797488
0.852740
0.824503
0.838384
0.770642
0.833747
0.800954
1.851100
270.108000
38
0.020700
1.757433
0.431940
0.986969
0.788012
0.809309
0.798519
0.793377
0.827327
0.809996
0.813814
0.811321
0.793846
0.802488
0.860963
0.799669
0.829185
0.757174
0.851117
0.801402
1.893000
264.129000
39
0.019900
1.710060
0.445421
0.987031
0.784838
0.816066
0.800147
0.788502
0.834084
0.810653
0.824324
0.824675
0.781538
0.802528
0.861538
0.834437
0.847771
0.774487
0.843672
0.807601
1.913500
261.300000
40
0.021100
1.679544
0.447749
0.987078
0.788971
0.805556
0.797177
0.804769
0.810811
0.807779
0.824324
0.817901
0.815385
0.816641
0.862369
0.819536
0.840407
0.778802
0.838710
0.807646
1.989600
251.305000
41
0.021100
1.772447
0.439737
0.987078
0.795522
0.800300
0.797904
0.805078
0.809309
0.807188
0.821321
0.809969
0.800000
0.804954
0.857877
0.829470
0.843434
0.779859
0.826303
0.802410
2.427700
205.959000
42
0.017800
1.769264
0.437806
0.986609
0.781726
0.809309
0.795278
0.782361
0.825826
0.803506
0.822823
0.838926
0.769231
0.802568
0.826498
0.867550
0.846527
0.805000
0.799007
0.801993
1.814600
275.537000
43
0.018600
1.792316
0.445057
0.986906
0.794449
0.795045
0.794747
0.801788
0.807808
0.804787
0.826577
0.833333
0.784615
0.808241
0.833068
0.867550
0.849959
0.811083
0.799007
0.805000
1.871400
267.178000
44
0.018600
1.717856
0.443054
0.987062
0.790715
0.805556
0.798066
0.798687
0.822072
0.810211
0.820571
0.825949
0.803077
0.814353
0.845763
0.826159
0.835846
0.781690
0.826303
0.803378
3.030500
164.987000
45
0.016800
1.719061
0.449725
0.987000
0.799849
0.795045
0.797440
0.811450
0.798048
0.804693
0.828829
0.833866
0.803077
0.818182
0.840580
0.864238
0.852245
0.806533
0.796526
0.801498
1.996600
250.426000
46
0.016200
1.789883
0.441381
0.986750
0.784571
0.809309
0.796748
0.800300
0.800300
0.800300
0.825075
0.842809
0.775385
0.807692
0.840131
0.852649
0.846343
0.790476
0.823821
0.806804
1.909100
261.905000
47
0.016200
1.738527
0.448426
0.986875
0.777936
0.820571
0.798685
0.789964
0.827327
0.808214
0.825826
0.820755
0.803077
0.811820
0.857143
0.834437
0.845638
0.786385
0.831266
0.808203
1.898800
263.329000
48
0.015000
1.759438
0.448316
0.986688
0.782734
0.816817
0.799412
0.790087
0.813814
0.801775
0.828078
0.797619
0.824615
0.810893
0.858844
0.836093
0.847315
0.808824
0.818859
0.813810
2.141100
233.524000
49
0.014300
1.739141
0.438471
0.987031
0.785869
0.818318
0.801765
0.793054
0.822823
0.807664
0.822072
0.810897
0.778462
0.794349
0.852596
0.842715
0.847627
0.787234
0.826303
0.806295
2.268300
220.431000
50
0.015200
1.806629
0.437686
0.986938
0.781585
0.822072
0.801317
0.790123
0.816817
0.803248
0.822072
0.832237
0.778462
0.804452
0.832528
0.855960
0.844082
0.798526
0.806452
0.802469
2.116600
236.230000
51
0.015200
1.751909
0.448987
0.987078
0.783803
0.813814
0.798527
0.795488
0.820571
0.807834
0.827327
0.806061
0.818462
0.812214
0.839806
0.859272
0.849427
0.825521
0.786600
0.805591
2.417800
206.803000
52
0.013600
1.756309
0.446924
0.987359
0.792522
0.811562
0.801929
0.799267
0.819069
0.809047
0.825826
0.807339
0.812308
0.809816
0.842020
0.855960
0.848933
0.815857
0.791563
0.803526
2.290800
218.269000
53
0.013600
1.771644
0.449768
0.987203
0.793662
0.808559
0.801041
0.802985
0.807808
0.805389
0.829580
0.804281
0.809231
0.806748
0.857143
0.854305
0.855721
0.808933
0.808933
0.808933
2.209900
226.255000
54
0.013600
1.794925
0.450868
0.987203
0.789627
0.811562
0.800444
0.803254
0.815315
0.809240
0.827327
0.801802
0.821538
0.811550
0.852596
0.842715
0.847627
0.810945
0.808933
0.809938
2.327500
214.822000
55
0.011700
1.794530
0.443402
0.987266
0.801815
0.795796
0.798794
0.805556
0.805556
0.805556
0.825075
0.832787
0.781538
0.806349
0.837097
0.859272
0.848039
0.800983
0.808933
0.804938
2.170300
230.387000
56
0.013500
1.799630
0.439336
0.986969
0.790923
0.798048
0.794469
0.797912
0.803303
0.800599
0.824324
0.799392
0.809231
0.804281
0.847934
0.849338
0.848635
0.809045
0.799007
0.803995
3.287800
152.079000
57
0.013500
1.805381
0.440039
0.987062
0.793208
0.789039
0.791118
0.812117
0.795045
0.803490
0.824324
0.808642
0.806154
0.807396
0.830696
0.869205
0.849515
0.827128
0.771712
0.798460
2.088400
239.419000
58
0.012400
1.781618
0.446814
0.987047
0.791420
0.803303
0.797317
0.793860
0.815315
0.804444
0.825826
0.831210
0.803077
0.816901
0.840722
0.847682
0.844188
0.799511
0.811414
0.805419
2.302100
217.195000
59
0.011700
1.823947
0.449294
0.986969
0.789668
0.803303
0.796427
0.798369
0.808559
0.803432
0.828078
0.811550
0.821538
0.816514
0.841503
0.852649
0.847039
0.820972
0.796526
0.808564
2.424200
206.255000
60
0.011700
1.827804
0.444374
0.986984
0.794872
0.791291
0.793078
0.810225
0.785285
0.797560
0.827327
0.818462
0.818462
0.818462
0.832268
0.862583
0.847154
0.826772
0.781638
0.803571
2.178200
229.545000
61
0.011600
1.809431
0.452122
0.987000
0.802752
0.788288
0.795455
0.813917
0.781532
0.797396
0.831832
0.828125
0.815385
0.821705
0.842532
0.859272
0.850820
0.818182
0.803970
0.811014
2.068400
241.730000
62
0.010800
1.830662
0.449956
0.987328
0.798200
0.798799
0.798499
0.806912
0.806306
0.806609
0.828078
0.828571
0.803077
0.815625
0.832803
0.865894
0.849026
0.820051
0.791563
0.805556
2.177800
229.585000
63
0.010800
1.831762
0.447547
0.987203
0.794074
0.804805
0.799403
0.805263
0.804054
0.804658
0.826577
0.823899
0.806154
0.814930
0.836305
0.854305
0.845209
0.813602
0.801489
0.807500
2.326200
214.944000
64
0.010500
1.829183
0.442622
0.987078
0.792692
0.798048
0.795361
0.807780
0.795045
0.801362
0.825075
0.808642
0.806154
0.807396
0.843234
0.846026
0.844628
0.810945
0.808933
0.809938
2.899100
172.470000
65
0.010200
1.798383
0.440475
0.987062
0.791356
0.797297
0.794316
0.809670
0.792042
0.800759
0.824324
0.806154
0.806154
0.806154
0.837398
0.852649
0.844955
0.818878
0.796526
0.807547
2.067400
241.854000
66
0.010200
1.800962
0.451185
0.987172
0.785922
0.813063
0.799262
0.795903
0.816817
0.806225
0.828829
0.828025
0.800000
0.813772
0.843393
0.855960
0.849630
0.807407
0.811414
0.809406
2.146000
232.988000
67
0.010700
1.809239
0.449513
0.987297
0.792873
0.801802
0.797312
0.810280
0.804805
0.807533
0.826577
0.800604
0.815385
0.807927
0.856655
0.831126
0.843697
0.804819
0.828784
0.816626
2.300700
217.322000
68
0.009400
1.818187
0.440534
0.987187
0.791233
0.799550
0.795370
0.807721
0.801051
0.804372
0.822823
0.827922
0.784615
0.805687
0.847826
0.839404
0.843594
0.784038
0.828784
0.805790
2.070600
241.479000
69
0.009400
1.839509
0.449072
0.987281
0.797156
0.799550
0.798351
0.811972
0.794294
0.803036
0.828829
0.838816
0.784615
0.810811
0.849587
0.850993
0.850289
0.791962
0.831266
0.811138
2.190800
228.229000
70
0.009500
1.833614
0.440818
0.987141
0.787187
0.802553
0.794796
0.805745
0.800300
0.803013
0.823574
0.820513
0.787692
0.803768
0.850420
0.837748
0.844037
0.788235
0.831266
0.809179
2.240400
223.177000
71
0.009400
1.792183
0.437578
0.987266
0.795609
0.789039
0.792311
0.814815
0.792793
0.803653
0.821321
0.805556
0.803077
0.804314
0.841322
0.842715
0.842018
0.803970
0.803970
0.803970
2.265900
220.660000
72
0.009400
1.809693
0.446543
0.987187
0.794623
0.798799
0.796705
0.805891
0.801051
0.803464
0.826577
0.805471
0.815385
0.810398
0.856176
0.837748
0.846862
0.800971
0.818859
0.809816
1.897500
263.507000
73
0.009300
1.782456
0.439629
0.987141
0.793876
0.798048
0.795957
0.812159
0.782282
0.796941
0.825075
0.823718
0.790769
0.806907
0.844884
0.847682
0.846281
0.797101
0.818859
0.807834
2.327800
214.797000
74
0.009500
1.808447
0.442311
0.987141
0.789903
0.798799
0.794326
0.811583
0.789039
0.800152
0.825075
0.804281
0.809231
0.806748
0.856899
0.832781
0.844668
0.796651
0.826303
0.811206
2.126700
235.101000
75
0.008800
1.795766
0.450452
0.987062
0.793284
0.798048
0.795659
0.807663
0.791291
0.799393
0.830330
0.824841
0.796923
0.810642
0.854027
0.842715
0.848333
0.800948
0.838710
0.819394
2.204600
226.797000
76
0.008800
1.804343
0.448934
0.987062
0.790526
0.801802
0.796124
0.805471
0.795796
0.800604
0.828829
0.832258
0.793846
0.812598
0.850000
0.844371
0.847176
0.796209
0.833747
0.814545
2.178600
229.509000
77
0.008800
1.799190
0.445429
0.987016
0.789513
0.802553
0.795979
0.805513
0.789790
0.797574
0.828078
0.824841
0.796923
0.810642
0.847682
0.847682
0.847682
0.801932
0.823821
0.812729
2.220700
225.156000
78
0.009000
1.787097
0.444187
0.987000
0.792354
0.793544
0.792948
0.807220
0.789039
0.798026
0.827327
0.815625
0.803077
0.809302
0.847682
0.847682
0.847682
0.806373
0.816377
0.811344
2.230900
224.124000
79
0.009000
1.790349
0.442649
0.986969
0.791480
0.795045
0.793258
0.806452
0.788288
0.797267
0.826577
0.817610
0.800000
0.808709
0.847430
0.846026
0.846727
0.802920
0.818859
0.810811
2.184300
228.903000
80
0.008700
1.792766
0.445089
0.986984
0.791636
0.795796
0.793710
0.805833
0.788288
0.796964
0.828078
0.824841
0.796923
0.810642
0.847682
0.847682
0.847682
0.801932
0.823821
0.812729
2.210100
226.236000
TrainOutput(global_step=5440, training_loss=0.07378851720405852, metrics={'train_runtime': 4319.8751, 'train_samples_per_second': 1.259, 'total_flos': 3.680907436228608e+16, 'epoch': 80.0, 'init_mem_cpu_alloc_delta': 2128576512, 'init_mem_gpu_alloc_delta': 558472192, 'init_mem_cpu_peaked_delta': 380325888, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 329842688, 'train_mem_gpu_alloc_delta': 2244640256, 'train_mem_cpu_peaked_delta': 720216064, 'train_mem_gpu_peaked_delta': 8832347136})
Code
model.save_pretrained(Path("/data/blog/2021-07-20-aspect-sentiment-metrics/model" ))
So that is a monstrous amount of output. It’s interesting for me but much too much for this blog post.
The summary is that it has an entity end F1 score of around 0.8, and a sentiment accuracy around 0.82. The SOTA scores for the aspect sentiment on this dataset is between 0.83 and 0.85 (source ). It’s not terrible.
Evaluation
I hooked the previous version of this up to a gradio and people had a play with it. One of the things that was a let down was the app would error if there were no entities in the text. This is due to the specific order of the operations, so that’s something that I need to fix this time around.
Code
sentiment_names = ["negative" , "neutral" , "positive" ]
def aspect_sentiment(text: str ) -> List[Tuple[str , str ]]:
tokenized_text = tokenizer(text, return_tensors= "pt" )
with torch.no_grad():
input_ids = tokenized_text["input_ids" ].to(model.device)
output = model(input_ids= input_ids)[0 ]
entity_boundaries = output[:, :, :2 ] > 0.
entity_mask = (output[:, :, 1 ] > 0. ).flatten()
# performing argmax early to avoid problem with no entity predictions
entity_sentiment = (
output.reshape(- 1 , 5 )
[:, 2 :]
.argmax(dim=- 1 )
[entity_mask]
)
entities = tokenizer.batch_decode([
[input_id]
for input_id, boundaries in zip (tokenized_text["input_ids" ][0 ], entity_boundaries[0 ])
if True in boundaries
])
return [
(entity, sentiment_names[sentiment])
for entity, sentiment in zip (entities, entity_sentiment.tolist())
]
Code
aspect_sentiment("the hotel had oversold the rooms, there was no place for us" )
Code
aspect_sentiment("The food was terrible, but the view was fantastic" )
[(' food', 'negative'), (' view', 'positive')]
So that works. If you’re interested in using gradio, you can see the code required below:
Code
#hide_output
import gradio as gr
def gradio_sentiment(text: str ) -> str :
results = aspect_sentiment(text)
if not results:
return "No entities found"
return " \n " .join(f"entity: { entity} , sentiment: { sentiment} " for entity, sentiment in results)
gr.Interface(
fn= gradio_sentiment,
inputs= ["textbox" ],
outputs= "text"
).launch(share= True )
Running locally at: http://127.0.0.1:7860/
This share link will expire in 24 hours. If you need a permanent link, visit: https://gradio.app/introducing-hosted (NEW!)
Running on External URL: https://33922.gradio.app
Interface loading below...
Tip: Add interpretation to your model by simply adding `interpretation="default"` to `Interface()`
(<Flask 'gradio.networking'>,
'http://127.0.0.1:7860/',
'https://33922.gradio.app')
That’s all fine. What happens if we look at the sentiment per word?
Code
sentiment_names = ["negative" , "neutral" , "positive" ]
def word_sentiment(text: str ) -> List[Tuple[str , str ]]:
tokenized_text = tokenizer(text, return_tensors= "pt" )
with torch.no_grad():
input_ids = tokenized_text["input_ids" ].to(model.device)
output = model(input_ids= input_ids)[0 ]
token_sentiment = (
output.reshape(- 1 , 5 )
[:, 2 :]
.softmax(dim=- 1 )
)
return [
{
"token" : tokenizer.decode(token),
"negative" : round (sentiment[0 ], 3 ),
"neutral" : round (sentiment[1 ], 3 ),
"positive" : round (sentiment[2 ], 3 ),
}
for token, sentiment in zip (input_ids[0 ].tolist(), token_sentiment.tolist())
]
Code
word_sentiment("The food was terrible but the view was fantastic" )
[{'token': '<s>', 'negative': 0.097, 'neutral': 0.042, 'positive': 0.861},
{'token': 'The', 'negative': 0.999, 'neutral': 0.0, 'positive': 0.001},
{'token': ' food', 'negative': 1.0, 'neutral': 0.0, 'positive': 0.0},
{'token': ' was', 'negative': 1.0, 'neutral': 0.0, 'positive': 0.0},
{'token': ' terrible',
'negative': 0.993,
'neutral': 0.005,
'positive': 0.002},
{'token': ' but', 'negative': 0.93, 'neutral': 0.035, 'positive': 0.036},
{'token': ' the', 'negative': 0.0, 'neutral': 0.0, 'positive': 1.0},
{'token': ' view', 'negative': 0.0, 'neutral': 0.0, 'positive': 1.0},
{'token': ' was', 'negative': 0.0, 'neutral': 0.0, 'positive': 1.0},
{'token': ' fantastic',
'negative': 0.001,
'neutral': 0.005,
'positive': 0.994},
{'token': '</s>', 'negative': 0.023, 'neutral': 0.159, 'positive': 0.819}]
This output suggests that spans of the text share sentiment levels. It’s quite hard to decode this output so visualizing it would help. I’m going to translate this to something that can color the tokens.
This is an extremely complex way of coloring the text, it’s what I have available to me right now though. The text will become a graphviz graph where each node is a token with a color related to the sentiment - red is mostly negative, blue is mostly neutral and green is mostly positive.
Code
show_word_sentiment("The food was terrible but the view was fantastic" )
Code
show_word_sentiment("Marriott food was better than Hilton's, but the Hilton view was fantastic." )
Potential Improvements
After making this there are a few ideas I have about improving it. Broadly the model does two things, and the improvements are directed to either entity extraction or token sentiment.
Token Sentiment
Checking the sentiment only for the end token is an artifact of the sequence to sequence approach that was first tried. Instead if the sentiment label is applied to every token in the entity the token sentiment allocation could be more consistent.
It may be possible to pretrain the model on document level sentiment and then refine it. The amount of document level sentiment data is considerably larger so this might help.
Token Relationships
In many ways the sentiment in the text can be viewed as a relationship between different entities in the text. For example the text Marriott food was better than Hilton’s is a comparison between Marriott food and Hilton food. With a better idea of the relationships it might be possible to represent the relationship as a vector, which could then be clustered to find instances of similar relationships between entity pairs.
This would be achievable by using cosine similiarity loss - similar to semantic search. To efficiently generate the actual pairing it might be nice to evaluate something similar to CLIP, where the image and text were processed to produce separate vectors, and the final classification was a dot product of this. This also reminds me of the adafactor / adam difference where the individual values can be appropriately estimated from the dot product.
SOTA Review
I need to evaluate the different approaches that have been published. The MAMS SOTA results are better than what this approach is currently achieving (this uses the smaller BART model though).