Using Huggingface Trainer to train a Sentence Transformer model
Sentence Transformer models create embeddings out of text, can they be trained with the Huggingface Trainer?
training
Published
October 19, 2022
The Sentence Transformer library is a way to turn documents into embeddings (Reimers and Gurevych 2019). I’m interested in training a sentence transformer model using the huggingface trainer to see how easy it would be.
Reimers, Nils, and Iryna Gurevych. 2019. “Sentence-BERT: Sentence Embeddings Using Siamese BERT-Networks.” In Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing. Association for Computational Linguistics. https://arxiv.org/abs/1908.10084.
Marelli, Marco, Stefano Menini, Marco Baroni, Luisa Bentivogli, Raffaella Bernardi, and Roberto Zamparelli. 2014. “A SICK Cure for the Evaluation of Compositional Distributional Semantic Models.” In Proceedings of the Ninth International Conference on Language Resources and Evaluation (LREC’14), 216–23. Reykjavik, Iceland: European Language Resources Association (ELRA). http://www.lrec-conf.org/proceedings/lrec2014/pdf/363_Paper.pdf.
They will use the SICK dataset (Marelli et al. 2014) which is a dataset of sentence pairs with both relatedness and entailment scores. The aim will be to embed the documents such that related statements are close to each other in the embedding space. I will not use the entailment score at this time.
I’m going to try two separate training approaches. The first will just be to normalize the relatedness score from -1 to 1 and then attempt to produce embeddings that have that cosine similarity. The second will be to take a highly related document pair and mix in random documents as detractors. Then we can compare how well the two approaches work.
The training set is quite small so training shouldn’t take too long.
Dataset
The SICK dataset is available on huggingface so let’s get it.
Found cached dataset sick (/home/matthew/.cache/huggingface/datasets/sick/default/0.0.0/c6b3b0b44eb84b134851396d6d464e5cb8f026960519d640e087fe33472626db)
{'id': '1',
'sentence_A': 'A group of kids is playing in a yard and an old man is standing in the background',
'sentence_B': 'A group of boys in a yard is playing and a man is standing in the background',
'label': 1,
'relatedness_score': 4.5,
'entailment_AB': 'A_neutral_B',
'entailment_BA': 'B_neutral_A',
'sentence_A_original': 'A group of children playing in a yard, a man in the background.',
'sentence_B_original': 'A group of children playing in a yard, a man in the background.',
'sentence_A_dataset': 'FLICKR',
'sentence_B_dataset': 'FLICKR'}
As we can see the relatedness_score is a value that ranges between 1 and 5. To be able to use this with the Sentence Transformers CosineSimilarityLoss I need to map the score to between -1 and 1.
I’ve been reviewing the Sentence Transformers training documentation and it looks like the inputs to the model don’t need to be encoded. No doubt part of this library is making it easy to invoke. Let’s try just using the original Sentence Transformers training approach and then we can compare that to huggingface.
Sentence Transformers Training
The training overview just shows a very simple training process involving calling fit over a list of examples. We can recreate that with the SICK training data.
from sentence_transformers import ( SentenceTransformer, InputExample, losses, evaluation,)from torch.utils.data import DataLoader# Define the model. Either from scratch of by loading a pre-trained modelmodel = SentenceTransformer(MODEL_NAME)# Define your train examples. You need more than just two examples...train_examples = [ InputExample( texts=[row["sentence_A"], row["sentence_B"]], label=row["label"], )for row in sick_ds["train"]]evaluator = evaluation.EmbeddingSimilarityEvaluator( sick_ds["validation"]["sentence_A"], sick_ds["validation"]["sentence_B"], sick_ds["validation"]["label"], main_similarity=evaluation.SimilarityFunction.COSINE,)# Define your train dataset, the dataloader and the train losstrain_dataloader = DataLoader( train_examples, shuffle=True, batch_size=BATCH_SIZE,)train_loss = losses.CosineSimilarityLoss(model)def show_evaluation(score: float, epoch: float, steps: int) ->None:if steps ==-1:print(f"evaluation: {score} epoch {epoch}")else:print(f"evaluation: {score} epoch {epoch} steps {steps}")# Tune the modelmodel.fit( train_objectives=[(train_dataloader, train_loss)], epochs=EPOCHS, warmup_steps=WARMUP_STEPS, optimizer_params={"lr": LEARNING_RATE}, evaluator=evaluator, evaluation_steps=500, callback=show_evaluation,)
evaluation: 0.7836338920116853 epoch 0
evaluation: 0.8033716010524109 epoch 1
evaluation: 0.8175660042766647 epoch 2
evaluation: 0.8146349323817795 epoch 3
evaluation: 0.8127202214972602 epoch 4
evaluation: 0.8139044909698879 epoch 5
evaluation: 0.8154477934840207 epoch 6
evaluation: 0.815803262430808 epoch 7
evaluation: 0.8139142427290464 epoch 8
evaluation: 0.8152793405072882 epoch 9
It took a couple of minutes to train that. The evaluation hasn’t significantly changed though. It seems that the model learnt most of the task in the first three epochs.
The display of the results isn’t as slick as the huggingface version, which separates the progress bar from the individual row scores as well as formatting the scores nicely.
I kinda think this is ok? Without a point of reference it’s difficult to say.
Let’s see what the predictions are for specific sentences.
Code
import pandas as pdimport numpy as npdf = pd.DataFrame( {"sentence_a": sick_ds["test"]["sentence_A"],"sentence_b": sick_ds["test"]["sentence_B"],"target": sick_ds["test"]["label"],"prediction": predictions,"difference": np.abs(predictions - labels), })print(df.head().to_markdown(tablefmt="grid", maxcolwidths=25))
+----+---------------------------+--------------------------+----------+--------------+--------------+
| | sentence_a | sentence_b | target | prediction | difference |
+====+===========================+==========================+==========+==============+==============+
| 0 | There is no boy playing | A group of kids is | 0.15 | 0.154085 | 0.00408456 |
| | outdoors and there is no | playing in a yard and an | | | |
| | man smiling | old man is standing in | | | |
| | | the background | | | |
+----+---------------------------+--------------------------+----------+--------------+--------------+
| 1 | A group of boys in a yard | The young boys are | 0.35 | 0.325681 | 0.0243189 |
| | is playing and a man is | playing outdoors and the | | | |
| | standing in the | man is smiling nearby | | | |
| | background | | | | |
+----+---------------------------+--------------------------+----------+--------------+--------------+
| 2 | A group of children is | The young boys are | 0 | 0.0895932 | 0.0895932 |
| | playing in the house and | playing outdoors and the | | | |
| | there is no man standing | man is smiling nearby | | | |
| | in the background | | | | |
+----+---------------------------+--------------------------+----------+--------------+--------------+
| 3 | A brown dog is attacking | A brown dog is attacking | 0.95 | 0.931177 | 0.0188229 |
| | another animal in front | another animal in front | | | |
| | of the tall man in pants | of the man in pants | | | |
+----+---------------------------+--------------------------+----------+--------------+--------------+
| 4 | A brown dog is attacking | A brown dog is helping | 0.3325 | 0.838758 | 0.506258 |
| | another animal in front | another animal in front | | | |
| | of the man in pants | of the man in pants | | | |
+----+---------------------------+--------------------------+----------+--------------+--------------+
I’ve used this output style to wrap the text nicely. I’m not sure about the targets but the predictions are similar to them. I guess this trained well.
Sentence Transformers Details
This was a slick train. Using this library is very nice, it made that training process very simple. To get this working in huggingface I need to understand further how it performs the training loop.
The key method appears to be model.fit, so I am going to start there. This is a chunk of code so I am going to break it down.
Sentence Transformers - Model Card
The method starts with some code around the model card. Presumably this is for uploading to the huggingface hub, as they have over 100 models there.
info_loss_functions = []for dataloader, loss in train_objectives: info_loss_functions.extend( ModelCardTemplate.get_train_objective_info(dataloader, loss) )info_loss_functions ="\n\n".join([text for text in info_loss_functions])info_fit_parameters = json.dumps( {"evaluator": fullname(evaluator),"epochs": epochs,"steps_per_epoch": steps_per_epoch,"scheduler": scheduler,"warmup_steps": warmup_steps,"optimizer_class": str(optimizer_class),"optimizer_params": optimizer_params,"weight_decay": weight_decay,"evaluation_steps": evaluation_steps,"max_grad_norm": max_grad_norm, }, indent=4, sort_keys=True,)self._model_card_text =Noneself._model_card_vars["{TRAINING_SECTION}"] = ModelCardTemplate.__TRAINING_SECTION__.replace("{LOSS_FUNCTIONS}", info_loss_functions).replace("{FIT_PARAMETERS}", info_fit_parameters)
Sentence Transformers - Preparation
After this comes the data and objective preparation. This involves putting the datasets into a useable format and creating the optimizers. The optimizers are already handled by the huggingface trainer, and data handling is handled by the datasets library.
# Use smart batchingfor dataloader in dataloaders: dataloader.collate_fn =self.smart_batching_collateloss_models = [loss for _, loss in train_objectives]for loss_model in loss_models: loss_model.to(self._target_device)self.best_score =-9999999if steps_per_epoch isNoneor steps_per_epoch ==0: steps_per_epoch =min([len(dataloader) for dataloader in dataloaders])num_train_steps =int(steps_per_epoch * epochs)# Prepare optimizersoptimizers = []schedulers = []for loss_model in loss_models: param_optimizer =list(loss_model.named_parameters()) no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ {"params": [ p for n, p in param_optimizer ifnotany(nd in n for nd in no_decay) ],"weight_decay": weight_decay, }, {"params": [ p for n, p in param_optimizer ifany(nd in n for nd in no_decay) ],"weight_decay": 0.0, }, ] optimizer = optimizer_class(optimizer_grouped_parameters, **optimizer_params) scheduler_obj =self._get_scheduler( optimizer, scheduler=scheduler, warmup_steps=warmup_steps, t_total=num_train_steps, ) optimizers.append(optimizer) schedulers.append(scheduler_obj)global_step =0data_iterators = [iter(dataloader) for dataloader in dataloaders]num_train_objectives =len(train_objectives)
Sentence Transformers - Train Loop
Finally comes the training loop. This mixes in the evaluation as well. It looks like quite standard stuff and this is exactly what the huggingface trainer was made to automate.
The one surprising thing about this loop is that it doesn’t really refer to the model (which would be self). This is because the loss_model is defined in a way that wraps the model:
Even with this, the code generally seems good and I’m hopeful that a version can be created that works with huggingface.
Sentence Transformers - Overview
This all hides some of the complexity behind methods which are not immediately obvious. The general approach for training is this:
These sections are found in the sentence transformers trainer, however it’s not immediately obvious.
The first is the data preparation which involves tokenizing the sentences. This is found in the data collator, which is an interesting choice as it means it has to be done for every batch even after the first epoch. The code for the collator is:
def smart_batching_collate(self, batch):""" Transforms a batch from a SmartBatchingDataset to a batch of tensors for the model Here, batch is a list of tuples: [(tokens, label), ...] :param batch: a batch from a SmartBatchingDataset :return: a batch of tensors for the model """ num_texts =len(batch[0].texts) texts = [[] for _ inrange(num_texts)] labels = []for example in batch:for idx, text inenumerate(example.texts): texts[idx].append(text) labels.append(example.label) labels = torch.tensor(labels) sentence_features = []for idx inrange(num_texts): tokenized =self.tokenize(texts[idx]) sentence_features.append(tokenized)return sentence_features, labels
In huggingface this would be performed by tokenizing the dataset.
The training loop itself conceals the location of the model by passing it as a parameter to the losses.CosineSimilarityLoss which transforms the model output by calculating the loss:
This is straightforward again, the difference in choice is just a question of the separation of concerns. In huggingface it is the case that you implement the loss calculation on the model itself, as the loss is assumed to be intimately related to the model. Sentence Transformers has chosen to allow the loss to vary over the same model.
Huggingface Custom Model
To create a custom model which can work with the Huggingface Trainer, we can follow the warning in the documentation:
The Trainer class is optimized for 🤗 Transformers models and can have surprising behaviors when you use it on other models. When using it on your own model, make sure: * your model always return tuples or subclasses of ModelOutput. * your model can compute the loss if a labels argument is provided and that loss is returned as the first element of the tuple (if your model returns tuples) * your model can accept multiple label arguments (use the label_names in your TrainingArguments to indicate their name to the Trainer) but none of them should be named “label”.
The important thing to realise here is that we want the full capabilities of the Sentence Transformer library to be available. That means that I want to use their data, loss and model as much as possible. A simple adapter between the two libraries is what is needed.
Given the review of the code and the requirements of the huggingface trainer, we can do this most easily by creating a custom model and collator.
Custom Collator
The collator is required because we need to handle two sentences for each row. Embeddings generated from these sentences are compared with each other to determine proximity.
The default DataCollatorWithPadding is designed for single sentence inputs and that makes it unsuitable. To reduce the number of parameters that the model requires we will not return the attention mask at this point. We can infer the attention mask by comparing the tokens to the pad_token_id.
Code
from dataclasses import dataclassfrom typing import Any, Dict, List, Optional, Unionimport torchfrom transformers import PreTrainedTokenizerBasefrom transformers.tokenization_utils import BatchEncodingfrom transformers.utils.generic import PaddingStrategy@dataclassclass SentenceTransformersCollator:"""Collator for a SentenceTransformers model. This encodes the text columns to {column}_input_ids and {column}_attention_mask columns. This works with the two text dataset that is used as the example in the training overview: https://www.sbert.net/docs/training/overview.html""" tokenizer: PreTrainedTokenizerBase text_columns: List[str] padding: Union[bool, str, PaddingStrategy] =True max_length: Optional[int] =None pad_to_multiple_of: Optional[int] =None return_tensors: str="pt"def__init__(self, tokenizer: PreTrainedTokenizerBase, text_columns: List[str]) ->None:self.tokenizer = tokenizerself.text_columns = text_columnsdef__call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:if"label"in features[0]: batch = {"label": torch.tensor([row["label"] for row in features])}else: batch = {}for column inself.text_columns: padded =self._encode([row[column] for row in features]) batch[f"{column}_input_ids"] = padded.input_ids batch[f"{column}_attention_mask"] = padded.attention_maskreturn batchdef _encode(self, texts: List[str]) -> BatchEncoding: tokens =self.tokenizer(texts, return_attention_mask=False)returnself.tokenizer.pad( tokens, padding=self.padding, max_length=self.max_length, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors=self.return_tensors, )
Custom Model
The model is a union of the sentence transformer model with a loss model. We will need to group the inputs together into a list of dicts, as that is what the loss function expects. Given that huggingface expects a single result row per invocation we also need to reshape the output a little.
This does feel like needless complexity.
Code
from typing import Tuple, Union, List, Dict, Any, Optionalimport torchimport torch.nn as nnimport numpy as npfrom sentence_transformers.SentenceTransformer import SentenceTransformerfrom transformers import AutoTokenizer, DataCollatorclass HuggingfaceSentenceTransformersModel(nn.Module):def__init__(self, model: SentenceTransformer, text_columns: List[str], loss: nn.Module, ) ->None:super().__init__()self.model = modelself.text_columns = text_columnsself.loss = lossdef forward(self, label: Optional[torch.Tensor] =None, **inputs) -> Tuple[torch.Tensor, ...]: pad_token_id =self.model.tokenizer.pad_token_id features =self.collect_features(inputs) output = torch.cat([self.model(row)["sentence_embedding"][:, None]for row in features ], dim=1)if label isNone:return (output,) loss =self.loss(features, label)return (loss, output)def collect_features(self, inputs: Dict[str, Union[torch.Tensor, Any]] ) -> List[Dict[str, torch.Tensor]]:"""Turn the inputs from the dataloader into the separate model inputs."""return [ {"input_ids": inputs[f"{column}_input_ids"],"attention_mask": inputs[f"{column}_attention_mask"], }for column inself.text_columns ]
Training compatible Model
Now that we have put everything together we can train the model.
Code
from pathlib import Pathfrom transformers import ( Trainer, TrainingArguments, EvalPrediction,)from sentence_transformers import ( SentenceTransformer, losses,)model = SentenceTransformer(MODEL_NAME)train_loss = losses.CosineSimilarityLoss(model)hf_model = HuggingfaceSentenceTransformersModel( model=model, loss=train_loss, text_columns=TEXT_COLUMNS,)evaluator = evaluation.EmbeddingSimilarityEvaluator( sick_ds["validation"]["sentence_A"], sick_ds["validation"]["sentence_B"], sick_ds["validation"]["label"], main_similarity=evaluation.SimilarityFunction.COSINE,)def compute_metrics(predictions: EvalPrediction) -> Dict[str, float]:return {"cosine_similarity": evaluator(model) }training_args = TrainingArguments( per_device_train_batch_size=BATCH_SIZE, per_device_eval_batch_size=BATCH_SIZE, learning_rate=LEARNING_RATE, warmup_steps=WARMUP_STEPS, optim="adamw_torch", report_to=[], # you'd use wandb for weights and biases# even shorter as this is testing the model and metrics evaluation_strategy="epoch", save_strategy="epoch", num_train_epochs=EPOCHS, logging_steps=100, load_best_model_at_end=True, metric_for_best_model="cosine_similarity", greater_is_better=True, no_cuda=False, remove_unused_columns=False,# output_dir is compulsory logging_dir=MODEL_RUN_FOLDER /"output", output_dir=MODEL_RUN_FOLDER /"output", overwrite_output_dir=True,)trainer = Trainer( model=hf_model, args=training_args, data_collator=SentenceTransformersCollator( model.tokenizer, text_columns=TEXT_COLUMNS, ), train_dataset=sick_ds["train"], eval_dataset=sick_ds["test"], tokenizer=model.tokenizer, compute_metrics=compute_metrics,)trainer.train()
This is really a very small difference. When the number is negative the huggingface model is better than the sentence transformers model. There are places where the huggingface model is better, but really I think these are very similar results. I think that this was a successful train.
Code
import pandas as pdimport numpy as npdf = pd.DataFrame( {"sentence_a": sick_ds["test"]["sentence_A"],"sentence_b": sick_ds["test"]["sentence_B"],"target": sick_ds["test"]["label"],"prediction": predictions,"difference": np.abs(predictions - labels), })print(df.head().to_markdown(tablefmt="grid", maxcolwidths=25))
+----+---------------------------+--------------------------+----------+--------------+--------------+
| | sentence_a | sentence_b | target | prediction | difference |
+====+===========================+==========================+==========+==============+==============+
| 0 | There is no boy playing | A group of kids is | 0.15 | 0.136733 | 0.0132667 |
| | outdoors and there is no | playing in a yard and an | | | |
| | man smiling | old man is standing in | | | |
| | | the background | | | |
+----+---------------------------+--------------------------+----------+--------------+--------------+
| 1 | A group of boys in a yard | The young boys are | 0.35 | 0.306338 | 0.043662 |
| | is playing and a man is | playing outdoors and the | | | |
| | standing in the | man is smiling nearby | | | |
| | background | | | | |
+----+---------------------------+--------------------------+----------+--------------+--------------+
| 2 | A group of children is | The young boys are | 0 | -0.00953974 | 0.00953974 |
| | playing in the house and | playing outdoors and the | | | |
| | there is no man standing | man is smiling nearby | | | |
| | in the background | | | | |
+----+---------------------------+--------------------------+----------+--------------+--------------+
| 3 | A brown dog is attacking | A brown dog is attacking | 0.95 | 0.959914 | 0.00991416 |
| | another animal in front | another animal in front | | | |
| | of the tall man in pants | of the man in pants | | | |
+----+---------------------------+--------------------------+----------+--------------+--------------+
| 4 | A brown dog is attacking | A brown dog is helping | 0.3325 | 0.816887 | 0.484387 |
| | another animal in front | another animal in front | | | |
| | of the man in pants | of the man in pants | | | |
+----+---------------------------+--------------------------+----------+--------------+--------------+
The huggingface trained model has the same problem with row 4, where the attacking/helping difference has not been identified. In a way this is encouraging - I am not trying to improve the training, I am changing how it is done.
I feel like this could be done in a better way by using a custom Trainer. I’ve done this in the past and it’s another way to separate the loss from the underlying model. Let’s try that next.
Huggingface Custom Trainer
Another way to work with the huggingface trainer is to subclass it, along with subclassing the arguments if desired. This allows you to use an unaltered model and then implement calculate_loss instead. Given the structure of the Sentence Transformers library this might be a better approach.
Code
from typing import Any, Dict, List, Tuple, Unionimport torchfrom torch import nnfrom transformers import Trainerfrom sentence_transformers import SentenceTransformerclass SentenceTransformersTrainer(Trainer):def__init__(self,*args, text_columns: List[str], loss: nn.Module,**kwargs, ) ->None:super().__init__(*args, **kwargs)self.text_columns = text_columnsself.loss = lossself.loss.to(self.model.device)def compute_loss(self, model: SentenceTransformer, inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs: bool=False, ) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: features =self.collect_features(inputs) loss =self.loss(features, inputs["label"])if return_outputs: output = torch.cat( [model(row)["sentence_embedding"][:, None] for row in features], dim=1 )return loss, outputreturn lossdef collect_features(self, inputs: Dict[str, Union[torch.Tensor, Any]] ) -> List[Dict[str, torch.Tensor]]:"""Turn the inputs from the dataloader into the separate model inputs."""return [ {"input_ids": inputs[f"{column}_input_ids"],"attention_mask": inputs[f"{column}_attention_mask"], }for column inself.text_columns ]
model = train( train_ds=sick_ds["train"], test_ds=sick_ds["validation"],# train dataset size is less than 500 batches, so no training loss is reported# this ensures that the logging_steps is smaller than the dataset size in batches to prevent "No log"# see https://github.com/huggingface/transformers/issues/8910 evaluation_steps=100,)
import pandas as pdimport numpy as npdf = pd.DataFrame( {"sentence_a": sick_ds["test"]["sentence_A"],"sentence_b": sick_ds["test"]["sentence_B"],"target": sick_ds["test"]["label"],"prediction": predictions,"difference": np.abs(predictions - labels), })print(df.head().to_markdown(tablefmt="grid", maxcolwidths=25))
+----+---------------------------+--------------------------+----------+--------------+--------------+
| | sentence_a | sentence_b | target | prediction | difference |
+====+===========================+==========================+==========+==============+==============+
| 0 | There is no boy playing | A group of kids is | 0.15 | 0.106 | 0.044 |
| | outdoors and there is no | playing in a yard and an | | | |
| | man smiling | old man is standing in | | | |
| | | the background | | | |
+----+---------------------------+--------------------------+----------+--------------+--------------+
| 1 | A group of boys in a yard | The young boys are | 0.35 | 0.377247 | 0.027247 |
| | is playing and a man is | playing outdoors and the | | | |
| | standing in the | man is smiling nearby | | | |
| | background | | | | |
+----+---------------------------+--------------------------+----------+--------------+--------------+
| 2 | A group of children is | The young boys are | 0 | 0.0330037 | 0.0330037 |
| | playing in the house and | playing outdoors and the | | | |
| | there is no man standing | man is smiling nearby | | | |
| | in the background | | | | |
+----+---------------------------+--------------------------+----------+--------------+--------------+
| 3 | A brown dog is attacking | A brown dog is attacking | 0.95 | 0.972256 | 0.0222558 |
| | another animal in front | another animal in front | | | |
| | of the tall man in pants | of the man in pants | | | |
+----+---------------------------+--------------------------+----------+--------------+--------------+
| 4 | A brown dog is attacking | A brown dog is helping | 0.3325 | 0.714325 | 0.381825 |
| | another animal in front | another animal in front | | | |
| | of the man in pants | of the man in pants | | | |
+----+---------------------------+--------------------------+----------+--------------+--------------+
This seems to be a better implementation, as it leaves the Sentence Transformers model unaltered. It still relies on the custom collator though. I think that’s just because of the sentence_A sentence_B inputs and you can’t really do anything about that.
It would be nice to be able to tokenize the inputs before collation, as that would be more consistent with the normal use of the trainer. I feel like this could speed up training by reducing the amount of repeated work.
I’ve made a PR to add this trainer and collator to the Sentence Transformers library.
Sentence Similarity with Negative Samples
The next train I was planning on was one where the positive pair was supplemented by several negative examples. To do this I need to reprocess the SICK dataset as I need to limit it to the sentences that entail each other. I can take the entailment sentences and label them as semantically identical and randomly select sentences from the dataset as the negative samples. Once again the size and distribution of the dataset works against this task - the sentences in the dataset often repeat the same concepts and so it’s likely that the negative samples would not be as negative as they really should be.
Even considering all of this I’m going to proceed. This is a demonstration of how to implement this, not an attempt to build a perfect system.
Broadly speaking I am going to use the custom trainer to supply negative examples from a list alongside the positive examples that will be the primary input to the trainer.
Found cached dataset sick (/home/matthew/.cache/huggingface/datasets/sick/default/0.0.0/c6b3b0b44eb84b134851396d6d464e5cb8f026960519d640e087fe33472626db)
Code
len(train_ds), len(validation_ds), len(test_ds)
(1274, 143, 1404)
Code
import pandas as pdprint( pd.DataFrame(train_ds) [["sentence_A", "sentence_B"]] .head() .to_markdown(tablefmt="grid", maxcolwidths=25))
+----+---------------------------+---------------------------+
| | sentence_A | sentence_B |
+====+===========================+===========================+
| 0 | The young boys are | The kids are playing |
| | playing outdoors and the | outdoors near a man with |
| | man is smiling nearby | a smile |
+----+---------------------------+---------------------------+
| 1 | A man with a jersey is | The ball is being dunked |
| | dunking the ball at a | by a man with a jersey at |
| | basketball game | a basketball game |
+----+---------------------------+---------------------------+
| 2 | Two young women are | Two women are sparring in |
| | sparring in a kickboxing | a kickboxing match |
| | fight | |
+----+---------------------------+---------------------------+
| 3 | Three boys are jumping in | Three kids are jumping in |
| | the leaves | the leaves |
+----+---------------------------+---------------------------+
| 4 | People wearing costumes | Masked people are looking |
| | are gathering in a forest | in the same direction in |
| | and are looking in the | a forest |
| | same direction | |
+----+---------------------------+---------------------------+
Code
from typing import Any, Dict, List, Tuple, Unionimport randomimport torchfrom torch import nnfrom transformers import Trainer, BatchEncodingfrom sentence_transformers import SentenceTransformerclass SentenceTransformersNegativeSampleTrainingArguments(TrainingArguments):def__init__(self,*args, negative_samples: int=5,**kwargs, ) ->None:super().__init__(*args, **kwargs)self.negative_samples = negative_samplesclass SentenceTransformersNegativeSampleTrainer(Trainer):def__init__(self,*args, text_columns: List[str], sentences: List[BatchEncoding], loss: nn.Module,**kwargs, ) ->None:super().__init__(*args, **kwargs)self.text_columns = text_columnsself.loss = lossself.loss.to(self.model.device)self.sentences = sentencesfor sentence inself.sentences: sentence.to(self.model.device)def compute_loss(self, model: SentenceTransformer, inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs: bool=False, ) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: features =self.collect_features(inputs) batch_size = features[0]["input_ids"].shape[0] loss =self.loss( features, torch.ones(batch_size, device=model.device), )for negative in random.choices(self.sentences, k=self.args.negative_samples): loss +=self.loss( [features[0], negative], torch.ones(batch_size, device=model.device) *-1 ) loss = loss / (1+self.args.negative_samples)if return_outputs: output = torch.cat( [model(row)["sentence_embedding"][:, None] for row in features], dim=1 )return loss, outputreturn lossdef collect_features(self, inputs: Dict[str, Union[torch.Tensor, Any]] ) -> List[Dict[str, torch.Tensor]]:"""Turn the inputs from the dataloader into the separate model inputs."""return [ {"input_ids": inputs[f"{column}_input_ids"],"attention_mask": inputs[f"{column}_attention_mask"], }for column inself.text_columns ]
model = train_negative_samples( train_ds=train_ds, test_ds=validation_ds, sentences=sentences,# train dataset size is less than 500 batches, so no training loss is reported# this ensures that the logging_steps is smaller than the dataset size in batches to prevent "No log"# see https://github.com/huggingface/transformers/issues/8910 evaluation_steps=10,)
import pandas as pdimport numpy as npdf = pd.DataFrame( {"sentence_a": sick_ds["test"]["sentence_A"],"sentence_b": sick_ds["test"]["sentence_B"],"target": sick_ds["test"]["label"],"prediction": predictions,"difference": np.abs(predictions - labels), })print(df.head().to_markdown(tablefmt="grid", maxcolwidths=25))
+----+---------------------------+--------------------------+----------+--------------+--------------+
| | sentence_a | sentence_b | target | prediction | difference |
+====+===========================+==========================+==========+==============+==============+
| 0 | There is no boy playing | A group of kids is | 0.15 | 0.186045 | 0.0360454 |
| | outdoors and there is no | playing in a yard and an | | | |
| | man smiling | old man is standing in | | | |
| | | the background | | | |
+----+---------------------------+--------------------------+----------+--------------+--------------+
| 1 | A group of boys in a yard | The young boys are | 0.35 | 0.508509 | 0.158509 |
| | is playing and a man is | playing outdoors and the | | | |
| | standing in the | man is smiling nearby | | | |
| | background | | | | |
+----+---------------------------+--------------------------+----------+--------------+--------------+
| 2 | A group of children is | The young boys are | 0 | 0.209752 | 0.209752 |
| | playing in the house and | playing outdoors and the | | | |
| | there is no man standing | man is smiling nearby | | | |
| | in the background | | | | |
+----+---------------------------+--------------------------+----------+--------------+--------------+
| 3 | A brown dog is attacking | A brown dog is attacking | 0.95 | 0.965482 | 0.0154816 |
| | another animal in front | another animal in front | | | |
| | of the tall man in pants | of the man in pants | | | |
+----+---------------------------+--------------------------+----------+--------------+--------------+
| 4 | A brown dog is attacking | A brown dog is helping | 0.3325 | 0.83249 | 0.49999 |
| | another animal in front | another animal in front | | | |
| | of the man in pants | of the man in pants | | | |
+----+---------------------------+--------------------------+----------+--------------+--------------+
This model is much better than I expected bearing in mind that it was trained on a fraction of the data in a noisy fashion. It also wasn’t given exact labels to work with.
It still suffers from the attacking/helping classification problem. I wonder if it started out well and then didn’t really change, as the training did not seem to alter the loss or cosine similarity very much.
Even so this was more about the technique than the results. It would be nice if it was possible to separate out the selection of negative samples from the trainer, but I think this approach is acceptable.