Using Huggingface Trainer to train a Sentence Transformer model

Sentence Transformer models create embeddings out of text, can they be trained with the Huggingface Trainer?
training
Published

October 19, 2022

The Sentence Transformer library is a way to turn documents into embeddings (Reimers and Gurevych 2019). I’m interested in training a sentence transformer model using the huggingface trainer to see how easy it would be.

Reimers, Nils, and Iryna Gurevych. 2019. “Sentence-BERT: Sentence Embeddings Using Siamese BERT-Networks.” In Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing. Association for Computational Linguistics. https://arxiv.org/abs/1908.10084.
Marelli, Marco, Stefano Menini, Marco Baroni, Luisa Bentivogli, Raffaella Bernardi, and Roberto Zamparelli. 2014. “A SICK Cure for the Evaluation of Compositional Distributional Semantic Models.” In Proceedings of the Ninth International Conference on Language Resources and Evaluation (LREC’14), 216–23. Reykjavik, Iceland: European Language Resources Association (ELRA). http://www.lrec-conf.org/proceedings/lrec2014/pdf/363_Paper.pdf.

They will use the SICK dataset (Marelli et al. 2014) which is a dataset of sentence pairs with both relatedness and entailment scores. The aim will be to embed the documents such that related statements are close to each other in the embedding space. I will not use the entailment score at this time.

I’m going to try two separate training approaches. The first will just be to normalize the relatedness score from -1 to 1 and then attempt to produce embeddings that have that cosine similarity. The second will be to take a highly related document pair and mix in random documents as detractors. Then we can compare how well the two approaches work.

The training set is quite small so training shouldn’t take too long.

Dataset

The SICK dataset is available on huggingface so let’s get it.

Code
import datasets

sick_ds = datasets.load_dataset("sick")
sick_ds
Found cached dataset sick (/home/matthew/.cache/huggingface/datasets/sick/default/0.0.0/c6b3b0b44eb84b134851396d6d464e5cb8f026960519d640e087fe33472626db)
DatasetDict({
    train: Dataset({
        features: ['id', 'sentence_A', 'sentence_B', 'label', 'relatedness_score', 'entailment_AB', 'entailment_BA', 'sentence_A_original', 'sentence_B_original', 'sentence_A_dataset', 'sentence_B_dataset'],
        num_rows: 4439
    })
    validation: Dataset({
        features: ['id', 'sentence_A', 'sentence_B', 'label', 'relatedness_score', 'entailment_AB', 'entailment_BA', 'sentence_A_original', 'sentence_B_original', 'sentence_A_dataset', 'sentence_B_dataset'],
        num_rows: 495
    })
    test: Dataset({
        features: ['id', 'sentence_A', 'sentence_B', 'label', 'relatedness_score', 'entailment_AB', 'entailment_BA', 'sentence_A_original', 'sentence_B_original', 'sentence_A_dataset', 'sentence_B_dataset'],
        num_rows: 4906
    })
})
Code
sick_ds["train"][0]
{'id': '1',
 'sentence_A': 'A group of kids is playing in a yard and an old man is standing in the background',
 'sentence_B': 'A group of boys in a yard is playing and a man is standing in the background',
 'label': 1,
 'relatedness_score': 4.5,
 'entailment_AB': 'A_neutral_B',
 'entailment_BA': 'B_neutral_A',
 'sentence_A_original': 'A group of children playing in a yard, a man in the background.',
 'sentence_B_original': 'A group of children playing in a yard, a man in the background.',
 'sentence_A_dataset': 'FLICKR',
 'sentence_B_dataset': 'FLICKR'}
Code
max(sick_ds["train"]["relatedness_score"]), min(sick_ds["train"]["relatedness_score"])
(5.0, 1.0)

As we can see the relatedness_score is a value that ranges between 1 and 5. To be able to use this with the Sentence Transformers CosineSimilarityLoss I need to map the score to between -1 and 1.

Code
from typing import TypedDict

maximum_relatedness = max(sick_ds["train"]["relatedness_score"])
minimum_relatedness = min(sick_ds["train"]["relatedness_score"])


class SickRow(TypedDict):
    id: str
    sentence_A: str
    sentence_B: str
    relatedness_score: float
    entailment_AB: str
    entailment_BA: str
    sentence_A_original: str
    sentence_B_original: str
    sentence_A_dataset: str
    sentence_B_dataset: str


def relatedness_to_label(row: SickRow) -> dict[str, float]:
    relatedness = row["relatedness_score"]
    relatedness = relatedness - minimum_relatedness
    relatedness = relatedness / (maximum_relatedness - minimum_relatedness)
    relatedness = (relatedness - 0.5) * 2
    return {"label": relatedness}


sick_ds = sick_ds.map(relatedness_to_label)
sick_ds = sick_ds.remove_columns(
    [
        "id",
        "relatedness_score",
        "entailment_AB",
        "entailment_BA",
        "sentence_A_original",
        "sentence_B_original",
        "sentence_A_dataset",
        "sentence_B_dataset",
    ]
)
Code
sick_ds
DatasetDict({
    train: Dataset({
        features: ['sentence_A', 'sentence_B', 'label'],
        num_rows: 4439
    })
    validation: Dataset({
        features: ['sentence_A', 'sentence_B', 'label'],
        num_rows: 495
    })
    test: Dataset({
        features: ['sentence_A', 'sentence_B', 'label'],
        num_rows: 4906
    })
})
Code
max(sick_ds["train"]["label"]), min(sick_ds["train"]["label"])
(1.0, -1.0)

I’ve been reviewing the Sentence Transformers training documentation and it looks like the inputs to the model don’t need to be encoded. No doubt part of this library is making it easy to invoke. Let’s try just using the original Sentence Transformers training approach and then we can compare that to huggingface.

Sentence Transformers Training

The training overview just shows a very simple training process involving calling fit over a list of examples. We can recreate that with the SICK training data.

Code
from pathlib import Path

MODEL_NAME = "all-MiniLM-L12-v2"
# MODEL_NAME = "nli-distilroberta-base-v2"
LEARNING_RATE = 2e-5
WARMUP_STEPS = 100
EPOCHS = 10
BATCH_SIZE = 32
TEXT_COLUMNS = ["sentence_A", "sentence_B"]

MODEL_RUN_FOLDER = Path("/data/blog/2022-10-19-sentence-transformers-and-huggingface")
MODEL_RUN_FOLDER.mkdir(parents=True, exist_ok=True)
Code
from sentence_transformers import (
    SentenceTransformer,
    InputExample,
    losses,
    evaluation,
)
from torch.utils.data import DataLoader

# Define the model. Either from scratch of by loading a pre-trained model
model = SentenceTransformer(MODEL_NAME)

# Define your train examples. You need more than just two examples...
train_examples = [
    InputExample(
        texts=[row["sentence_A"], row["sentence_B"]],
        label=row["label"],
    )
    for row in sick_ds["train"]
]
evaluator = evaluation.EmbeddingSimilarityEvaluator(
    sick_ds["validation"]["sentence_A"],
    sick_ds["validation"]["sentence_B"],
    sick_ds["validation"]["label"],
    main_similarity=evaluation.SimilarityFunction.COSINE,
)

# Define your train dataset, the dataloader and the train loss
train_dataloader = DataLoader(
    train_examples,
    shuffle=True,
    batch_size=BATCH_SIZE,
)
train_loss = losses.CosineSimilarityLoss(model)


def show_evaluation(score: float, epoch: float, steps: int) -> None:
    if steps == -1:
        print(f"evaluation: {score} epoch {epoch}")
    else:
        print(f"evaluation: {score} epoch {epoch} steps {steps}")


# Tune the model
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=EPOCHS,
    warmup_steps=WARMUP_STEPS,
    optimizer_params={"lr": LEARNING_RATE},
    evaluator=evaluator,
    evaluation_steps=500,
    callback=show_evaluation,
)
evaluation: 0.7836338920116853 epoch 0
evaluation: 0.8033716010524109 epoch 1
evaluation: 0.8175660042766647 epoch 2
evaluation: 0.8146349323817795 epoch 3
evaluation: 0.8127202214972602 epoch 4
evaluation: 0.8139044909698879 epoch 5
evaluation: 0.8154477934840207 epoch 6
evaluation: 0.815803262430808 epoch 7
evaluation: 0.8139142427290464 epoch 8
evaluation: 0.8152793405072882 epoch 9

It took a couple of minutes to train that. The evaluation hasn’t significantly changed though. It seems that the model learnt most of the task in the first three epochs.

The display of the results isn’t as slick as the huggingface version, which separates the progress bar from the individual row scores as well as formatting the scores nicely.

Let’s try running this on the test dataset.

Code
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np

model.eval()

with torch.inference_mode():
    embedding_a = model.encode(sick_ds["test"]["sentence_A"])
    embedding_b = model.encode(sick_ds["test"]["sentence_B"])

labels = sick_ds["test"]["label"]
predictions = F.cosine_similarity(
    torch.from_numpy(embedding_a),
    torch.from_numpy(embedding_b),
).numpy()

sentence_transformers_performance = pd.DataFrame(
    {
        "predictions": predictions,
        "targets": labels,
        "difference": np.abs(predictions - labels),
    }
).describe()
sentence_transformers_performance
predictions targets difference
count 4906.000000 4906.000000 4906.000000
mean 0.416570 0.263954 0.247593
std 0.367427 0.504683 0.232843
min -0.642627 -1.000000 0.000121
25% 0.150933 0.000000 0.073947
50% 0.403467 0.300000 0.176686
75% 0.732815 0.650000 0.343893
max 0.998161 1.000000 1.419088

I kinda think this is ok? Without a point of reference it’s difficult to say.

Let’s see what the predictions are for specific sentences.

Code
import pandas as pd
import numpy as np

df = pd.DataFrame(
    {
        "sentence_a": sick_ds["test"]["sentence_A"],
        "sentence_b": sick_ds["test"]["sentence_B"],
        "target": sick_ds["test"]["label"],
        "prediction": predictions,
        "difference": np.abs(predictions - labels),
    }
)
print(df.head().to_markdown(tablefmt="grid", maxcolwidths=25))
+----+---------------------------+--------------------------+----------+--------------+--------------+
|    | sentence_a                | sentence_b               |   target |   prediction |   difference |
+====+===========================+==========================+==========+==============+==============+
|  0 | There is no boy playing   | A group of kids is       |   0.15   |    0.154085  |   0.00408456 |
|    | outdoors and there is no  | playing in a yard and an |          |              |              |
|    | man smiling               | old man is standing in   |          |              |              |
|    |                           | the background           |          |              |              |
+----+---------------------------+--------------------------+----------+--------------+--------------+
|  1 | A group of boys in a yard | The young boys are       |   0.35   |    0.325681  |   0.0243189  |
|    | is playing and a man is   | playing outdoors and the |          |              |              |
|    | standing in the           | man is smiling nearby    |          |              |              |
|    | background                |                          |          |              |              |
+----+---------------------------+--------------------------+----------+--------------+--------------+
|  2 | A group of children is    | The young boys are       |   0      |    0.0895932 |   0.0895932  |
|    | playing in the house and  | playing outdoors and the |          |              |              |
|    | there is no man standing  | man is smiling nearby    |          |              |              |
|    | in the background         |                          |          |              |              |
+----+---------------------------+--------------------------+----------+--------------+--------------+
|  3 | A brown dog is attacking  | A brown dog is attacking |   0.95   |    0.931177  |   0.0188229  |
|    | another animal in front   | another animal in front  |          |              |              |
|    | of the tall man in pants  | of the man in pants      |          |              |              |
+----+---------------------------+--------------------------+----------+--------------+--------------+
|  4 | A brown dog is attacking  | A brown dog is helping   |   0.3325 |    0.838758  |   0.506258   |
|    | another animal in front   | another animal in front  |          |              |              |
|    | of the man in pants       | of the man in pants      |          |              |              |
+----+---------------------------+--------------------------+----------+--------------+--------------+

I’ve used this output style to wrap the text nicely. I’m not sure about the targets but the predictions are similar to them. I guess this trained well.

Sentence Transformers Details

This was a slick train. Using this library is very nice, it made that training process very simple. To get this working in huggingface I need to understand further how it performs the training loop.

The key method appears to be model.fit, so I am going to start there. This is a chunk of code so I am going to break it down.

Sentence Transformers - Model Card

The method starts with some code around the model card. Presumably this is for uploading to the huggingface hub, as they have over 100 models there.

info_loss_functions = []
for dataloader, loss in train_objectives:
    info_loss_functions.extend(
        ModelCardTemplate.get_train_objective_info(dataloader, loss)
    )
info_loss_functions = "\n\n".join([text for text in info_loss_functions])

info_fit_parameters = json.dumps(
    {
        "evaluator": fullname(evaluator),
        "epochs": epochs,
        "steps_per_epoch": steps_per_epoch,
        "scheduler": scheduler,
        "warmup_steps": warmup_steps,
        "optimizer_class": str(optimizer_class),
        "optimizer_params": optimizer_params,
        "weight_decay": weight_decay,
        "evaluation_steps": evaluation_steps,
        "max_grad_norm": max_grad_norm,
    },
    indent=4,
    sort_keys=True,
)
self._model_card_text = None
self._model_card_vars[
    "{TRAINING_SECTION}"
] = ModelCardTemplate.__TRAINING_SECTION__.replace(
    "{LOSS_FUNCTIONS}", info_loss_functions
).replace(
    "{FIT_PARAMETERS}", info_fit_parameters
)

Sentence Transformers - Preparation

After this comes the data and objective preparation. This involves putting the datasets into a useable format and creating the optimizers. The optimizers are already handled by the huggingface trainer, and data handling is handled by the datasets library.

# Use smart batching
for dataloader in dataloaders:
    dataloader.collate_fn = self.smart_batching_collate

loss_models = [loss for _, loss in train_objectives]
for loss_model in loss_models:
    loss_model.to(self._target_device)

self.best_score = -9999999

if steps_per_epoch is None or steps_per_epoch == 0:
    steps_per_epoch = min([len(dataloader) for dataloader in dataloaders])

num_train_steps = int(steps_per_epoch * epochs)

# Prepare optimizers
optimizers = []
schedulers = []
for loss_model in loss_models:
    param_optimizer = list(loss_model.named_parameters())

    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": weight_decay,
        },
        {
            "params": [
                p for n, p in param_optimizer if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]

    optimizer = optimizer_class(optimizer_grouped_parameters, **optimizer_params)
    scheduler_obj = self._get_scheduler(
        optimizer,
        scheduler=scheduler,
        warmup_steps=warmup_steps,
        t_total=num_train_steps,
    )

    optimizers.append(optimizer)
    schedulers.append(scheduler_obj)

global_step = 0
data_iterators = [iter(dataloader) for dataloader in dataloaders]

num_train_objectives = len(train_objectives)

Sentence Transformers - Train Loop

Finally comes the training loop. This mixes in the evaluation as well. It looks like quite standard stuff and this is exactly what the huggingface trainer was made to automate.

    skip_scheduler = False
    for epoch in trange(epochs, desc="Epoch", disable=not show_progress_bar):
        training_steps = 0

        for loss_model in loss_models:
            loss_model.zero_grad()
            loss_model.train()

        for _ in trange(
            steps_per_epoch,
            desc="Iteration",
            smoothing=0.05,
            disable=not show_progress_bar,
        ):
            for train_idx in range(num_train_objectives):
                loss_model = loss_models[train_idx]
                optimizer = optimizers[train_idx]
                scheduler = schedulers[train_idx]
                data_iterator = data_iterators[train_idx]

                try:
                    data = next(data_iterator)
                except StopIteration:
                    data_iterator = iter(dataloaders[train_idx])
                    data_iterators[train_idx] = data_iterator
                    data = next(data_iterator)

                features, labels = data
                labels = labels.to(self._target_device)
                features = list(
                    map(
                        lambda batch: batch_to_device(batch, self._target_device),
                        features,
                    )
                )

                if use_amp:
                    with autocast():
                        loss_value = loss_model(features, labels)

                    scale_before_step = scaler.get_scale()
                    scaler.scale(loss_value).backward()
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(
                        loss_model.parameters(), max_grad_norm
                    )
                    scaler.step(optimizer)
                    scaler.update()

                    skip_scheduler = scaler.get_scale() != scale_before_step
                else:
                    loss_value = loss_model(features, labels)
                    loss_value.backward()
                    torch.nn.utils.clip_grad_norm_(
                        loss_model.parameters(), max_grad_norm
                    )
                    optimizer.step()

                optimizer.zero_grad()

                if not skip_scheduler:
                    scheduler.step()

            training_steps += 1
            global_step += 1

            if evaluation_steps > 0 and training_steps % evaluation_steps == 0:
                self._eval_during_training(
                    evaluator,
                    output_path,
                    save_best_model,
                    epoch,
                    training_steps,
                    callback,
                )

                for loss_model in loss_models:
                    loss_model.zero_grad()
                    loss_model.train()

            if (
                checkpoint_path is not None
                and checkpoint_save_steps is not None
                and checkpoint_save_steps > 0
                and global_step % checkpoint_save_steps == 0
            ):
                self._save_checkpoint(
                    checkpoint_path, checkpoint_save_total_limit, global_step
                )

        self._eval_during_training(
            evaluator, output_path, save_best_model, epoch, -1, callback
        )

The one surprising thing about this loop is that it doesn’t really refer to the model (which would be self). This is because the loss_model is defined in a way that wraps the model:

train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
train_loss = losses.CosineSimilarityLoss(model)

# ...

model.fit(
    train_objectives=[(train_dataloader, train_loss)],

This means that the following code is actually responsible for back propagation:

loss_value = loss_model(features, labels)
loss_value.backward()

Even with this, the code generally seems good and I’m hopeful that a version can be created that works with huggingface.

Sentence Transformers - Overview

This all hides some of the complexity behind methods which are not immediately obvious. The general approach for training is this:

These sections are found in the sentence transformers trainer, however it’s not immediately obvious.

The first is the data preparation which involves tokenizing the sentences. This is found in the data collator, which is an interesting choice as it means it has to be done for every batch even after the first epoch. The code for the collator is:

def smart_batching_collate(self, batch):
    """
    Transforms a batch from a SmartBatchingDataset to a batch of tensors for the model
    Here, batch is a list of tuples: [(tokens, label), ...]

    :param batch:
        a batch from a SmartBatchingDataset
    :return:
        a batch of tensors for the model
    """
    num_texts = len(batch[0].texts)
    texts = [[] for _ in range(num_texts)]
    labels = []

    for example in batch:
        for idx, text in enumerate(example.texts):
            texts[idx].append(text)

        labels.append(example.label)

    labels = torch.tensor(labels)

    sentence_features = []
    for idx in range(num_texts):
        tokenized = self.tokenize(texts[idx])
        sentence_features.append(tokenized)

    return sentence_features, labels

In huggingface this would be performed by tokenizing the dataset.

The training loop itself conceals the location of the model by passing it as a parameter to the losses.CosineSimilarityLoss which transforms the model output by calculating the loss:

class CosineSimilarityLoss(nn.Module):
    def __init__(
        self,
        model: SentenceTransformer,
        loss_fct = nn.MSELoss(),
        cos_score_transformation=nn.Identity()
    ):
        super(CosineSimilarityLoss, self).__init__()
        self.model = model
        self.loss_fct = loss_fct
        self.cos_score_transformation = cos_score_transformation


    def forward(
        self,
        sentence_features: Iterable[Dict[str, Tensor]],
        labels: Tensor
    ):
        embeddings = [
            self.model(sentence_feature)['sentence_embedding']
            for sentence_feature in sentence_features
        ]
        output = self.cos_score_transformation(
            torch.cosine_similarity(embeddings[0], embeddings[1])
        )
        return self.loss_fct(output, labels.view(-1))

This is straightforward again, the difference in choice is just a question of the separation of concerns. In huggingface it is the case that you implement the loss calculation on the model itself, as the loss is assumed to be intimately related to the model. Sentence Transformers has chosen to allow the loss to vary over the same model.

Huggingface Custom Model

To create a custom model which can work with the Huggingface Trainer, we can follow the warning in the documentation:

The Trainer class is optimized for 🤗 Transformers models and can have surprising behaviors when you use it on other models. When using it on your own model, make sure: * your model always return tuples or subclasses of ModelOutput. * your model can compute the loss if a labels argument is provided and that loss is returned as the first element of the tuple (if your model returns tuples) * your model can accept multiple label arguments (use the label_names in your TrainingArguments to indicate their name to the Trainer) but none of them should be named “label”.

The important thing to realise here is that we want the full capabilities of the Sentence Transformer library to be available. That means that I want to use their data, loss and model as much as possible. A simple adapter between the two libraries is what is needed.

Given the review of the code and the requirements of the huggingface trainer, we can do this most easily by creating a custom model and collator.

Custom Collator

The collator is required because we need to handle two sentences for each row. Embeddings generated from these sentences are compared with each other to determine proximity.

The default DataCollatorWithPadding is designed for single sentence inputs and that makes it unsuitable. To reduce the number of parameters that the model requires we will not return the attention mask at this point. We can infer the attention mask by comparing the tokens to the pad_token_id.

Code
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

import torch
from transformers import PreTrainedTokenizerBase
from transformers.tokenization_utils import BatchEncoding
from transformers.utils.generic import PaddingStrategy


@dataclass
class SentenceTransformersCollator:
    """Collator for a SentenceTransformers model.
    This encodes the text columns to {column}_input_ids and {column}_attention_mask columns.
    This works with the two text dataset that is used as the example in the training overview:
    https://www.sbert.net/docs/training/overview.html"""

    tokenizer: PreTrainedTokenizerBase
    text_columns: List[str]

    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    return_tensors: str = "pt"

    def __init__(self, tokenizer: PreTrainedTokenizerBase, text_columns: List[str]) -> None:
        self.tokenizer = tokenizer
        self.text_columns = text_columns

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        if "label" in features[0]:
            batch = {"label": torch.tensor([row["label"] for row in features])}
        else:
            batch = {}
        for column in self.text_columns:
            padded = self._encode([row[column] for row in features])
            batch[f"{column}_input_ids"] = padded.input_ids
            batch[f"{column}_attention_mask"] = padded.attention_mask
        return batch

    def _encode(self, texts: List[str]) -> BatchEncoding:
        tokens = self.tokenizer(texts, return_attention_mask=False)
        return self.tokenizer.pad(
            tokens,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=self.return_tensors,
        )

Custom Model

The model is a union of the sentence transformer model with a loss model. We will need to group the inputs together into a list of dicts, as that is what the loss function expects. Given that huggingface expects a single result row per invocation we also need to reshape the output a little.

This does feel like needless complexity.

Code
from typing import Tuple, Union, List, Dict, Any, Optional
import torch
import torch.nn as nn
import numpy as np
from sentence_transformers.SentenceTransformer import SentenceTransformer
from transformers import AutoTokenizer, DataCollator

class HuggingfaceSentenceTransformersModel(nn.Module):
    def __init__(
        self,
        model: SentenceTransformer,
        text_columns: List[str],
        loss: nn.Module,
    ) -> None:
        super().__init__()
        self.model = model
        self.text_columns = text_columns
        self.loss = loss

    def forward(self, label: Optional[torch.Tensor] = None, **inputs) -> Tuple[torch.Tensor, ...]:
        pad_token_id = self.model.tokenizer.pad_token_id
        features = self.collect_features(inputs)
        output = torch.cat([
            self.model(row)["sentence_embedding"][:, None]
            for row in features
        ], dim=1)
        if label is None:
            return (output,)
        loss = self.loss(features, label)
        return (loss, output)

    def collect_features(
        self, inputs: Dict[str, Union[torch.Tensor, Any]]
    ) -> List[Dict[str, torch.Tensor]]:
        """Turn the inputs from the dataloader into the separate model inputs."""
        return [
            {
                "input_ids": inputs[f"{column}_input_ids"],
                "attention_mask": inputs[f"{column}_attention_mask"],
            }
            for column in self.text_columns
        ]

Training compatible Model

Now that we have put everything together we can train the model.

Code
from pathlib import Path
from transformers import (
    Trainer,
    TrainingArguments,
    EvalPrediction,
)
from sentence_transformers import (
    SentenceTransformer,
    losses,
)

model = SentenceTransformer(MODEL_NAME)
train_loss = losses.CosineSimilarityLoss(model)
hf_model = HuggingfaceSentenceTransformersModel(
    model=model,
    loss=train_loss,
    text_columns=TEXT_COLUMNS,
)

evaluator = evaluation.EmbeddingSimilarityEvaluator(
    sick_ds["validation"]["sentence_A"],
    sick_ds["validation"]["sentence_B"],
    sick_ds["validation"]["label"],
    main_similarity=evaluation.SimilarityFunction.COSINE,
)
def compute_metrics(predictions: EvalPrediction) -> Dict[str, float]:
    return {
        "cosine_similarity": evaluator(model)
    }

training_args = TrainingArguments(
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    warmup_steps=WARMUP_STEPS,
    optim="adamw_torch",

    report_to=[], # you'd use wandb for weights and biases

    # even shorter as this is testing the model and metrics
    evaluation_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=EPOCHS,
    logging_steps=100,

    load_best_model_at_end=True,
    metric_for_best_model="cosine_similarity",
    greater_is_better=True,
    
    no_cuda=False,
    remove_unused_columns=False,

    # output_dir is compulsory
    logging_dir=MODEL_RUN_FOLDER / "output",
    output_dir=MODEL_RUN_FOLDER / "output",
    overwrite_output_dir=True,
)

trainer = Trainer(
    model=hf_model,
    args=training_args,
    data_collator=SentenceTransformersCollator(
        model.tokenizer, text_columns=TEXT_COLUMNS,
    ),
    train_dataset=sick_ds["train"],
    eval_dataset=sick_ds["test"],
    tokenizer=model.tokenizer,
    compute_metrics=compute_metrics,
)

trainer.train()
[1390/1390 02:53, Epoch 10/10]
Epoch Training Loss Validation Loss Cosine Similarity
1 0.171700 0.147226 0.792164
2 0.134900 0.134125 0.805221
3 0.106200 0.128251 0.810922
4 0.092000 0.122461 0.821752
5 0.082800 0.120815 0.818696
6 0.070600 0.118044 0.812771
7 0.068000 0.117768 0.813136
8 0.062000 0.116243 0.810268
9 0.059200 0.115605 0.810959
10 0.055300 0.115676 0.812836

TrainOutput(global_step=1390, training_loss=0.08742413692337146, metrics={'train_runtime': 174.0469, 'train_samples_per_second': 255.046, 'train_steps_per_second': 7.986, 'total_flos': 0.0, 'train_loss': 0.08742413692337146, 'epoch': 10.0})

A very similar training result to the SentenceTransformers train. Once again the best model is well before the 10th epoch.

Code
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np

model.eval()

with torch.inference_mode():
    embedding_a = model.encode(sick_ds["test"]["sentence_A"])
    embedding_b = model.encode(sick_ds["test"]["sentence_B"])

labels = sick_ds["test"]["label"]
predictions = F.cosine_similarity(
    torch.from_numpy(embedding_a),
    torch.from_numpy(embedding_b),
).numpy()

model_performance = pd.DataFrame(
    {
        "predictions": predictions,
        "targets": labels,
        "difference": np.abs(predictions - labels),
    }
).describe()
model_performance
predictions targets difference
count 4906.000000 4906.000000 4906.000000
mean 0.431406 0.263954 0.253702
std 0.341693 0.504683 0.241056
min -0.497602 -1.000000 0.000082
25% 0.179611 0.000000 0.075927
50% 0.403866 0.300000 0.176595
75% 0.723692 0.650000 0.347679
max 0.995991 1.000000 1.220004

We can compare this to the sentence transformers results:

Code
(model_performance.difference - sentence_transformers_performance.difference).to_frame()
difference
count 0.000000
mean 0.006109
std 0.008213
min -0.000039
25% 0.001981
50% -0.000091
75% 0.003786
max -0.199084

This is really a very small difference. When the number is negative the huggingface model is better than the sentence transformers model. There are places where the huggingface model is better, but really I think these are very similar results. I think that this was a successful train.

Code
import pandas as pd
import numpy as np

df = pd.DataFrame(
    {
        "sentence_a": sick_ds["test"]["sentence_A"],
        "sentence_b": sick_ds["test"]["sentence_B"],
        "target": sick_ds["test"]["label"],
        "prediction": predictions,
        "difference": np.abs(predictions - labels),
    }
)
print(df.head().to_markdown(tablefmt="grid", maxcolwidths=25))
+----+---------------------------+--------------------------+----------+--------------+--------------+
|    | sentence_a                | sentence_b               |   target |   prediction |   difference |
+====+===========================+==========================+==========+==============+==============+
|  0 | There is no boy playing   | A group of kids is       |   0.15   |   0.136733   |   0.0132667  |
|    | outdoors and there is no  | playing in a yard and an |          |              |              |
|    | man smiling               | old man is standing in   |          |              |              |
|    |                           | the background           |          |              |              |
+----+---------------------------+--------------------------+----------+--------------+--------------+
|  1 | A group of boys in a yard | The young boys are       |   0.35   |   0.306338   |   0.043662   |
|    | is playing and a man is   | playing outdoors and the |          |              |              |
|    | standing in the           | man is smiling nearby    |          |              |              |
|    | background                |                          |          |              |              |
+----+---------------------------+--------------------------+----------+--------------+--------------+
|  2 | A group of children is    | The young boys are       |   0      |  -0.00953974 |   0.00953974 |
|    | playing in the house and  | playing outdoors and the |          |              |              |
|    | there is no man standing  | man is smiling nearby    |          |              |              |
|    | in the background         |                          |          |              |              |
+----+---------------------------+--------------------------+----------+--------------+--------------+
|  3 | A brown dog is attacking  | A brown dog is attacking |   0.95   |   0.959914   |   0.00991416 |
|    | another animal in front   | another animal in front  |          |              |              |
|    | of the tall man in pants  | of the man in pants      |          |              |              |
+----+---------------------------+--------------------------+----------+--------------+--------------+
|  4 | A brown dog is attacking  | A brown dog is helping   |   0.3325 |   0.816887   |   0.484387   |
|    | another animal in front   | another animal in front  |          |              |              |
|    | of the man in pants       | of the man in pants      |          |              |              |
+----+---------------------------+--------------------------+----------+--------------+--------------+

The huggingface trained model has the same problem with row 4, where the attacking/helping difference has not been identified. In a way this is encouraging - I am not trying to improve the training, I am changing how it is done.

I feel like this could be done in a better way by using a custom Trainer. I’ve done this in the past and it’s another way to separate the loss from the underlying model. Let’s try that next.

Huggingface Custom Trainer

Another way to work with the huggingface trainer is to subclass it, along with subclassing the arguments if desired. This allows you to use an unaltered model and then implement calculate_loss instead. Given the structure of the Sentence Transformers library this might be a better approach.

Code
from typing import Any, Dict, List, Tuple, Union

import torch
from torch import nn
from transformers import Trainer

from sentence_transformers import SentenceTransformer

class SentenceTransformersTrainer(Trainer):
    def __init__(
        self,
        *args,
        text_columns: List[str],
        loss: nn.Module,
        **kwargs,
    ) -> None:
        super().__init__(*args, **kwargs)
        self.text_columns = text_columns
        self.loss = loss
        self.loss.to(self.model.device)

    def compute_loss(
        self,
        model: SentenceTransformer,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        return_outputs: bool = False,
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
        features = self.collect_features(inputs)
        loss = self.loss(features, inputs["label"])
        if return_outputs:
            output = torch.cat(
                [model(row)["sentence_embedding"][:, None] for row in features], dim=1
            )
            return loss, output
        return loss

    def collect_features(
        self, inputs: Dict[str, Union[torch.Tensor, Any]]
    ) -> List[Dict[str, torch.Tensor]]:
        """Turn the inputs from the dataloader into the separate model inputs."""
        return [
            {
                "input_ids": inputs[f"{column}_input_ids"],
                "attention_mask": inputs[f"{column}_attention_mask"],
            }
            for column in self.text_columns
        ]
Code
from pathlib import Path
from typing import Optional, Dict, Tuple, Union

from transformers import (
    Trainer,
    TrainingArguments,
    EvalPrediction,
)
from sentence_transformers import (
    SentenceTransformer,
    losses,
    evaluation,
)
import datasets
import torch


def train(
    *,
    train_ds: datasets.Dataset,
    test_ds: datasets.Dataset,
    model_name: str = MODEL_NAME,
    batch_size: int = BATCH_SIZE,
    learning_rate: float = LEARNING_RATE,
    fp16: bool = False,
    epochs: Optional[float] = EPOCHS,
    max_steps: int = -1,
    run_folder: Path = MODEL_RUN_FOLDER,
    model_folder: Path = MODEL_RUN_FOLDER,
    # number of steps before moving evaluation results from GPU to CPU see
    # https://discuss.huggingface.co/t/cuda-out-of-memory-when-using-trainer-with-compute-metrics/2941
    eval_accumulation_steps: Optional[int] = None,
    evaluation_strategy: str = "epoch",
    save_strategy: str = "epoch",
    evaluation_steps: int = 500,
) -> Path:
    run_name = "-".join(
        [
            f"{model_name}",
            f"e{epochs}" if max_steps == -1 else f"ms{max_steps}",
            f"bs{batch_size}",
            f"lr{learning_rate}",
        ]
        + (["fp16"] if fp16 else [])
    )
    print(f"Starting {run_name}")

    training_args = TrainingArguments(
        report_to="none",
        output_dir=run_folder,
        num_train_epochs=epochs,
        max_steps=max_steps,
        seed=33,
        eval_accumulation_steps=eval_accumulation_steps,
        #
        # hyperparameters
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        fp16=fp16,
        learning_rate=learning_rate,
        warmup_steps=WARMUP_STEPS,
        optim="adamw_torch",
        #
        # evaluation settings
        evaluation_strategy=evaluation_strategy,
        save_strategy=evaluation_strategy,
        logging_steps=evaluation_steps,
        eval_steps=evaluation_steps,
        save_steps=evaluation_steps,
        #
        # checkpoint settings
        logging_dir=run_folder / "logs",
        save_total_limit=2,
        load_best_model_at_end=True,
        metric_for_best_model="cosine_similarity",
        greater_is_better=True,
        #
        # needed to get sentence_A and sentence_B
        remove_unused_columns=False,
    )

    model = SentenceTransformer(model_name)
    tokenizer = model.tokenizer
    loss = losses.CosineSimilarityLoss(model)
    evaluator = evaluation.EmbeddingSimilarityEvaluator(
        test_ds["sentence_A"],
        test_ds["sentence_B"],
        test_ds["label"],
        main_similarity=evaluation.SimilarityFunction.COSINE,
    )
    def compute_metrics(predictions: EvalPrediction) -> Dict[str, float]:
        return {
            "cosine_similarity": evaluator(model)
        }
    
    data_collator = SentenceTransformersCollator(
        tokenizer=tokenizer,
        text_columns=TEXT_COLUMNS,
    )

    trainer = SentenceTransformersTrainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=test_ds,
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
        # custom arguments
        loss=loss,
        text_columns=TEXT_COLUMNS,
    )

    trainer.train()
    model.save(str(model_folder / run_name)) # this does not accept Path objects

    return model
Code
model = train(
    train_ds=sick_ds["train"],
    test_ds=sick_ds["validation"],
    # train dataset size is less than 500 batches, so no training loss is reported
    # this ensures that the logging_steps is smaller than the dataset size in batches to prevent "No log"
    # see https://github.com/huggingface/transformers/issues/8910
    evaluation_steps=100,
)
PyTorch: setting up devices
Starting all-MiniLM-L12-v2-e10-bs32-lr2e-05
[1390/1390 01:36, Epoch 10/10]
Epoch Training Loss Validation Loss Cosine Similarity
1 0.168700 0.133041 0.772604
2 0.134500 0.120427 0.805793
3 0.104700 0.114581 0.814234
4 0.092500 0.109890 0.818900
5 0.086400 0.108434 0.811472
6 0.070100 0.106079 0.816036
7 0.067600 0.105010 0.814723
8 0.061300 0.103482 0.817867
9 0.055200 0.103269 0.817235
10 0.058400 0.103041 0.818394

Code
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np

model.eval()

with torch.inference_mode():
    embedding_a = model.encode(sick_ds["test"]["sentence_A"])
    embedding_b = model.encode(sick_ds["test"]["sentence_B"])

labels = sick_ds["test"]["label"]
predictions = F.cosine_similarity(
    torch.from_numpy(embedding_a),
    torch.from_numpy(embedding_b),
).numpy()

trainer_performance = pd.DataFrame(
    {
        "predictions": predictions,
        "targets": labels,
        "difference": np.abs(predictions - labels),
    }
).describe()
trainer_performance
predictions targets difference
count 4906.000000 4906.000000 4906.000000
mean 0.428635 0.263954 0.254517
std 0.346585 0.504683 0.243137
min -0.451704 -1.000000 0.000133
25% 0.168755 0.000000 0.072743
50% 0.400480 0.300000 0.178784
75% 0.729426 0.650000 0.353027
max 0.997572 1.000000 1.255221
Code
(trainer_performance.difference - sentence_transformers_performance.difference).to_frame()
difference
count 0.000000
mean 0.006925
std 0.010294
min 0.000012
25% -0.001203
50% 0.002099
75% 0.009135
max -0.163866
Code
import pandas as pd
import numpy as np

df = pd.DataFrame(
    {
        "sentence_a": sick_ds["test"]["sentence_A"],
        "sentence_b": sick_ds["test"]["sentence_B"],
        "target": sick_ds["test"]["label"],
        "prediction": predictions,
        "difference": np.abs(predictions - labels),
    }
)
print(df.head().to_markdown(tablefmt="grid", maxcolwidths=25))
+----+---------------------------+--------------------------+----------+--------------+--------------+
|    | sentence_a                | sentence_b               |   target |   prediction |   difference |
+====+===========================+==========================+==========+==============+==============+
|  0 | There is no boy playing   | A group of kids is       |   0.15   |    0.106     |    0.044     |
|    | outdoors and there is no  | playing in a yard and an |          |              |              |
|    | man smiling               | old man is standing in   |          |              |              |
|    |                           | the background           |          |              |              |
+----+---------------------------+--------------------------+----------+--------------+--------------+
|  1 | A group of boys in a yard | The young boys are       |   0.35   |    0.377247  |    0.027247  |
|    | is playing and a man is   | playing outdoors and the |          |              |              |
|    | standing in the           | man is smiling nearby    |          |              |              |
|    | background                |                          |          |              |              |
+----+---------------------------+--------------------------+----------+--------------+--------------+
|  2 | A group of children is    | The young boys are       |   0      |    0.0330037 |    0.0330037 |
|    | playing in the house and  | playing outdoors and the |          |              |              |
|    | there is no man standing  | man is smiling nearby    |          |              |              |
|    | in the background         |                          |          |              |              |
+----+---------------------------+--------------------------+----------+--------------+--------------+
|  3 | A brown dog is attacking  | A brown dog is attacking |   0.95   |    0.972256  |    0.0222558 |
|    | another animal in front   | another animal in front  |          |              |              |
|    | of the tall man in pants  | of the man in pants      |          |              |              |
+----+---------------------------+--------------------------+----------+--------------+--------------+
|  4 | A brown dog is attacking  | A brown dog is helping   |   0.3325 |    0.714325  |    0.381825  |
|    | another animal in front   | another animal in front  |          |              |              |
|    | of the man in pants       | of the man in pants      |          |              |              |
+----+---------------------------+--------------------------+----------+--------------+--------------+

This seems to be a better implementation, as it leaves the Sentence Transformers model unaltered. It still relies on the custom collator though. I think that’s just because of the sentence_A sentence_B inputs and you can’t really do anything about that.

It would be nice to be able to tokenize the inputs before collation, as that would be more consistent with the normal use of the trainer. I feel like this could speed up training by reducing the amount of repeated work.

I’ve made a PR to add this trainer and collator to the Sentence Transformers library.

Sentence Similarity with Negative Samples

The next train I was planning on was one where the positive pair was supplemented by several negative examples. To do this I need to reprocess the SICK dataset as I need to limit it to the sentences that entail each other. I can take the entailment sentences and label them as semantically identical and randomly select sentences from the dataset as the negative samples. Once again the size and distribution of the dataset works against this task - the sentences in the dataset often repeat the same concepts and so it’s likely that the negative samples would not be as negative as they really should be.

Even considering all of this I’m going to proceed. This is a demonstration of how to implement this, not an attempt to build a perfect system.

Broadly speaking I am going to use the custom trainer to supply negative examples from a list alongside the positive examples that will be the primary input to the trainer.

Code
import datasets
import pandas as pd

def load_positive_negative_dataset() -> Tuple[
    datasets.Dataset, datasets.Dataset, datasets.Dataset, List[str]
]:
    sick_ds = datasets.load_dataset("sick")
    sentences = sorted(set(
        sentence
        for entry in ["train"]
        for column in ["sentence_A", "sentence_B"]
        for sentence in set(sick_ds[entry][column])
    ))
    train_ds = _entailment_pairs(sick_ds["train"])
    validation_ds = _entailment_pairs(sick_ds["validation"])
    test_ds = _entailment_pairs(sick_ds["test"])

    return train_ds, validation_ds, test_ds, sentences

def _entailment_pairs(ds: datasets.Dataset) -> datasets.Dataset:
    df = pd.DataFrame(ds)
    df = df[df.label == 0] # 0 = entailment
    df = df[["sentence_A", "sentence_B"]]
    return datasets.Dataset.from_pandas(df)
Code
train_ds, validation_ds, test_ds, sentences = load_positive_negative_dataset()
Found cached dataset sick (/home/matthew/.cache/huggingface/datasets/sick/default/0.0.0/c6b3b0b44eb84b134851396d6d464e5cb8f026960519d640e087fe33472626db)
Code
len(train_ds), len(validation_ds), len(test_ds)
(1274, 143, 1404)
Code
import pandas as pd

print(
    pd.DataFrame(train_ds)
        [["sentence_A", "sentence_B"]]
        .head()
        .to_markdown(tablefmt="grid", maxcolwidths=25)
)
+----+---------------------------+---------------------------+
|    | sentence_A                | sentence_B                |
+====+===========================+===========================+
|  0 | The young boys are        | The kids are playing      |
|    | playing outdoors and the  | outdoors near a man with  |
|    | man is smiling nearby     | a smile                   |
+----+---------------------------+---------------------------+
|  1 | A man with a jersey is    | The ball is being dunked  |
|    | dunking the ball at a     | by a man with a jersey at |
|    | basketball game           | a basketball game         |
+----+---------------------------+---------------------------+
|  2 | Two young women are       | Two women are sparring in |
|    | sparring in a kickboxing  | a kickboxing match        |
|    | fight                     |                           |
+----+---------------------------+---------------------------+
|  3 | Three boys are jumping in | Three kids are jumping in |
|    | the leaves                | the leaves                |
+----+---------------------------+---------------------------+
|  4 | People wearing costumes   | Masked people are looking |
|    | are gathering in a forest | in the same direction in  |
|    | and are looking in the    | a forest                  |
|    | same direction            |                           |
+----+---------------------------+---------------------------+
Code
from typing import Any, Dict, List, Tuple, Union
import random

import torch
from torch import nn
from transformers import Trainer, BatchEncoding

from sentence_transformers import SentenceTransformer

class SentenceTransformersNegativeSampleTrainingArguments(TrainingArguments):
    def __init__(
        self,
        *args,
        negative_samples: int = 5,
        **kwargs,
    ) -> None:
        super().__init__(*args, **kwargs)
        self.negative_samples = negative_samples

class SentenceTransformersNegativeSampleTrainer(Trainer):
    def __init__(
        self,
        *args,
        text_columns: List[str],
        sentences: List[BatchEncoding],
        loss: nn.Module,
        **kwargs,
    ) -> None:
        super().__init__(*args, **kwargs)
        self.text_columns = text_columns
        self.loss = loss
        self.loss.to(self.model.device)
        self.sentences = sentences
        for sentence in self.sentences:
            sentence.to(self.model.device)

    def compute_loss(
        self,
        model: SentenceTransformer,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        return_outputs: bool = False,
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
        features = self.collect_features(inputs)
        batch_size = features[0]["input_ids"].shape[0]
        loss = self.loss(
            features,
            torch.ones(batch_size, device=model.device),
        )
        for negative in random.choices(self.sentences, k=self.args.negative_samples):
            loss += self.loss(
                [features[0], negative],
                torch.ones(batch_size, device=model.device) * -1
            )
        loss = loss / (1 + self.args.negative_samples)
        if return_outputs:
            output = torch.cat(
                [model(row)["sentence_embedding"][:, None] for row in features], dim=1
            )
            return loss, output
        return loss

    def collect_features(
        self, inputs: Dict[str, Union[torch.Tensor, Any]]
    ) -> List[Dict[str, torch.Tensor]]:
        """Turn the inputs from the dataloader into the separate model inputs."""
        return [
            {
                "input_ids": inputs[f"{column}_input_ids"],
                "attention_mask": inputs[f"{column}_attention_mask"],
            }
            for column in self.text_columns
        ]
Code
from pathlib import Path
from typing import Optional, Dict, Tuple, Union

from transformers import (
    Trainer,
    TrainingArguments,
    EvalPrediction,
)
from sentence_transformers import (
    SentenceTransformer,
    losses,
    evaluation,
)
import datasets
import torch


def train_negative_samples(
    *,
    train_ds: datasets.Dataset,
    test_ds: datasets.Dataset,
    sentences: List[str],
    model_name: str = MODEL_NAME,
    batch_size: int = BATCH_SIZE,
    learning_rate: float = LEARNING_RATE,
    fp16: bool = False,
    epochs: Optional[float] = EPOCHS,
    max_steps: int = -1,
    run_folder: Path = MODEL_RUN_FOLDER,
    model_folder: Path = MODEL_RUN_FOLDER,
    # number of steps before moving evaluation results from GPU to CPU see
    # https://discuss.huggingface.co/t/cuda-out-of-memory-when-using-trainer-with-compute-metrics/2941
    eval_accumulation_steps: Optional[int] = None,
    evaluation_strategy: str = "epoch",
    save_strategy: str = "epoch",
    evaluation_steps: int = 500,
) -> Path:
    run_name = "-".join(
        [
            f"{model_name}",
            f"e{epochs}" if max_steps == -1 else f"ms{max_steps}",
            f"bs{batch_size}",
            f"lr{learning_rate}",
        ]
        + (["fp16"] if fp16 else [])
    )
    print(f"Starting {run_name}")

    training_args = SentenceTransformersNegativeSampleTrainingArguments(
        report_to="none",
        output_dir=run_folder,
        num_train_epochs=epochs,
        max_steps=max_steps,
        seed=33,
        eval_accumulation_steps=eval_accumulation_steps,
        #
        # hyperparameters
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        fp16=fp16,
        learning_rate=learning_rate,
        warmup_steps=WARMUP_STEPS,
        optim="adamw_torch",
        #
        # evaluation settings
        evaluation_strategy=evaluation_strategy,
        save_strategy=evaluation_strategy,
        logging_steps=evaluation_steps,
        eval_steps=evaluation_steps,
        save_steps=evaluation_steps,
        #
        # checkpoint settings
        logging_dir=run_folder / "logs",
        save_total_limit=2,
        load_best_model_at_end=True,
        metric_for_best_model="cosine_similarity",
        greater_is_better=True,
        #
        # needed to get sentence_A and sentence_B
        remove_unused_columns=False,
    )

    model = SentenceTransformer(model_name)
    tokenizer = model.tokenizer
    loss = losses.CosineSimilarityLoss(model)
    
    # equal positive and negative samples
    evaluator = evaluation.EmbeddingSimilarityEvaluator(
        test_ds["sentence_A"] + test_ds["sentence_A"],
        test_ds["sentence_B"] + random.choices(sentences, k=len(test_ds)),
        ([1]*len(test_ds)) + ([-1]*len(test_ds)),
        main_similarity=evaluation.SimilarityFunction.COSINE,
    )
    def compute_metrics(predictions: EvalPrediction) -> Dict[str, float]:
        return {
            "cosine_similarity": evaluator(model)
        }
    
    data_collator = SentenceTransformersCollator(
        tokenizer=tokenizer,
        text_columns=TEXT_COLUMNS,
    )
    sentence_tokens = [
        tokenizer(sentence, return_tensors="pt")
        for sentence in sentences
    ]

    trainer = SentenceTransformersNegativeSampleTrainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=test_ds,
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
        # custom arguments
        loss=loss,
        text_columns=TEXT_COLUMNS,
        sentences=sentence_tokens,
    )

    trainer.train()
    model.save(str(model_folder / run_name)) # this does not accept Path objects

    return model
Code
model = train_negative_samples(
    train_ds=train_ds,
    test_ds=validation_ds,
    sentences=sentences,
    # train dataset size is less than 500 batches, so no training loss is reported
    # this ensures that the logging_steps is smaller than the dataset size in batches to prevent "No log"
    # see https://github.com/huggingface/transformers/issues/8910
    evaluation_steps=10,
)
PyTorch: setting up devices
Starting all-MiniLM-L12-v2-e10-bs32-lr2e-05
[400/400 01:43, Epoch 10/10]
Epoch Training Loss Validation Loss Cosine Similarity
1 0.866000 0.853925 0.864421
2 0.867200 0.836017 0.863998
3 0.864700 0.816948 0.863828
4 0.857000 0.855043 0.864337
5 0.857500 0.868054 0.863659
6 0.844300 0.847864 0.863151
7 0.846900 0.826158 0.862897
8 0.827400 0.805237 0.863066
9 0.834800 0.835648 0.862812
10 0.832900 0.852276 0.862897

Code
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np

model.eval()

with torch.inference_mode():
    embedding_a = model.encode(sick_ds["test"]["sentence_A"])
    embedding_b = model.encode(sick_ds["test"]["sentence_B"])

labels = sick_ds["test"]["label"]
predictions = F.cosine_similarity(
    torch.from_numpy(embedding_a),
    torch.from_numpy(embedding_b),
).numpy()

negative_performance = pd.DataFrame(
    {
        "predictions": predictions,
        "targets": labels,
        "difference": np.abs(predictions - labels),
    }
).describe()
negative_performance
predictions targets difference
count 4906.000000 4906.000000 4906.000000
mean 0.598601 0.263954 0.354039
std 0.274311 0.504683 0.293836
min -0.188712 -1.000000 0.000018
25% 0.418188 0.000000 0.102994
50% 0.637474 0.300000 0.280422
75% 0.820952 0.650000 0.536532
max 0.998223 1.000000 1.450934
Code
(negative_performance.difference - sentence_transformers_performance.difference).to_frame()
difference
count 0.000000
mean 0.106446
std 0.060993
min -0.000103
25% 0.029047
50% 0.103737
75% 0.192640
max 0.031847
Code
import pandas as pd
import numpy as np

df = pd.DataFrame(
    {
        "sentence_a": sick_ds["test"]["sentence_A"],
        "sentence_b": sick_ds["test"]["sentence_B"],
        "target": sick_ds["test"]["label"],
        "prediction": predictions,
        "difference": np.abs(predictions - labels),
    }
)
print(df.head().to_markdown(tablefmt="grid", maxcolwidths=25))
+----+---------------------------+--------------------------+----------+--------------+--------------+
|    | sentence_a                | sentence_b               |   target |   prediction |   difference |
+====+===========================+==========================+==========+==============+==============+
|  0 | There is no boy playing   | A group of kids is       |   0.15   |     0.186045 |    0.0360454 |
|    | outdoors and there is no  | playing in a yard and an |          |              |              |
|    | man smiling               | old man is standing in   |          |              |              |
|    |                           | the background           |          |              |              |
+----+---------------------------+--------------------------+----------+--------------+--------------+
|  1 | A group of boys in a yard | The young boys are       |   0.35   |     0.508509 |    0.158509  |
|    | is playing and a man is   | playing outdoors and the |          |              |              |
|    | standing in the           | man is smiling nearby    |          |              |              |
|    | background                |                          |          |              |              |
+----+---------------------------+--------------------------+----------+--------------+--------------+
|  2 | A group of children is    | The young boys are       |   0      |     0.209752 |    0.209752  |
|    | playing in the house and  | playing outdoors and the |          |              |              |
|    | there is no man standing  | man is smiling nearby    |          |              |              |
|    | in the background         |                          |          |              |              |
+----+---------------------------+--------------------------+----------+--------------+--------------+
|  3 | A brown dog is attacking  | A brown dog is attacking |   0.95   |     0.965482 |    0.0154816 |
|    | another animal in front   | another animal in front  |          |              |              |
|    | of the tall man in pants  | of the man in pants      |          |              |              |
+----+---------------------------+--------------------------+----------+--------------+--------------+
|  4 | A brown dog is attacking  | A brown dog is helping   |   0.3325 |     0.83249  |    0.49999   |
|    | another animal in front   | another animal in front  |          |              |              |
|    | of the man in pants       | of the man in pants      |          |              |              |
+----+---------------------------+--------------------------+----------+--------------+--------------+

This model is much better than I expected bearing in mind that it was trained on a fraction of the data in a noisy fashion. It also wasn’t given exact labels to work with.

It still suffers from the attacking/helping classification problem. I wonder if it started out well and then didn’t really change, as the training did not seem to alter the loss or cosine similarity very much.

Even so this was more about the technique than the results. It would be nice if it was possible to separate out the selection of negative samples from the trainer, but I think this approach is acceptable.