Tweet Hashtag Prediction

Predicting the presence of hashtags on tweets
Published

July 21, 2021

Someone asked if it would be possible to predict categories for text, where the training data uses hashtags as the labels. I thought this was an interesting enough problem to have a go at, so here we are.

The plan is to determine the top 50 hashtags for one day of twitter data and then train a multi label model to predict the hashtags for a given tweet. I need this to predict multiple labels so it will require a slight adjustment to the MODELForSequenceClassification models that huggingface provides (they can do single label classification or regression problems).

So let’s start by evaluating the data. It’s still downloading and it’s rather secretive so I will only be showing aggregate stats for it.

Code
from pathlib import Path
import pandas as pd

DATA_FOLDER = Path("/data/blog/2021-07-21-tweet-hashtag/")
TWEET_FILES = sorted((DATA_FOLDER / "raw").glob("*.gz.parquet"))

df = pd.read_parquet(TWEET_FILES[0])
df = df[["contents"]].rename(columns={"contents": "text"})
Code
df.text.head()
0                hannie go nom https://t.co/8lTgi2oPXe
1    RT @MuhaddasI Oomfs convinced me to post this ...
2                     O verdadeiro não está entre nós.
3    Happy birthday to the one who’s had my heart s...
4    RT @momomosumomo517 💗❤️💛💜\nももクロ・スケジュール メモ (202...
Name: text, dtype: object

So I need to extract the hashtags from this text.

Code
import re

HASHTAG_PATTERN = re.compile(r"#([^\s]+)")
df.text.str.findall(HASHTAG_PATTERN).head()
0                             []
1                             []
2                             []
3    [HappyBirthdayJustinBieber]
4                         [ももクロ]
Name: text, dtype: object
Code
df["hashtags"] = df.text.str.findall(HASHTAG_PATTERN)
df.head()
text hashtags
0 hannie go nom https://t.co/8lTgi2oPXe []
1 RT @MuhaddasI Oomfs convinced me to post this ... []
2 O verdadeiro não está entre nós. []
3 Happy birthday to the one who’s had my heart s... [HappyBirthdayJustinBieber]
4 RT @momomosumomo517 💗❤️💛💜\nももクロ・スケジュール メモ (202... [ももクロ]
Code
(
    df.hashtags
        .explode()
        .value_counts()
        .sort_values(ascending=False)
        [:10]
)
YirmiHezimet60BinHak       551
NoticeTurkishStudents      535
ROSÉ                       534
WhatsHappeningInMyanmar    411
3MPostsForSidShuklaOnIG    392
로제                         391
3YearsWithHopeWorld        319
Mar1Coup                   307
ShopeeJAMBORE              219
1                          196
Name: hashtags, dtype: int64

So these might be the top hashtags in my dataset however I question if they are really categories. This also reminds me that I’m going to need a multi lingual model.

I want to make this easy for me and I want to be able to interpret the results. So I’m going to restrict this to english and single word hashtags.

Code
df = pd.read_parquet(TWEET_FILES[0])
df = df[df.language == "en"]
df = df[["contents"]].rename(columns={"contents": "text"})

SIMPLE_HASHTAG_PATTERN = re.compile(r"#([A-Za-z][a-z]+)")
df["hashtags"] = (
    df.text.str
        .findall(SIMPLE_HASHTAG_PATTERN)
        .apply(lambda hashtags: [hashtag.casefold() for hashtag in hashtags])
)
df.hashtags.head()
0         []
1         []
3    [happy]
5         []
7         []
Name: hashtags, dtype: object
Code
(
    df.hashtags
        .explode()
        .value_counts()
        .sort_values(ascending=False)
        [:10]
)
whats       471
mar         292
nengi       135
vote        108
happy       103
sidharth    100
march        98
hope         98
milk         79
blueside     74
Name: hashtags, dtype: int64

This might be better? I need to filter this down to actual categories. It might be necessary to infer the category from the hashtag through a mapping.

For now I’m inclined to just use a dictionary.

Code
DICTIONARY_WORDS = {
    word.casefold().strip()
    for word in Path("/usr/share/dict/british-english").read_text().splitlines()
}

hashtags = (
    df.hashtags
        .explode()
        .value_counts()
        .sort_values(ascending=False)
        .to_frame()
        .reset_index()
        .rename(columns={"index": "hashtag", "hashtags": "count"})
)
hashtags[hashtags.hashtag.isin(DICTIONARY_WORDS)][:10]
hashtag count
0 whats 471
1 mar 292
3 vote 108
4 happy 103
6 march 98
7 hope 98
8 milk 79
10 friday 73
11 blue 69
12 myanmar 64

I’ve discussed this more and it may well be best to start with a list of categories that might be hashtags. Starting with them would ensure that the categories are reasonably high quality. To be able to do this effectively I need to load all of the data, as there just arn’t enough hashtags in this single file.

Code
from tqdm.auto import tqdm

SIMPLE_HASHTAG_PATTERN = re.compile(r"#([A-Za-z][a-z]+)")

def load_file(path: Path) -> pd.DataFrame:
    df = pd.read_parquet(path)
    df = df[df.language == "en"]
    df = df[["contents"]].rename(columns={"contents": "text"})

    df["hashtags"] = (
        df.text.str
            .findall(SIMPLE_HASHTAG_PATTERN)
            .apply(lambda hashtags: [hashtag.casefold() for hashtag in hashtags])
    )
    return df

df = pd.concat([
    load_file(path)
    for path in tqdm(TWEET_FILES[:1_000])
])
Code
df
text hashtags
0 hannie go nom https://t.co/8lTgi2oPXe []
1 RT @MuhaddasI Oomfs convinced me to post this ... []
3 Happy birthday to the one who’s had my heart s... [happy]
5 make it make sense []
7 RT @Findomgoddess25 Here is my verification vi... []
... ... ...
99970 Due to the soggy conditions at WMS, Girls’s Cr... []
99974 RT @httphearts i pray that march will be full ... []
99975 RT @mintaiey art is hard https://t.co/93GBQieXpi []
99976 RT @bontlexn In March I will attract money. I ... []
99982 i might do a picarto stream later today becaus... []

33800080 rows × 2 columns

Code
df.to_parquet(DATA_FOLDER / "english-hashtags.gz.parquet", compression="gzip")
Code
df = pd.read_parquet(DATA_FOLDER / "english-hashtags.gz.parquet")
Code
hashtags = (
    df.hashtags
        .explode()
        .value_counts()
        .sort_values(ascending=False)
        .to_frame()
        .reset_index()
        .rename(columns={"index": "hashtag", "hashtags": "count"})
)
Code
"sport" in hashtags.hashtag.values
True
Code
hashtags.hashtag.head(20)
0         whats
1           mar
2          hope
3           the
4          vote
5         march
6          stop
7         happy
8        shopee
9          blue
10    daechwita
11          way
12       friday
13        nengi
14        actor
15       golden
16          new
17      bitcoin
18         milk
19     sidharth
Name: hashtag, dtype: object
Code
interests = {
    "bts",
    "cricket",
    "arsenal",
    "twitch",
    "author",
    "wrestling",
    "baseball",
    "crypto",
    "sports",
    "hockey",
    "anime",
    "fortnite",
    "f1",
    "golf",
    "travel",
    "mma",
    "soccer",
    "boxing",
    "artist",
    "wine",
    "beer",
    "photographer",
    "softball",
    "vegan",
    "football",
    "tennis",
    "weather",
    "fitness",
    "cannabis",
    "rugby",
    "hair",
    "developer",
    "congress",
    "music",
    "disney",
    "yankees",
    "basketball",
    "overwatch",
    "makeup",
    "cycling",
    "theatre",
    "security"
}
Code
top_hashtags = hashtags[hashtags.hashtag.isin(interests)][:10]
top_interests = set(top_hashtags.hashtag.values)
top_hashtags
hashtag count
20 music 40945
49 crypto 20641
105 twitch 10933
154 artist 8037
207 disney 5982
293 travel 4492
326 fortnite 4076
330 anime 3995
469 fitness 2976
499 security 2808

So of my list there are 10 hashtags that are present in the dataset in reasonable volumes. This is enough to get going with, if it can’t predict these then there is no hope. Regarding the dataset I think that any tweet with these hashtags will be included and the ratio between the most and least common isn’t terrible (about 15x).

So let’s select the appropriate rows and start preprocessing the text. The hashtags have to be removed from the text otherwise the task would be much too easy!

Code
interesting_rows = (
    df[
        df.hashtags.apply(
            lambda hashtags: bool(top_interests.intersection(hashtags))
        )
    ].copy()
)

interesting_rows["hashtags"] = interesting_rows.hashtags.apply(
    lambda hashtags: sorted({hashtag for hashtag in hashtags if hashtag in top_interests})
)
Code
interesting_rows.head()
text hashtags
135 RT @SwaapCrypto #Ripple fintech giant has agai... [crypto]
1712 RT @out_h2 Tales of Berseria continues! Happy ... [twitch]
1795 The Android 11 Privacy and Security Features Y... [security]
2259 RT @LNdejje Who wants some support? \n1Like/Re... [twitch]
3756 Live! come here! https://t.co/DwyA0a3Ng0\n\n#... [twitch]
Code
import re

HASHTAG_PATTERN = re.compile(r"\s*#([^\s]+)\s*")
interesting_rows["text"] = interesting_rows.text.str.replace(HASHTAG_PATTERN, " ")
interesting_rows.head()
text hashtags
135 RT @SwaapCrypto fintech giant has again unleas... [crypto]
1712 RT @out_h2 Tales of Berseria continues! Happy ... [twitch]
1795 The Android 11 Privacy and Security Features Y... [security]
2259 RT @LNdejje Who wants some support? \n1Like/Re... [twitch]
3756 Live! come here! https://t.co/DwyA0a3Ng0 ... [twitch]

I’m not prone to over thinking things, so I’m going to assume that this worked perfectly and I can now train the model. The first step of this is to create the dataset.

Code
MODEL_NAME = "facebook/bart-base"
MAXIMUM_TOKEN_LENGTH = 128
BATCH_SIZE = 32
MAX_STEPS = 5_000
Code
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
Code
categories = sorted(interesting_rows.hashtags.explode().unique())
categories
['anime',
 'artist',
 'crypto',
 'disney',
 'fitness',
 'fortnite',
 'music',
 'security',
 'travel',
 'twitch']
Code
from typing import *

def encode(row: Dict[str, Any]) -> Dict[str, Any]:
    input_ids = tokenizer(
        row["text"],
        max_length=MAXIMUM_TOKEN_LENGTH,
        truncation=True,
        return_attention_mask=False,
    )["input_ids"]
    hashtags = set(row["hashtags"])
    label = [category in hashtags for category in categories]
    
    return {"input_ids": input_ids, "label": label}
Code
from datasets import Dataset

ds = Dataset.from_pandas(interesting_rows)
ds = ds.remove_columns("__index_level_0__")
ds = ds.map(encode)

split = ds.train_test_split(test_size=1024)
train_ds = split["train"]
test_ds = split["test"]
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Code
from transformers import BartPretrainedModel, BartModel, BartConfig
from transformers.models.bart.modeling_bart import BartClassificationHead
from transformers.modeling_outputs import Seq2SeqSequenceClassifierOutput
import torch

class BartForMultiLabelSequenceClassification(BartPretrainedModel):
    def __init__(self, config: BartConfig, **kwargs):
        config.num_labels = 10
        super().__init__(config, **kwargs)
        self.model = BartModel(config)
        self.classification_head = BartClassificationHead(
            config.d_model,
            config.d_model,
            config.num_labels,
            config.classifier_dropout,
        )
        self.model._init_weights(self.classification_head.dense)
        self.model._init_weights(self.classification_head.out_proj)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Seq2SeqSequenceClassifierOutput:
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        hidden_states = outputs[0]  # last hidden state

        eos_mask = input_ids.eq(self.config.eos_token_id)

        if len(torch.unique(eos_mask.sum(1))) > 1:
            raise ValueError("All examples must have the same number of <eos> tokens.")
        sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
            :, -1, :
        ]
        logits = self.classification_head(sentence_representation)

        loss = None
        if labels is not None:
            # changed
            loss = torch.nn.functional.binary_cross_entropy_with_logits(
                # logits.view(-1), labels
                logits, labels.float()
            )

        return (loss, logits)
#         return Seq2SeqSequenceClassifierOutput(
#             loss=loss,
#             logits=logits,
#             past_key_values=outputs.past_key_values,
#             decoder_hidden_states=outputs.decoder_hidden_states,
#             decoder_attentions=outputs.decoder_attentions,
#             cross_attentions=outputs.cross_attentions,
#             encoder_last_hidden_state=outputs.encoder_last_hidden_state,
#             encoder_hidden_states=outputs.encoder_hidden_states,
#             encoder_attentions=outputs.encoder_attentions,
#         )
Code
from typing import *
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support
)
from transformers import EvalPrediction

def compute_metrics(pred: EvalPrediction) -> Dict[str, float]:
    labels = pred.label_ids
    predictions = (pred.predictions > 0).astype(int)
    
    accuracy = accuracy_score(labels, predictions)
    results = {
        "accuracy": accuracy
    }
    for name, precision, recall, fscore in zip(
        categories,
        *precision_recall_fscore_support(
            labels, predictions, zero_division=0.
        )[:3]
    ):
        results[f"{name}_precision"] = precision
        results[f"{name}_recall"] = recall
        results[f"{name}_f1"] = fscore
    return results
Code
model = BartForMultiLabelSequenceClassification.from_pretrained(MODEL_NAME)
Some weights of the model checkpoint at facebook/bart-base were not used when initializing BartForMultiLabelSequenceClassification: ['final_logits_bias']
- This IS expected if you are initializing BartForMultiLabelSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BartForMultiLabelSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BartForMultiLabelSequenceClassification were not initialized from the model checkpoint at facebook/bart-base and are newly initialized: ['classification_head.dense.weight', 'classification_head.dense.bias', 'classification_head.out_proj.weight', 'classification_head.out_proj.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Code
from pathlib import Path
from transformers import Trainer, TrainingArguments

MODEL_RUN_FOLDER = Path("/data/blog/2021-07-21-tweet-hashtag/runs")
MODEL_RUN_FOLDER.mkdir(parents=True, exist_ok=True)

training_args = TrainingArguments(
    report_to=[],
    output_dir=MODEL_RUN_FOLDER / "output",
    overwrite_output_dir=True,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=5e-5,
    warmup_ratio=0.06,
    max_steps=MAX_STEPS,
    evaluation_strategy="steps",
    logging_dir=MODEL_RUN_FOLDER / "output",
    logging_steps=100,
    eval_steps=100,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    greater_is_better=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

trainer.train()
[5000/5000 32:22, Epoch 1/2]
Step Training Loss Validation Loss Accuracy Anime Precision Anime Recall Anime F1 Artist Precision Artist Recall Artist F1 Crypto Precision Crypto Recall Crypto F1 Disney Precision Disney Recall Disney F1 Fitness Precision Fitness Recall Fitness F1 Fortnite Precision Fortnite Recall Fortnite F1 Music Precision Music Recall Music F1 Security Precision Security Recall Security F1 Travel Precision Travel Recall Travel F1 Twitch Precision Twitch Recall Twitch F1 Runtime Samples Per Second
100 0.358600 0.190214 0.482422 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.968254 0.598039 0.739394 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.898795 0.855505 0.876616 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 2.761900 370.761000
200 0.157000 0.121153 0.707031 0.000000 0.000000 0.000000 0.666667 0.055556 0.102564 0.861905 0.887255 0.874396 1.000000 0.586957 0.739726 0.875000 0.451613 0.595745 0.000000 0.000000 0.000000 0.972569 0.894495 0.931900 1.000000 0.428571 0.600000 1.000000 0.090909 0.166667 0.760000 0.848214 0.801688 2.797000 366.109000
300 0.117200 0.106900 0.768555 0.714286 0.185185 0.294118 0.445946 0.458333 0.452055 0.915789 0.852941 0.883249 1.000000 0.782609 0.878049 0.882353 0.483871 0.625000 0.000000 0.000000 0.000000 0.968370 0.912844 0.939787 0.909091 0.714286 0.800000 0.878049 0.654545 0.750000 0.898876 0.714286 0.796020 2.981300 343.478000
400 0.104200 0.096251 0.786133 0.923077 0.444444 0.600000 0.537037 0.402778 0.460317 0.923469 0.887255 0.905000 0.971429 0.739130 0.839506 1.000000 0.354839 0.523810 0.000000 0.000000 0.000000 0.957143 0.922018 0.939252 0.638889 0.821429 0.718750 0.822222 0.672727 0.740000 0.897959 0.785714 0.838095 3.061600 334.466000
500 0.093300 0.089162 0.792969 0.516129 0.592593 0.551724 0.647059 0.305556 0.415094 0.952128 0.877451 0.913265 0.970588 0.717391 0.825000 0.894737 0.548387 0.680000 1.000000 0.037037 0.071429 0.971014 0.922018 0.945882 0.694444 0.892857 0.781250 0.878049 0.654545 0.750000 0.892157 0.812500 0.850467 3.033800 337.534000
600 0.084400 0.082491 0.815430 0.517241 0.555556 0.535714 0.729730 0.375000 0.495413 0.940887 0.936275 0.938575 0.925000 0.804348 0.860465 0.900000 0.580645 0.705882 1.000000 0.222222 0.363636 0.980488 0.922018 0.950355 0.851852 0.821429 0.836364 0.918919 0.618182 0.739130 0.902913 0.830357 0.865116 3.068200 333.749000
700 0.087300 0.079204 0.826172 0.631579 0.444444 0.521739 0.702703 0.361111 0.477064 0.958333 0.901961 0.929293 0.972222 0.760870 0.853659 0.944444 0.548387 0.693878 0.750000 0.333333 0.461538 0.978208 0.926606 0.951708 0.774194 0.857143 0.813559 0.930233 0.727273 0.816327 0.833333 0.937500 0.882353 3.038800 336.974000
800 0.082500 0.079864 0.812500 0.500000 0.592593 0.542373 0.700000 0.388889 0.500000 0.919048 0.946078 0.932367 1.000000 0.782609 0.878049 0.950000 0.612903 0.745098 0.900000 0.333333 0.486486 0.973301 0.919725 0.945755 0.827586 0.857143 0.842105 0.923077 0.654545 0.765957 0.953488 0.732143 0.828283 3.046500 336.126000
900 0.081000 0.073797 0.828125 0.714286 0.555556 0.625000 0.638889 0.319444 0.425926 0.972826 0.877451 0.922680 0.902439 0.804348 0.850575 1.000000 0.580645 0.734694 1.000000 0.370370 0.540541 0.955399 0.933486 0.944316 0.750000 0.964286 0.843750 0.951220 0.709091 0.812500 0.910714 0.910714 0.910714 3.028000 338.176000
1000 0.078200 0.071819 0.835938 0.933333 0.518519 0.666667 0.660714 0.513889 0.578125 0.969231 0.926471 0.947368 0.948718 0.804348 0.870588 0.833333 0.645161 0.727273 0.750000 0.333333 0.461538 0.964029 0.922018 0.942556 0.800000 0.857143 0.827586 0.906977 0.709091 0.795918 0.892857 0.892857 0.892857 2.998100 341.544000
1100 0.072100 0.072056 0.842773 0.652174 0.555556 0.600000 0.716981 0.527778 0.608000 0.940887 0.936275 0.938575 0.925000 0.804348 0.860465 0.850000 0.548387 0.666667 0.833333 0.370370 0.512821 0.971292 0.931193 0.950820 0.862069 0.892857 0.877193 0.911111 0.745455 0.820000 0.902655 0.910714 0.906667 3.005700 340.683000
1200 0.071100 0.075315 0.843750 0.666667 0.592593 0.627451 0.696429 0.541667 0.609375 0.897196 0.941176 0.918660 0.926829 0.826087 0.873563 1.000000 0.483871 0.652174 0.750000 0.333333 0.461538 0.966667 0.931193 0.948598 0.787879 0.928571 0.852459 0.814815 0.800000 0.807339 0.939394 0.830357 0.881517 2.995700 341.824000
1300 0.070000 0.069436 0.841797 0.727273 0.592593 0.653061 0.756757 0.388889 0.513761 0.949495 0.921569 0.935323 1.000000 0.782609 0.878049 0.952381 0.645161 0.769231 0.846154 0.407407 0.550000 0.955711 0.940367 0.947977 0.950000 0.678571 0.791667 0.807018 0.836364 0.821429 0.942857 0.883929 0.912442 3.020500 339.022000
1400 0.071800 0.066637 0.845703 0.761905 0.592593 0.666667 0.678571 0.527778 0.593750 0.974490 0.936275 0.955000 1.000000 0.760870 0.864198 0.950000 0.612903 0.745098 0.705882 0.444444 0.545455 0.976019 0.933486 0.954279 0.827586 0.857143 0.842105 0.906977 0.709091 0.795918 0.925234 0.883929 0.904110 3.048800 335.868000
1500 0.066400 0.069763 0.840820 0.708333 0.629630 0.666667 0.738095 0.430556 0.543860 0.969388 0.931373 0.950000 0.972973 0.782609 0.867470 0.900000 0.580645 0.705882 0.733333 0.407407 0.523810 0.980488 0.922018 0.950355 0.892857 0.892857 0.892857 0.716418 0.872727 0.786885 0.933333 0.875000 0.903226 3.058900 334.764000
1600 0.069400 0.067827 0.832031 0.833333 0.555556 0.666667 0.653846 0.472222 0.548387 0.949749 0.926471 0.937965 0.973684 0.804348 0.880952 0.944444 0.548387 0.693878 0.769231 0.370370 0.500000 0.969267 0.940367 0.954598 0.916667 0.785714 0.846154 0.923077 0.654545 0.765957 0.929293 0.821429 0.872038 3.007400 340.489000
1700 0.066000 0.072248 0.839844 0.695652 0.592593 0.640000 0.692308 0.500000 0.580645 0.932039 0.941176 0.936585 0.948718 0.804348 0.870588 0.894737 0.548387 0.680000 0.800000 0.444444 0.571429 0.958042 0.942661 0.950289 0.857143 0.857143 0.857143 0.888889 0.727273 0.800000 0.956044 0.776786 0.857143 3.025400 338.467000
1800 0.064200 0.065662 0.842773 0.695652 0.592593 0.640000 0.725490 0.513889 0.601626 0.973822 0.911765 0.941772 1.000000 0.782609 0.878049 0.888889 0.516129 0.653061 0.909091 0.370370 0.526316 0.980676 0.931193 0.955294 0.838710 0.928571 0.881356 0.888889 0.727273 0.800000 0.902655 0.910714 0.906667 3.030400 337.907000
1900 0.066700 0.063640 0.848633 0.888889 0.592593 0.711111 0.740000 0.513889 0.606557 0.969388 0.931373 0.950000 1.000000 0.782609 0.878049 0.857143 0.580645 0.692308 0.846154 0.407407 0.550000 0.971496 0.938073 0.954492 0.757576 0.892857 0.819672 0.869565 0.727273 0.792079 0.878261 0.901786 0.889868 3.040600 336.777000
2000 0.065400 0.061513 0.852539 0.727273 0.592593 0.653061 0.735849 0.541667 0.624000 0.942308 0.960784 0.951456 0.973684 0.804348 0.880952 0.863636 0.612903 0.716981 0.750000 0.444444 0.558140 0.976019 0.933486 0.954279 0.781250 0.892857 0.833333 1.000000 0.745455 0.854167 0.932692 0.866071 0.898148 3.053000 335.413000
2100 0.062300 0.063823 0.856445 0.800000 0.592593 0.680851 0.693548 0.597222 0.641791 0.946078 0.946078 0.946078 1.000000 0.782609 0.878049 0.909091 0.645161 0.754717 0.785714 0.407407 0.536585 0.983133 0.935780 0.958872 0.857143 0.857143 0.857143 0.906977 0.709091 0.795918 0.908257 0.883929 0.895928 3.044400 336.350000
2200 0.063300 0.063828 0.849609 0.615385 0.592593 0.603774 0.829268 0.472222 0.601770 0.955000 0.936275 0.945545 0.973684 0.804348 0.880952 0.785714 0.709677 0.745763 0.705882 0.444444 0.545455 0.983092 0.933486 0.957647 0.884615 0.821429 0.851852 0.860000 0.781818 0.819048 0.941176 0.857143 0.897196 3.016200 339.498000
2300 0.066600 0.059670 0.852539 0.727273 0.592593 0.653061 0.782609 0.500000 0.610169 0.964286 0.926471 0.945000 1.000000 0.804348 0.891566 0.833333 0.645161 0.727273 0.733333 0.407407 0.523810 0.976415 0.949541 0.962791 0.884615 0.821429 0.851852 0.877551 0.781818 0.826923 0.933333 0.875000 0.903226 3.054500 335.241000
2400 0.063600 0.058983 0.857422 0.842105 0.592593 0.695652 0.750000 0.541667 0.629032 0.965174 0.950980 0.958025 1.000000 0.804348 0.891566 0.833333 0.645161 0.727273 0.800000 0.444444 0.571429 0.987864 0.933486 0.959906 0.892857 0.892857 0.892857 0.880000 0.800000 0.838095 0.925234 0.883929 0.904110 3.051100 335.618000
2500 0.062900 0.057790 0.861328 0.842105 0.592593 0.695652 0.775510 0.527778 0.628099 0.964824 0.941176 0.952854 1.000000 0.782609 0.878049 0.875000 0.677419 0.763636 0.800000 0.444444 0.571429 0.983213 0.940367 0.961313 0.764706 0.928571 0.838710 0.913043 0.763636 0.831683 0.894737 0.910714 0.902655 3.044700 336.324000
2600 0.056300 0.060169 0.860352 0.653846 0.629630 0.641509 0.763636 0.583333 0.661417 0.941463 0.946078 0.943765 1.000000 0.804348 0.891566 0.900000 0.580645 0.705882 0.800000 0.444444 0.571429 0.985542 0.938073 0.961222 0.892857 0.892857 0.892857 0.931818 0.745455 0.828283 0.943396 0.892857 0.917431 3.043800 336.416000
2700 0.060500 0.057206 0.866211 0.941176 0.592593 0.727273 0.750000 0.625000 0.681818 0.969697 0.941176 0.955224 1.000000 0.804348 0.891566 0.904762 0.612903 0.730769 0.800000 0.444444 0.571429 0.985577 0.940367 0.962441 0.764706 0.928571 0.838710 0.875000 0.763636 0.815534 0.855932 0.901786 0.878261 3.030500 337.898000
2800 0.060500 0.055237 0.864258 0.809524 0.629630 0.708333 0.745763 0.611111 0.671756 0.969697 0.941176 0.955224 0.974359 0.826087 0.894118 0.947368 0.580645 0.720000 0.785714 0.407407 0.536585 0.978673 0.947248 0.962704 0.787879 0.928571 0.852459 0.934783 0.781818 0.851485 0.942857 0.883929 0.912442 3.051900 335.527000
2900 0.055400 0.056251 0.874023 0.888889 0.592593 0.711111 0.745455 0.569444 0.645669 0.960784 0.960784 0.960784 1.000000 0.847826 0.917647 0.862069 0.806452 0.833333 0.750000 0.444444 0.558140 0.974057 0.947248 0.960465 0.888889 0.857143 0.872727 0.897959 0.800000 0.846154 0.952381 0.892857 0.921659 3.020500 339.014000
3000 0.049900 0.058357 0.863281 0.720000 0.666667 0.692308 0.741935 0.638889 0.686567 0.925234 0.970588 0.947368 0.972973 0.782609 0.867470 0.863636 0.612903 0.716981 0.733333 0.407407 0.523810 0.987952 0.940367 0.963572 0.812500 0.928571 0.866667 0.895833 0.781818 0.834951 0.950495 0.857143 0.901408 3.079000 332.578000
3100 0.047700 0.055833 0.873047 0.692308 0.666667 0.679245 0.757576 0.694444 0.724638 0.955882 0.955882 0.955882 0.951220 0.847826 0.896552 0.846154 0.709677 0.771930 0.705882 0.444444 0.545455 0.988010 0.944954 0.966002 0.764706 0.928571 0.838710 0.895833 0.781818 0.834951 0.941176 0.857143 0.897196 3.057400 334.920000
3200 0.044400 0.054468 0.869141 0.800000 0.592593 0.680851 0.816327 0.555556 0.661157 0.946078 0.946078 0.946078 0.975000 0.847826 0.906977 0.851852 0.741935 0.793103 0.846154 0.407407 0.550000 0.983294 0.944954 0.963743 0.806452 0.892857 0.847458 0.936170 0.800000 0.862745 0.893805 0.901786 0.897778 3.054000 335.298000
3300 0.046700 0.055581 0.874023 0.655172 0.703704 0.678571 0.769231 0.694444 0.729927 0.964824 0.941176 0.952854 0.952381 0.869565 0.909091 0.916667 0.709677 0.800000 0.916667 0.407407 0.564103 0.983294 0.944954 0.963743 0.862069 0.892857 0.877193 0.933333 0.763636 0.840000 0.919643 0.919643 0.919643 3.048600 335.890000
3400 0.049800 0.054662 0.872070 0.703704 0.703704 0.703704 0.789474 0.625000 0.697674 0.955665 0.950980 0.953317 0.973684 0.804348 0.880952 0.846154 0.709677 0.771930 0.705882 0.444444 0.545455 0.983294 0.944954 0.963743 0.838710 0.928571 0.881356 0.934783 0.781818 0.851485 0.934579 0.892857 0.913242 3.052100 335.512000
3500 0.045000 0.053903 0.877930 0.842105 0.592593 0.695652 0.784615 0.708333 0.744526 0.946860 0.960784 0.953771 0.973684 0.804348 0.880952 0.913043 0.677419 0.777778 0.733333 0.407407 0.523810 0.974118 0.949541 0.961672 0.827586 0.857143 0.842105 0.936170 0.800000 0.862745 0.942857 0.883929 0.912442 3.085000 331.928000
3600 0.051100 0.053717 0.878906 0.791667 0.703704 0.745098 0.771930 0.611111 0.682171 0.974490 0.936275 0.955000 0.975000 0.847826 0.906977 0.833333 0.806452 0.819672 0.764706 0.481481 0.590909 0.983333 0.947248 0.964953 0.806452 0.892857 0.847458 0.901961 0.836364 0.867925 0.970588 0.883929 0.925234 3.037900 337.076000
3700 0.044400 0.053387 0.880859 0.850000 0.629630 0.723404 0.758065 0.652778 0.701493 0.969849 0.946078 0.957816 0.975610 0.869565 0.919540 0.909091 0.645161 0.754717 0.736842 0.518519 0.608696 0.985782 0.954128 0.969697 0.833333 0.892857 0.862069 0.916667 0.800000 0.854369 0.934579 0.892857 0.913242 3.029100 338.056000
3800 0.042800 0.054548 0.878906 0.857143 0.666667 0.750000 0.786885 0.666667 0.721805 0.955882 0.955882 0.955882 0.975610 0.869565 0.919540 0.800000 0.774194 0.786885 0.636364 0.518519 0.571429 0.987893 0.935780 0.961131 0.833333 0.892857 0.862069 0.867925 0.836364 0.851852 0.935185 0.901786 0.918182 3.001900 341.116000
3900 0.045000 0.055396 0.871094 0.562500 0.666667 0.610169 0.785714 0.611111 0.687500 0.960396 0.950980 0.955665 0.951220 0.847826 0.896552 0.913043 0.677419 0.777778 0.750000 0.444444 0.558140 0.978774 0.951835 0.965116 0.827586 0.857143 0.842105 0.918367 0.818182 0.865385 0.951456 0.875000 0.911628 3.014600 339.675000
4000 0.041900 0.053967 0.881836 0.857143 0.666667 0.750000 0.816667 0.680556 0.742424 0.960591 0.955882 0.958231 0.975610 0.869565 0.919540 0.785714 0.709677 0.745763 0.764706 0.481481 0.590909 0.985577 0.940367 0.962441 0.827586 0.857143 0.842105 0.903846 0.854545 0.878505 0.926606 0.901786 0.914027 3.034200 337.487000
4100 0.044700 0.052438 0.884766 0.857143 0.666667 0.750000 0.821429 0.638889 0.718750 0.960784 0.960784 0.960784 1.000000 0.869565 0.930233 0.827586 0.774194 0.800000 0.750000 0.444444 0.558140 0.978723 0.949541 0.963912 0.806452 0.892857 0.847458 0.903846 0.854545 0.878505 0.933962 0.883929 0.908257 2.999700 341.365000
4200 0.042900 0.052500 0.878906 0.818182 0.666667 0.734694 0.790323 0.680556 0.731343 0.960396 0.950980 0.955665 0.974359 0.826087 0.894118 0.846154 0.709677 0.771930 0.857143 0.444444 0.585366 0.980998 0.947248 0.963827 0.833333 0.892857 0.862069 0.918367 0.818182 0.865385 0.909091 0.892857 0.900901 3.038200 337.043000
4300 0.045200 0.052227 0.876953 0.894737 0.629630 0.739130 0.731343 0.680556 0.705036 0.970000 0.950980 0.960396 0.975000 0.847826 0.906977 0.880000 0.709677 0.785714 0.750000 0.444444 0.558140 0.983294 0.944954 0.963743 0.862069 0.892857 0.877193 0.934783 0.781818 0.851485 0.942857 0.883929 0.912442 3.013300 339.832000
4400 0.040300 0.052687 0.875977 0.720000 0.666667 0.692308 0.774194 0.666667 0.716418 0.965000 0.946078 0.955446 0.952381 0.869565 0.909091 0.827586 0.774194 0.800000 0.750000 0.444444 0.558140 0.985577 0.940367 0.962441 0.833333 0.892857 0.862069 0.901961 0.836364 0.867925 0.927273 0.910714 0.918919 2.999700 341.372000
4500 0.040300 0.050848 0.881836 0.760000 0.703704 0.730769 0.836364 0.638889 0.724409 0.965347 0.955882 0.960591 1.000000 0.847826 0.917647 0.884615 0.741935 0.807018 0.750000 0.444444 0.558140 0.976526 0.954128 0.965197 0.827586 0.857143 0.842105 0.901961 0.836364 0.867925 0.951923 0.883929 0.916667 3.021100 338.951000
4600 0.043500 0.051430 0.876953 0.791667 0.703704 0.745098 0.807692 0.583333 0.677419 0.964824 0.941176 0.952854 0.951220 0.847826 0.896552 0.888889 0.774194 0.827586 0.700000 0.518519 0.595745 0.981087 0.951835 0.966240 0.833333 0.892857 0.862069 0.903846 0.854545 0.878505 0.951923 0.883929 0.916667 3.002500 341.044000
4700 0.044900 0.050805 0.882812 0.857143 0.666667 0.750000 0.810345 0.652778 0.723077 0.965000 0.946078 0.955446 0.975610 0.869565 0.919540 0.857143 0.774194 0.813559 0.777778 0.518519 0.622222 0.981087 0.951835 0.966240 0.827586 0.857143 0.842105 0.903846 0.854545 0.878505 0.951923 0.883929 0.916667 3.026100 338.392000
4800 0.044600 0.050638 0.883789 0.863636 0.703704 0.775510 0.827586 0.666667 0.738462 0.960396 0.950980 0.955665 0.975610 0.869565 0.919540 0.884615 0.741935 0.807018 0.736842 0.518519 0.608696 0.983373 0.949541 0.966161 0.827586 0.857143 0.842105 0.900000 0.818182 0.857143 0.934579 0.892857 0.913242 3.005100 340.754000
4900 0.040700 0.050686 0.885742 0.826087 0.703704 0.760000 0.810345 0.652778 0.723077 0.960591 0.955882 0.958231 0.975610 0.869565 0.919540 0.857143 0.774194 0.813559 0.736842 0.518519 0.608696 0.983373 0.949541 0.966161 0.827586 0.857143 0.842105 0.884615 0.836364 0.859813 0.952381 0.892857 0.921659 3.020100 339.064000
5000 0.045900 0.050501 0.884766 0.826087 0.703704 0.760000 0.824561 0.652778 0.728682 0.960591 0.955882 0.958231 0.975610 0.869565 0.919540 0.857143 0.774194 0.813559 0.736842 0.518519 0.608696 0.983373 0.949541 0.966161 0.857143 0.857143 0.857143 0.900000 0.818182 0.857143 0.934579 0.892857 0.913242 3.023300 338.699000

TrainOutput(global_step=5000, training_loss=0.06839489059448242, metrics={'train_runtime': 1942.6625, 'train_samples_per_second': 2.574, 'total_flos': 1.5358470639536976e+16, 'epoch': 1.71, 'init_mem_cpu_alloc_delta': 2160254976, 'init_mem_gpu_alloc_delta': 560850432, 'init_mem_cpu_peaked_delta': 154095616, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 20344832, 'train_mem_gpu_alloc_delta': 2345446400, 'train_mem_cpu_peaked_delta': 942477312, 'train_mem_gpu_peaked_delta': 4135501312})
Code
model.save_pretrained(DATA_FOLDER / "model")
Code
@torch.no_grad()
def infer(text: str) -> List[str]:
    input_ids = tokenizer(
        text, return_attention_mask=False, return_tensors="pt"
    )["input_ids"]
    output = model(input_ids=input_ids.to(model.device))[1][0]
    return [
        category
        for confidence, category in zip(output.tolist(), categories)
        if confidence > 0.
    ]
Code
infer("I made a ton of money on the chain")
['crypto']
Code
infer("I really enjoy seeing the countryside during my morning run")
['fitness']
Code
import gradio as gr

def gradio_categorise(text: str) -> str:
    results = infer(text)
    if not results:
        return "You talk about boring things."

    return ", ".join(results)

gr.Interface(
    fn=gradio_categorise,
    inputs=["textbox"],
    outputs="text"
).launch(share=True)
Running locally at: http://127.0.0.1:7861/
This share link will expire in 24 hours. If you need a permanent link, visit: https://gradio.app/introducing-hosted (NEW!)
Running on External URL: https://40419.gradio.app
Interface loading below...
(<Flask 'gradio.networking'>,
 'http://127.0.0.1:7861/',
 'https://40419.gradio.app')