MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
DOCUMENTS = [
"The huggingface models are really neat.",
"I like using open source models.",
]The huggingface models are really neat. One thing I particularly like about them is how easy they are to use; load them, prepare the input, invoke. If we make our own model that extends them then the utility methods will be lost. It is possible to extend the huggingface classes in a way that keeps them.
Original Model
The model that I’m going to use for this is sentence-transformers/all-MiniLM-L6-v2 which I’ve been using quite a bit to embed documents. It’s small and very fast which makes it easy to try out. The model card includes custom code to do two additional steps - mean pooling and normalization.
Let’s start by using the model as the modelcard shows and then wrap that up into a nice huggingface compatible class.
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
@torch.inference_mode()
def embed(documents: list[str]) -> torch.Tensor:
# Load model from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
clean_up_tokenization_spaces=True,
)
model = AutoModel.from_pretrained(MODEL_NAME)
# Tokenize sentences
encoded_input = tokenizer(
documents,
padding=True,
truncation=True,
return_tensors="pt",
)
# Compute token embeddings
model_output = model(**encoded_input)
# Perform pooling
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
# Normalize embeddings
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
return sentence_embeddings
#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
original_embedding = embed(DOCUMENTS)
original_embedding.shapetorch.Size([2, 384])
It has taken the two documents and turned them into 384 dimensional vector embeddings. Easy enough. Can we make a model which wraps up the additional steps that are performed after invoking the model?
Huggingface Pretrained Models
There are already models which extend the base pretrained models. The sequence classification versions of the models do exactly this. When you load them there are warnings clearly saying that they need training.
Looking at one of them can give some hints. The BERT model is a classic, and huggingface transformers itself was originally a repo called pytorch-pretrained-BERT.
Looking at the task specific models I can see that they extend a class called BertPreTrainedModel which starts like this:
class BertPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = BertConfig
load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert"
supports_gradient_checkpointing = True
_supports_sdpa = TrueThese class parameters at the top coupled with extending the
PreTrainedModelbase. Some of these are documented in the parent class:Class attributes (overridden by derived classes):
- config_class ([
PretrainedConfig]) – A subclass of [PretrainedConfig] to use as configuration class for this model architecture.- load_tf_weights (
Callable) – A python method for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
- model ([
PreTrainedModel]) – An instance of the model on which to load the TensorFlow checkpoint.- config ([
PreTrainedConfig]) – An instance of the configuration associated to the model.- path (
str) – A path to the TensorFlow checkpoint.- base_model_prefix (
str) – A string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.- is_parallelizable (
bool) – A flag indicating whether this model supports model parallelization.- main_input_name (
str) – The name of the principal input to the model (ofteninput_idsfor NLP models,pixel_valuesfor vision models andinput_valuesfor speech models).
base_model_prefix is one of the most important settings. With this calling from_pretrained can correctly populate this field when the referenced model is just the base model.
This means that load_tf_weights can be skipped. It doesn’t explain what supports_gradient_checkpointing or _supports_sdpa do though.
Looking up the terms I can find that gradient checkpointing is a way to train large models where the gradients cannot be completely stored in a single pass of the model. As we are adding layers that have no parameters we can be sure that this will work with our new model.
SPDA is scaled dot product attention (also known as flash attention). This is a way to make transformer models much faster. The BERT model is explicitly listed as supported, and we are not adding any attention layers so we continue to support SDPA.
We can demonstrate this working by making a model that only wraps an existing model.
from typing import Union, Tuple
import torch
from transformers import BertConfig, BertModel
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
from transformers.modeling_utils import PreTrainedModel
class SimplestModelExtension(PreTrainedModel): # pylint: disable=abstract-method
config_class = BertConfig
base_model_prefix = "body"
supports_gradient_checkpointing = True
_supports_sdpa = True
def __init__(self, config: BertConfig, **kwargs) -> None:
super().__init__(config)
self.body = BertModel(config) # base_model_prefix
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
**kwargs,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
return self.body(
input_ids=input_ids,
attention_mask=attention_mask,
**kwargs,
)from transformers import AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
clean_up_tokenization_spaces=True,
)
model = SimplestModelExtension.from_pretrained(MODEL_NAME)
tokens = tokenizer("hello world", return_tensors="pt")
with torch.inference_mode():
output = model(**tokens)
output.last_hidden_state.shapetorch.Size([1, 4, 384])
We can see that the model loads and runs without error. This is just the sort of behaviour we want. Our embedding model should return a single embedding for the entire input though (i.e. a shape of (1, 384)).
Making Our Own
To make this I need to define the two modules, one to do mean pooling and then another to do the normalization. With these we can then create the embedding model. As these two layers have no parameters they only exist to define a forward method:
import torch
import torch.nn.functional as F
from torch import nn
from transformers import AutoModel, AutoTokenizer
class MeanPooling(nn.Module):
def forward(
self,
last_hidden_state: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
unattended_tokens = ~attention_mask[..., None].bool()
last_hidden_state = last_hidden_state.masked_fill(unattended_tokens, 0.0)
attended_token_count = attention_mask.sum(dim=1)[..., None]
return last_hidden_state.sum(dim=1) / attended_token_count
class Normalize(nn.Module):
def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
return F.normalize(embeddings, p=2, dim=1)With this we can perform the two additional steps after the base embedding model. It’s important to pass the attention mask to the mean pooling layer so that it can correctly average over the attended tokens.
import torch
from transformers import BertConfig, BertModel
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
from transformers.modeling_utils import PreTrainedModel
class EmbeddingModel(PreTrainedModel): # pylint: disable=abstract-method
config_class = BertConfig
base_model_prefix = "body"
supports_gradient_checkpointing = True
_supports_sdpa = True
def __init__(self, config: BertConfig, **kwargs) -> None:
super().__init__(config)
self.body = BertModel(config)
self.pool = MeanPooling()
self.normalize = Normalize()
# for simplicity this no longer respects return_dict
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
**kwargs,
) -> torch.Tensor:
output: BaseModelOutputWithPoolingAndCrossAttentions = self.body(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
)
embeddings = self.pool(
last_hidden_state=output.last_hidden_state,
attention_mask=attention_mask,
)
normalized_embeddings = self.normalize(embeddings)
return normalized_embeddingsDoes this work? Let’s test loading it using from_pretrained and running it over the documents.
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
clean_up_tokenization_spaces=True,
)
model = EmbeddingModel.from_pretrained(MODEL_NAME)
@torch.inference_mode()
def model_embed(documents: list[str]) -> torch.Tensor:
encoded_input = tokenizer(
documents,
padding=True,
truncation=True,
return_tensors="pt",
)
return model(**encoded_input)
model_embedding = model_embed(DOCUMENTS)
model_embedding.shapetorch.Size([2, 384])
The shape of the output has changed, as desired. Since this is the original model with the additional steps these outputs should exactly match the original embeddings we generated:
torch.all(original_embedding == model_embedding)tensor(True)
This certainly appears to be a suitable replacement for the original code!