Someone asked if it would be possible to predict categories for text, where the training data uses hashtags as the labels. I thought this was an interesting enough problem to have a go at, so here we are.
The plan is to determine the top 50 hashtags for one day of twitter data and then train a multi label model to predict the hashtags for a given tweet. I need this to predict multiple labels so it will require a slight adjustment to the MODELForSequenceClassification models that huggingface provides (they can do single label classification or regression problems).
So let’s start by evaluating the data. It’s still downloading and it’s rather secretive so I will only be showing aggregate stats for it.
Code
from pathlib import Path
import pandas as pd
DATA_FOLDER = Path("/data/blog/2021-07-21-tweet-hashtag/" )
TWEET_FILES = sorted ((DATA_FOLDER / "raw" ).glob("*.gz.parquet" ))
df = pd.read_parquet(TWEET_FILES[0 ])
df = df[["contents" ]].rename(columns= {"contents" : "text" })
Code
0 hannie go nom https://t.co/8lTgi2oPXe
1 RT @MuhaddasI Oomfs convinced me to post this ...
2 O verdadeiro não está entre nós.
3 Happy birthday to the one who’s had my heart s...
4 RT @momomosumomo517 💗❤️💛💜\nももクロ・スケジュール メモ (202...
Name: text, dtype: object
So I need to extract the hashtags from this text.
Code
import re
HASHTAG_PATTERN = re.compile (r"#([^\s]+)" )
df.text.str .findall(HASHTAG_PATTERN).head()
0 []
1 []
2 []
3 [HappyBirthdayJustinBieber]
4 [ももクロ]
Name: text, dtype: object
Code
df["hashtags" ] = df.text.str .findall(HASHTAG_PATTERN)
df.head()
text
hashtags
0
hannie go nom https://t.co/8lTgi2oPXe
[]
1
RT @MuhaddasI Oomfs convinced me to post this ...
[]
2
O verdadeiro não está entre nós.
[]
3
Happy birthday to the one who’s had my heart s...
[HappyBirthdayJustinBieber]
4
RT @momomosumomo517 💗❤️💛💜\nももクロ・スケジュール メモ (202...
[ももクロ]
Code
(
df.hashtags
.explode()
.value_counts()
.sort_values(ascending= False )
[:10 ]
)
YirmiHezimet60BinHak 551
NoticeTurkishStudents 535
ROSÉ 534
WhatsHappeningInMyanmar 411
3MPostsForSidShuklaOnIG 392
로제 391
3YearsWithHopeWorld 319
Mar1Coup 307
ShopeeJAMBORE 219
1 196
Name: hashtags, dtype: int64
So these might be the top hashtags in my dataset however I question if they are really categories. This also reminds me that I’m going to need a multi lingual model.
I want to make this easy for me and I want to be able to interpret the results. So I’m going to restrict this to english and single word hashtags.
Code
df = pd.read_parquet(TWEET_FILES[0 ])
df = df[df.language == "en" ]
df = df[["contents" ]].rename(columns= {"contents" : "text" })
SIMPLE_HASHTAG_PATTERN = re.compile (r"#([A-Za-z][a-z]+)" )
df["hashtags" ] = (
df.text.str
.findall(SIMPLE_HASHTAG_PATTERN)
.apply (lambda hashtags: [hashtag.casefold() for hashtag in hashtags])
)
df.hashtags.head()
0 []
1 []
3 [happy]
5 []
7 []
Name: hashtags, dtype: object
Code
(
df.hashtags
.explode()
.value_counts()
.sort_values(ascending= False )
[:10 ]
)
whats 471
mar 292
nengi 135
vote 108
happy 103
sidharth 100
march 98
hope 98
milk 79
blueside 74
Name: hashtags, dtype: int64
This might be better? I need to filter this down to actual categories. It might be necessary to infer the category from the hashtag through a mapping.
For now I’m inclined to just use a dictionary.
Code
DICTIONARY_WORDS = {
word.casefold().strip()
for word in Path("/usr/share/dict/british-english" ).read_text().splitlines()
}
hashtags = (
df.hashtags
.explode()
.value_counts()
.sort_values(ascending= False )
.to_frame()
.reset_index()
.rename(columns= {"index" : "hashtag" , "hashtags" : "count" })
)
hashtags[hashtags.hashtag.isin(DICTIONARY_WORDS)][:10 ]
hashtag
count
0
whats
471
1
mar
292
3
vote
108
4
happy
103
6
march
98
7
hope
98
8
milk
79
10
friday
73
11
blue
69
12
myanmar
64
I’ve discussed this more and it may well be best to start with a list of categories that might be hashtags. Starting with them would ensure that the categories are reasonably high quality. To be able to do this effectively I need to load all of the data, as there just arn’t enough hashtags in this single file.
Code
from tqdm.auto import tqdm
SIMPLE_HASHTAG_PATTERN = re.compile (r"#([A-Za-z][a-z]+)" )
def load_file(path: Path) -> pd.DataFrame:
df = pd.read_parquet(path)
df = df[df.language == "en" ]
df = df[["contents" ]].rename(columns= {"contents" : "text" })
df["hashtags" ] = (
df.text.str
.findall(SIMPLE_HASHTAG_PATTERN)
.apply (lambda hashtags: [hashtag.casefold() for hashtag in hashtags])
)
return df
df = pd.concat([
load_file(path)
for path in tqdm(TWEET_FILES[:1_000 ])
])
Code
text
hashtags
0
hannie go nom https://t.co/8lTgi2oPXe
[]
1
RT @MuhaddasI Oomfs convinced me to post this ...
[]
3
Happy birthday to the one who’s had my heart s...
[happy]
5
make it make sense
[]
7
RT @Findomgoddess25 Here is my verification vi...
[]
...
...
...
99970
Due to the soggy conditions at WMS, Girls’s Cr...
[]
99974
RT @httphearts i pray that march will be full ...
[]
99975
RT @mintaiey art is hard https://t.co/93GBQieXpi
[]
99976
RT @bontlexn In March I will attract money. I ...
[]
99982
i might do a picarto stream later today becaus...
[]
33800080 rows × 2 columns
Code
df.to_parquet(DATA_FOLDER / "english-hashtags.gz.parquet" , compression= "gzip" )
Code
df = pd.read_parquet(DATA_FOLDER / "english-hashtags.gz.parquet" )
Code
hashtags = (
df.hashtags
.explode()
.value_counts()
.sort_values(ascending= False )
.to_frame()
.reset_index()
.rename(columns= {"index" : "hashtag" , "hashtags" : "count" })
)
Code
"sport" in hashtags.hashtag.values
Code
hashtags.hashtag.head(20 )
0 whats
1 mar
2 hope
3 the
4 vote
5 march
6 stop
7 happy
8 shopee
9 blue
10 daechwita
11 way
12 friday
13 nengi
14 actor
15 golden
16 new
17 bitcoin
18 milk
19 sidharth
Name: hashtag, dtype: object
Code
interests = {
"bts" ,
"cricket" ,
"arsenal" ,
"twitch" ,
"author" ,
"wrestling" ,
"baseball" ,
"crypto" ,
"sports" ,
"hockey" ,
"anime" ,
"fortnite" ,
"f1" ,
"golf" ,
"travel" ,
"mma" ,
"soccer" ,
"boxing" ,
"artist" ,
"wine" ,
"beer" ,
"photographer" ,
"softball" ,
"vegan" ,
"football" ,
"tennis" ,
"weather" ,
"fitness" ,
"cannabis" ,
"rugby" ,
"hair" ,
"developer" ,
"congress" ,
"music" ,
"disney" ,
"yankees" ,
"basketball" ,
"overwatch" ,
"makeup" ,
"cycling" ,
"theatre" ,
"security"
}
Code
top_hashtags = hashtags[hashtags.hashtag.isin(interests)][:10 ]
top_interests = set (top_hashtags.hashtag.values)
top_hashtags
hashtag
count
20
music
40945
49
crypto
20641
105
twitch
10933
154
artist
8037
207
disney
5982
293
travel
4492
326
fortnite
4076
330
anime
3995
469
fitness
2976
499
security
2808
So of my list there are 10 hashtags that are present in the dataset in reasonable volumes. This is enough to get going with, if it can’t predict these then there is no hope. Regarding the dataset I think that any tweet with these hashtags will be included and the ratio between the most and least common isn’t terrible (about 15x).
So let’s select the appropriate rows and start preprocessing the text. The hashtags have to be removed from the text otherwise the task would be much too easy!
Code
interesting_rows = (
df[
df.hashtags.apply (
lambda hashtags: bool (top_interests.intersection(hashtags))
)
].copy()
)
interesting_rows["hashtags" ] = interesting_rows.hashtags.apply (
lambda hashtags: sorted ({hashtag for hashtag in hashtags if hashtag in top_interests})
)
Code
text
hashtags
135
RT @SwaapCrypto #Ripple fintech giant has agai...
[crypto]
1712
RT @out_h2 Tales of Berseria continues! Happy ...
[twitch]
1795
The Android 11 Privacy and Security Features Y...
[security]
2259
RT @LNdejje Who wants some support? \n1Like/Re...
[twitch]
3756
Live! come here! https://t.co/DwyA0a3Ng0\n\n#...
[twitch]
Code
import re
HASHTAG_PATTERN = re.compile (r"\s*#([^\s]+)\s*" )
interesting_rows["text" ] = interesting_rows.text.str .replace(HASHTAG_PATTERN, " " )
interesting_rows.head()
text
hashtags
135
RT @SwaapCrypto fintech giant has again unleas...
[crypto]
1712
RT @out_h2 Tales of Berseria continues! Happy ...
[twitch]
1795
The Android 11 Privacy and Security Features Y...
[security]
2259
RT @LNdejje Who wants some support? \n1Like/Re...
[twitch]
3756
Live! come here! https://t.co/DwyA0a3Ng0 ...
[twitch]
I’m not prone to over thinking things, so I’m going to assume that this worked perfectly and I can now train the model. The first step of this is to create the dataset.
Code
MODEL_NAME = "facebook/bart-base"
MAXIMUM_TOKEN_LENGTH = 128
BATCH_SIZE = 32
MAX_STEPS = 5_000
Code
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
Code
categories = sorted (interesting_rows.hashtags.explode().unique())
categories
['anime',
'artist',
'crypto',
'disney',
'fitness',
'fortnite',
'music',
'security',
'travel',
'twitch']
Code
from typing import *
def encode(row: Dict[str , Any]) -> Dict[str , Any]:
input_ids = tokenizer(
row["text" ],
max_length= MAXIMUM_TOKEN_LENGTH,
truncation= True ,
return_attention_mask= False ,
)["input_ids" ]
hashtags = set (row["hashtags" ])
label = [category in hashtags for category in categories]
return {"input_ids" : input_ids, "label" : label}
Code
from datasets import Dataset
ds = Dataset.from_pandas(interesting_rows)
ds = ds.remove_columns("__index_level_0__" )
ds = ds.map (encode)
split = ds.train_test_split(test_size= 1024 )
train_ds = split["train" ]
test_ds = split["test" ]
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Code
from transformers import BartPretrainedModel, BartModel, BartConfig
from transformers.models.bart.modeling_bart import BartClassificationHead
from transformers.modeling_outputs import Seq2SeqSequenceClassifierOutput
import torch
class BartForMultiLabelSequenceClassification(BartPretrainedModel):
def __init__ (self , config: BartConfig, ** kwargs):
config.num_labels = 10
super ().__init__ (config, ** kwargs)
self .model = BartModel(config)
self .classification_head = BartClassificationHead(
config.d_model,
config.d_model,
config.num_labels,
config.classifier_dropout,
)
self .model._init_weights(self .classification_head.dense)
self .model._init_weights(self .classification_head.out_proj)
def forward(
self ,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None ,
labels: Optional[torch.Tensor] = None ,
** kwargs,
) -> Seq2SeqSequenceClassifierOutput:
outputs = self .model(
input_ids,
attention_mask= attention_mask,
** kwargs
)
hidden_states = outputs[0 ] # last hidden state
eos_mask = input_ids.eq(self .config.eos_token_id)
if len (torch.unique(eos_mask.sum (1 ))) > 1 :
raise ValueError ("All examples must have the same number of <eos> tokens." )
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0 ), - 1 , hidden_states.size(- 1 ))[
:, - 1 , :
]
logits = self .classification_head(sentence_representation)
loss = None
if labels is not None :
# changed
loss = torch.nn.functional.binary_cross_entropy_with_logits(
# logits.view(-1), labels
logits, labels.float ()
)
return (loss, logits)
# return Seq2SeqSequenceClassifierOutput(
# loss=loss,
# logits=logits,
# past_key_values=outputs.past_key_values,
# decoder_hidden_states=outputs.decoder_hidden_states,
# decoder_attentions=outputs.decoder_attentions,
# cross_attentions=outputs.cross_attentions,
# encoder_last_hidden_state=outputs.encoder_last_hidden_state,
# encoder_hidden_states=outputs.encoder_hidden_states,
# encoder_attentions=outputs.encoder_attentions,
# )
Code
from typing import *
from sklearn.metrics import (
accuracy_score, precision_recall_fscore_support
)
from transformers import EvalPrediction
def compute_metrics(pred: EvalPrediction) -> Dict[str , float ]:
labels = pred.label_ids
predictions = (pred.predictions > 0 ).astype(int )
accuracy = accuracy_score(labels, predictions)
results = {
"accuracy" : accuracy
}
for name, precision, recall, fscore in zip (
categories,
* precision_recall_fscore_support(
labels, predictions, zero_division= 0.
)[:3 ]
):
results[f" { name} _precision" ] = precision
results[f" { name} _recall" ] = recall
results[f" { name} _f1" ] = fscore
return results
Code
model = BartForMultiLabelSequenceClassification.from_pretrained(MODEL_NAME)
Some weights of the model checkpoint at facebook/bart-base were not used when initializing BartForMultiLabelSequenceClassification: ['final_logits_bias']
- This IS expected if you are initializing BartForMultiLabelSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BartForMultiLabelSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BartForMultiLabelSequenceClassification were not initialized from the model checkpoint at facebook/bart-base and are newly initialized: ['classification_head.dense.weight', 'classification_head.dense.bias', 'classification_head.out_proj.weight', 'classification_head.out_proj.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Code
from pathlib import Path
from transformers import Trainer, TrainingArguments
MODEL_RUN_FOLDER = Path("/data/blog/2021-07-21-tweet-hashtag/runs" )
MODEL_RUN_FOLDER.mkdir(parents= True , exist_ok= True )
training_args = TrainingArguments(
report_to= [],
output_dir= MODEL_RUN_FOLDER / "output" ,
overwrite_output_dir= True ,
per_device_train_batch_size= BATCH_SIZE,
per_device_eval_batch_size= BATCH_SIZE,
learning_rate= 5e-5 ,
warmup_ratio= 0.06 ,
max_steps= MAX_STEPS,
evaluation_strategy= "steps" ,
logging_dir= MODEL_RUN_FOLDER / "output" ,
logging_steps= 100 ,
eval_steps= 100 ,
load_best_model_at_end= True ,
metric_for_best_model= "accuracy" ,
greater_is_better= True ,
)
trainer = Trainer(
model= model,
args= training_args,
train_dataset= train_ds,
eval_dataset= test_ds,
tokenizer= tokenizer,
compute_metrics= compute_metrics,
)
trainer.train()
[5000/5000 32:22, Epoch 1/2]
Step
Training Loss
Validation Loss
Accuracy
Anime Precision
Anime Recall
Anime F1
Artist Precision
Artist Recall
Artist F1
Crypto Precision
Crypto Recall
Crypto F1
Disney Precision
Disney Recall
Disney F1
Fitness Precision
Fitness Recall
Fitness F1
Fortnite Precision
Fortnite Recall
Fortnite F1
Music Precision
Music Recall
Music F1
Security Precision
Security Recall
Security F1
Travel Precision
Travel Recall
Travel F1
Twitch Precision
Twitch Recall
Twitch F1
Runtime
Samples Per Second
100
0.358600
0.190214
0.482422
0.000000
0.000000
0.000000
0.000000
0.000000
0.000000
0.968254
0.598039
0.739394
0.000000
0.000000
0.000000
0.000000
0.000000
0.000000
0.000000
0.000000
0.000000
0.898795
0.855505
0.876616
0.000000
0.000000
0.000000
0.000000
0.000000
0.000000
0.000000
0.000000
0.000000
2.761900
370.761000
200
0.157000
0.121153
0.707031
0.000000
0.000000
0.000000
0.666667
0.055556
0.102564
0.861905
0.887255
0.874396
1.000000
0.586957
0.739726
0.875000
0.451613
0.595745
0.000000
0.000000
0.000000
0.972569
0.894495
0.931900
1.000000
0.428571
0.600000
1.000000
0.090909
0.166667
0.760000
0.848214
0.801688
2.797000
366.109000
300
0.117200
0.106900
0.768555
0.714286
0.185185
0.294118
0.445946
0.458333
0.452055
0.915789
0.852941
0.883249
1.000000
0.782609
0.878049
0.882353
0.483871
0.625000
0.000000
0.000000
0.000000
0.968370
0.912844
0.939787
0.909091
0.714286
0.800000
0.878049
0.654545
0.750000
0.898876
0.714286
0.796020
2.981300
343.478000
400
0.104200
0.096251
0.786133
0.923077
0.444444
0.600000
0.537037
0.402778
0.460317
0.923469
0.887255
0.905000
0.971429
0.739130
0.839506
1.000000
0.354839
0.523810
0.000000
0.000000
0.000000
0.957143
0.922018
0.939252
0.638889
0.821429
0.718750
0.822222
0.672727
0.740000
0.897959
0.785714
0.838095
3.061600
334.466000
500
0.093300
0.089162
0.792969
0.516129
0.592593
0.551724
0.647059
0.305556
0.415094
0.952128
0.877451
0.913265
0.970588
0.717391
0.825000
0.894737
0.548387
0.680000
1.000000
0.037037
0.071429
0.971014
0.922018
0.945882
0.694444
0.892857
0.781250
0.878049
0.654545
0.750000
0.892157
0.812500
0.850467
3.033800
337.534000
600
0.084400
0.082491
0.815430
0.517241
0.555556
0.535714
0.729730
0.375000
0.495413
0.940887
0.936275
0.938575
0.925000
0.804348
0.860465
0.900000
0.580645
0.705882
1.000000
0.222222
0.363636
0.980488
0.922018
0.950355
0.851852
0.821429
0.836364
0.918919
0.618182
0.739130
0.902913
0.830357
0.865116
3.068200
333.749000
700
0.087300
0.079204
0.826172
0.631579
0.444444
0.521739
0.702703
0.361111
0.477064
0.958333
0.901961
0.929293
0.972222
0.760870
0.853659
0.944444
0.548387
0.693878
0.750000
0.333333
0.461538
0.978208
0.926606
0.951708
0.774194
0.857143
0.813559
0.930233
0.727273
0.816327
0.833333
0.937500
0.882353
3.038800
336.974000
800
0.082500
0.079864
0.812500
0.500000
0.592593
0.542373
0.700000
0.388889
0.500000
0.919048
0.946078
0.932367
1.000000
0.782609
0.878049
0.950000
0.612903
0.745098
0.900000
0.333333
0.486486
0.973301
0.919725
0.945755
0.827586
0.857143
0.842105
0.923077
0.654545
0.765957
0.953488
0.732143
0.828283
3.046500
336.126000
900
0.081000
0.073797
0.828125
0.714286
0.555556
0.625000
0.638889
0.319444
0.425926
0.972826
0.877451
0.922680
0.902439
0.804348
0.850575
1.000000
0.580645
0.734694
1.000000
0.370370
0.540541
0.955399
0.933486
0.944316
0.750000
0.964286
0.843750
0.951220
0.709091
0.812500
0.910714
0.910714
0.910714
3.028000
338.176000
1000
0.078200
0.071819
0.835938
0.933333
0.518519
0.666667
0.660714
0.513889
0.578125
0.969231
0.926471
0.947368
0.948718
0.804348
0.870588
0.833333
0.645161
0.727273
0.750000
0.333333
0.461538
0.964029
0.922018
0.942556
0.800000
0.857143
0.827586
0.906977
0.709091
0.795918
0.892857
0.892857
0.892857
2.998100
341.544000
1100
0.072100
0.072056
0.842773
0.652174
0.555556
0.600000
0.716981
0.527778
0.608000
0.940887
0.936275
0.938575
0.925000
0.804348
0.860465
0.850000
0.548387
0.666667
0.833333
0.370370
0.512821
0.971292
0.931193
0.950820
0.862069
0.892857
0.877193
0.911111
0.745455
0.820000
0.902655
0.910714
0.906667
3.005700
340.683000
1200
0.071100
0.075315
0.843750
0.666667
0.592593
0.627451
0.696429
0.541667
0.609375
0.897196
0.941176
0.918660
0.926829
0.826087
0.873563
1.000000
0.483871
0.652174
0.750000
0.333333
0.461538
0.966667
0.931193
0.948598
0.787879
0.928571
0.852459
0.814815
0.800000
0.807339
0.939394
0.830357
0.881517
2.995700
341.824000
1300
0.070000
0.069436
0.841797
0.727273
0.592593
0.653061
0.756757
0.388889
0.513761
0.949495
0.921569
0.935323
1.000000
0.782609
0.878049
0.952381
0.645161
0.769231
0.846154
0.407407
0.550000
0.955711
0.940367
0.947977
0.950000
0.678571
0.791667
0.807018
0.836364
0.821429
0.942857
0.883929
0.912442
3.020500
339.022000
1400
0.071800
0.066637
0.845703
0.761905
0.592593
0.666667
0.678571
0.527778
0.593750
0.974490
0.936275
0.955000
1.000000
0.760870
0.864198
0.950000
0.612903
0.745098
0.705882
0.444444
0.545455
0.976019
0.933486
0.954279
0.827586
0.857143
0.842105
0.906977
0.709091
0.795918
0.925234
0.883929
0.904110
3.048800
335.868000
1500
0.066400
0.069763
0.840820
0.708333
0.629630
0.666667
0.738095
0.430556
0.543860
0.969388
0.931373
0.950000
0.972973
0.782609
0.867470
0.900000
0.580645
0.705882
0.733333
0.407407
0.523810
0.980488
0.922018
0.950355
0.892857
0.892857
0.892857
0.716418
0.872727
0.786885
0.933333
0.875000
0.903226
3.058900
334.764000
1600
0.069400
0.067827
0.832031
0.833333
0.555556
0.666667
0.653846
0.472222
0.548387
0.949749
0.926471
0.937965
0.973684
0.804348
0.880952
0.944444
0.548387
0.693878
0.769231
0.370370
0.500000
0.969267
0.940367
0.954598
0.916667
0.785714
0.846154
0.923077
0.654545
0.765957
0.929293
0.821429
0.872038
3.007400
340.489000
1700
0.066000
0.072248
0.839844
0.695652
0.592593
0.640000
0.692308
0.500000
0.580645
0.932039
0.941176
0.936585
0.948718
0.804348
0.870588
0.894737
0.548387
0.680000
0.800000
0.444444
0.571429
0.958042
0.942661
0.950289
0.857143
0.857143
0.857143
0.888889
0.727273
0.800000
0.956044
0.776786
0.857143
3.025400
338.467000
1800
0.064200
0.065662
0.842773
0.695652
0.592593
0.640000
0.725490
0.513889
0.601626
0.973822
0.911765
0.941772
1.000000
0.782609
0.878049
0.888889
0.516129
0.653061
0.909091
0.370370
0.526316
0.980676
0.931193
0.955294
0.838710
0.928571
0.881356
0.888889
0.727273
0.800000
0.902655
0.910714
0.906667
3.030400
337.907000
1900
0.066700
0.063640
0.848633
0.888889
0.592593
0.711111
0.740000
0.513889
0.606557
0.969388
0.931373
0.950000
1.000000
0.782609
0.878049
0.857143
0.580645
0.692308
0.846154
0.407407
0.550000
0.971496
0.938073
0.954492
0.757576
0.892857
0.819672
0.869565
0.727273
0.792079
0.878261
0.901786
0.889868
3.040600
336.777000
2000
0.065400
0.061513
0.852539
0.727273
0.592593
0.653061
0.735849
0.541667
0.624000
0.942308
0.960784
0.951456
0.973684
0.804348
0.880952
0.863636
0.612903
0.716981
0.750000
0.444444
0.558140
0.976019
0.933486
0.954279
0.781250
0.892857
0.833333
1.000000
0.745455
0.854167
0.932692
0.866071
0.898148
3.053000
335.413000
2100
0.062300
0.063823
0.856445
0.800000
0.592593
0.680851
0.693548
0.597222
0.641791
0.946078
0.946078
0.946078
1.000000
0.782609
0.878049
0.909091
0.645161
0.754717
0.785714
0.407407
0.536585
0.983133
0.935780
0.958872
0.857143
0.857143
0.857143
0.906977
0.709091
0.795918
0.908257
0.883929
0.895928
3.044400
336.350000
2200
0.063300
0.063828
0.849609
0.615385
0.592593
0.603774
0.829268
0.472222
0.601770
0.955000
0.936275
0.945545
0.973684
0.804348
0.880952
0.785714
0.709677
0.745763
0.705882
0.444444
0.545455
0.983092
0.933486
0.957647
0.884615
0.821429
0.851852
0.860000
0.781818
0.819048
0.941176
0.857143
0.897196
3.016200
339.498000
2300
0.066600
0.059670
0.852539
0.727273
0.592593
0.653061
0.782609
0.500000
0.610169
0.964286
0.926471
0.945000
1.000000
0.804348
0.891566
0.833333
0.645161
0.727273
0.733333
0.407407
0.523810
0.976415
0.949541
0.962791
0.884615
0.821429
0.851852
0.877551
0.781818
0.826923
0.933333
0.875000
0.903226
3.054500
335.241000
2400
0.063600
0.058983
0.857422
0.842105
0.592593
0.695652
0.750000
0.541667
0.629032
0.965174
0.950980
0.958025
1.000000
0.804348
0.891566
0.833333
0.645161
0.727273
0.800000
0.444444
0.571429
0.987864
0.933486
0.959906
0.892857
0.892857
0.892857
0.880000
0.800000
0.838095
0.925234
0.883929
0.904110
3.051100
335.618000
2500
0.062900
0.057790
0.861328
0.842105
0.592593
0.695652
0.775510
0.527778
0.628099
0.964824
0.941176
0.952854
1.000000
0.782609
0.878049
0.875000
0.677419
0.763636
0.800000
0.444444
0.571429
0.983213
0.940367
0.961313
0.764706
0.928571
0.838710
0.913043
0.763636
0.831683
0.894737
0.910714
0.902655
3.044700
336.324000
2600
0.056300
0.060169
0.860352
0.653846
0.629630
0.641509
0.763636
0.583333
0.661417
0.941463
0.946078
0.943765
1.000000
0.804348
0.891566
0.900000
0.580645
0.705882
0.800000
0.444444
0.571429
0.985542
0.938073
0.961222
0.892857
0.892857
0.892857
0.931818
0.745455
0.828283
0.943396
0.892857
0.917431
3.043800
336.416000
2700
0.060500
0.057206
0.866211
0.941176
0.592593
0.727273
0.750000
0.625000
0.681818
0.969697
0.941176
0.955224
1.000000
0.804348
0.891566
0.904762
0.612903
0.730769
0.800000
0.444444
0.571429
0.985577
0.940367
0.962441
0.764706
0.928571
0.838710
0.875000
0.763636
0.815534
0.855932
0.901786
0.878261
3.030500
337.898000
2800
0.060500
0.055237
0.864258
0.809524
0.629630
0.708333
0.745763
0.611111
0.671756
0.969697
0.941176
0.955224
0.974359
0.826087
0.894118
0.947368
0.580645
0.720000
0.785714
0.407407
0.536585
0.978673
0.947248
0.962704
0.787879
0.928571
0.852459
0.934783
0.781818
0.851485
0.942857
0.883929
0.912442
3.051900
335.527000
2900
0.055400
0.056251
0.874023
0.888889
0.592593
0.711111
0.745455
0.569444
0.645669
0.960784
0.960784
0.960784
1.000000
0.847826
0.917647
0.862069
0.806452
0.833333
0.750000
0.444444
0.558140
0.974057
0.947248
0.960465
0.888889
0.857143
0.872727
0.897959
0.800000
0.846154
0.952381
0.892857
0.921659
3.020500
339.014000
3000
0.049900
0.058357
0.863281
0.720000
0.666667
0.692308
0.741935
0.638889
0.686567
0.925234
0.970588
0.947368
0.972973
0.782609
0.867470
0.863636
0.612903
0.716981
0.733333
0.407407
0.523810
0.987952
0.940367
0.963572
0.812500
0.928571
0.866667
0.895833
0.781818
0.834951
0.950495
0.857143
0.901408
3.079000
332.578000
3100
0.047700
0.055833
0.873047
0.692308
0.666667
0.679245
0.757576
0.694444
0.724638
0.955882
0.955882
0.955882
0.951220
0.847826
0.896552
0.846154
0.709677
0.771930
0.705882
0.444444
0.545455
0.988010
0.944954
0.966002
0.764706
0.928571
0.838710
0.895833
0.781818
0.834951
0.941176
0.857143
0.897196
3.057400
334.920000
3200
0.044400
0.054468
0.869141
0.800000
0.592593
0.680851
0.816327
0.555556
0.661157
0.946078
0.946078
0.946078
0.975000
0.847826
0.906977
0.851852
0.741935
0.793103
0.846154
0.407407
0.550000
0.983294
0.944954
0.963743
0.806452
0.892857
0.847458
0.936170
0.800000
0.862745
0.893805
0.901786
0.897778
3.054000
335.298000
3300
0.046700
0.055581
0.874023
0.655172
0.703704
0.678571
0.769231
0.694444
0.729927
0.964824
0.941176
0.952854
0.952381
0.869565
0.909091
0.916667
0.709677
0.800000
0.916667
0.407407
0.564103
0.983294
0.944954
0.963743
0.862069
0.892857
0.877193
0.933333
0.763636
0.840000
0.919643
0.919643
0.919643
3.048600
335.890000
3400
0.049800
0.054662
0.872070
0.703704
0.703704
0.703704
0.789474
0.625000
0.697674
0.955665
0.950980
0.953317
0.973684
0.804348
0.880952
0.846154
0.709677
0.771930
0.705882
0.444444
0.545455
0.983294
0.944954
0.963743
0.838710
0.928571
0.881356
0.934783
0.781818
0.851485
0.934579
0.892857
0.913242
3.052100
335.512000
3500
0.045000
0.053903
0.877930
0.842105
0.592593
0.695652
0.784615
0.708333
0.744526
0.946860
0.960784
0.953771
0.973684
0.804348
0.880952
0.913043
0.677419
0.777778
0.733333
0.407407
0.523810
0.974118
0.949541
0.961672
0.827586
0.857143
0.842105
0.936170
0.800000
0.862745
0.942857
0.883929
0.912442
3.085000
331.928000
3600
0.051100
0.053717
0.878906
0.791667
0.703704
0.745098
0.771930
0.611111
0.682171
0.974490
0.936275
0.955000
0.975000
0.847826
0.906977
0.833333
0.806452
0.819672
0.764706
0.481481
0.590909
0.983333
0.947248
0.964953
0.806452
0.892857
0.847458
0.901961
0.836364
0.867925
0.970588
0.883929
0.925234
3.037900
337.076000
3700
0.044400
0.053387
0.880859
0.850000
0.629630
0.723404
0.758065
0.652778
0.701493
0.969849
0.946078
0.957816
0.975610
0.869565
0.919540
0.909091
0.645161
0.754717
0.736842
0.518519
0.608696
0.985782
0.954128
0.969697
0.833333
0.892857
0.862069
0.916667
0.800000
0.854369
0.934579
0.892857
0.913242
3.029100
338.056000
3800
0.042800
0.054548
0.878906
0.857143
0.666667
0.750000
0.786885
0.666667
0.721805
0.955882
0.955882
0.955882
0.975610
0.869565
0.919540
0.800000
0.774194
0.786885
0.636364
0.518519
0.571429
0.987893
0.935780
0.961131
0.833333
0.892857
0.862069
0.867925
0.836364
0.851852
0.935185
0.901786
0.918182
3.001900
341.116000
3900
0.045000
0.055396
0.871094
0.562500
0.666667
0.610169
0.785714
0.611111
0.687500
0.960396
0.950980
0.955665
0.951220
0.847826
0.896552
0.913043
0.677419
0.777778
0.750000
0.444444
0.558140
0.978774
0.951835
0.965116
0.827586
0.857143
0.842105
0.918367
0.818182
0.865385
0.951456
0.875000
0.911628
3.014600
339.675000
4000
0.041900
0.053967
0.881836
0.857143
0.666667
0.750000
0.816667
0.680556
0.742424
0.960591
0.955882
0.958231
0.975610
0.869565
0.919540
0.785714
0.709677
0.745763
0.764706
0.481481
0.590909
0.985577
0.940367
0.962441
0.827586
0.857143
0.842105
0.903846
0.854545
0.878505
0.926606
0.901786
0.914027
3.034200
337.487000
4100
0.044700
0.052438
0.884766
0.857143
0.666667
0.750000
0.821429
0.638889
0.718750
0.960784
0.960784
0.960784
1.000000
0.869565
0.930233
0.827586
0.774194
0.800000
0.750000
0.444444
0.558140
0.978723
0.949541
0.963912
0.806452
0.892857
0.847458
0.903846
0.854545
0.878505
0.933962
0.883929
0.908257
2.999700
341.365000
4200
0.042900
0.052500
0.878906
0.818182
0.666667
0.734694
0.790323
0.680556
0.731343
0.960396
0.950980
0.955665
0.974359
0.826087
0.894118
0.846154
0.709677
0.771930
0.857143
0.444444
0.585366
0.980998
0.947248
0.963827
0.833333
0.892857
0.862069
0.918367
0.818182
0.865385
0.909091
0.892857
0.900901
3.038200
337.043000
4300
0.045200
0.052227
0.876953
0.894737
0.629630
0.739130
0.731343
0.680556
0.705036
0.970000
0.950980
0.960396
0.975000
0.847826
0.906977
0.880000
0.709677
0.785714
0.750000
0.444444
0.558140
0.983294
0.944954
0.963743
0.862069
0.892857
0.877193
0.934783
0.781818
0.851485
0.942857
0.883929
0.912442
3.013300
339.832000
4400
0.040300
0.052687
0.875977
0.720000
0.666667
0.692308
0.774194
0.666667
0.716418
0.965000
0.946078
0.955446
0.952381
0.869565
0.909091
0.827586
0.774194
0.800000
0.750000
0.444444
0.558140
0.985577
0.940367
0.962441
0.833333
0.892857
0.862069
0.901961
0.836364
0.867925
0.927273
0.910714
0.918919
2.999700
341.372000
4500
0.040300
0.050848
0.881836
0.760000
0.703704
0.730769
0.836364
0.638889
0.724409
0.965347
0.955882
0.960591
1.000000
0.847826
0.917647
0.884615
0.741935
0.807018
0.750000
0.444444
0.558140
0.976526
0.954128
0.965197
0.827586
0.857143
0.842105
0.901961
0.836364
0.867925
0.951923
0.883929
0.916667
3.021100
338.951000
4600
0.043500
0.051430
0.876953
0.791667
0.703704
0.745098
0.807692
0.583333
0.677419
0.964824
0.941176
0.952854
0.951220
0.847826
0.896552
0.888889
0.774194
0.827586
0.700000
0.518519
0.595745
0.981087
0.951835
0.966240
0.833333
0.892857
0.862069
0.903846
0.854545
0.878505
0.951923
0.883929
0.916667
3.002500
341.044000
4700
0.044900
0.050805
0.882812
0.857143
0.666667
0.750000
0.810345
0.652778
0.723077
0.965000
0.946078
0.955446
0.975610
0.869565
0.919540
0.857143
0.774194
0.813559
0.777778
0.518519
0.622222
0.981087
0.951835
0.966240
0.827586
0.857143
0.842105
0.903846
0.854545
0.878505
0.951923
0.883929
0.916667
3.026100
338.392000
4800
0.044600
0.050638
0.883789
0.863636
0.703704
0.775510
0.827586
0.666667
0.738462
0.960396
0.950980
0.955665
0.975610
0.869565
0.919540
0.884615
0.741935
0.807018
0.736842
0.518519
0.608696
0.983373
0.949541
0.966161
0.827586
0.857143
0.842105
0.900000
0.818182
0.857143
0.934579
0.892857
0.913242
3.005100
340.754000
4900
0.040700
0.050686
0.885742
0.826087
0.703704
0.760000
0.810345
0.652778
0.723077
0.960591
0.955882
0.958231
0.975610
0.869565
0.919540
0.857143
0.774194
0.813559
0.736842
0.518519
0.608696
0.983373
0.949541
0.966161
0.827586
0.857143
0.842105
0.884615
0.836364
0.859813
0.952381
0.892857
0.921659
3.020100
339.064000
5000
0.045900
0.050501
0.884766
0.826087
0.703704
0.760000
0.824561
0.652778
0.728682
0.960591
0.955882
0.958231
0.975610
0.869565
0.919540
0.857143
0.774194
0.813559
0.736842
0.518519
0.608696
0.983373
0.949541
0.966161
0.857143
0.857143
0.857143
0.900000
0.818182
0.857143
0.934579
0.892857
0.913242
3.023300
338.699000
TrainOutput(global_step=5000, training_loss=0.06839489059448242, metrics={'train_runtime': 1942.6625, 'train_samples_per_second': 2.574, 'total_flos': 1.5358470639536976e+16, 'epoch': 1.71, 'init_mem_cpu_alloc_delta': 2160254976, 'init_mem_gpu_alloc_delta': 560850432, 'init_mem_cpu_peaked_delta': 154095616, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 20344832, 'train_mem_gpu_alloc_delta': 2345446400, 'train_mem_cpu_peaked_delta': 942477312, 'train_mem_gpu_peaked_delta': 4135501312})
Code
model.save_pretrained(DATA_FOLDER / "model" )
Code
@torch.no_grad ()
def infer(text: str ) -> List[str ]:
input_ids = tokenizer(
text, return_attention_mask= False , return_tensors= "pt"
)["input_ids" ]
output = model(input_ids= input_ids.to(model.device))[1 ][0 ]
return [
category
for confidence, category in zip (output.tolist(), categories)
if confidence > 0.
]
Code
infer("I made a ton of money on the chain" )
Code
infer("I really enjoy seeing the countryside during my morning run" )
Code
import gradio as gr
def gradio_categorise(text: str ) -> str :
results = infer(text)
if not results:
return "You talk about boring things."
return ", " .join(results)
gr.Interface(
fn= gradio_categorise,
inputs= ["textbox" ],
outputs= "text"
).launch(share= True )
Running locally at: http://127.0.0.1:7861/
This share link will expire in 24 hours. If you need a permanent link, visit: https://gradio.app/introducing-hosted (NEW!)
Running on External URL: https://40419.gradio.app
Interface loading below...
(<Flask 'gradio.networking'>,
'http://127.0.0.1:7861/',
'https://40419.gradio.app')