Extending Huggingface Model

Customizing a pretrained model while keeping the handy utility methods
Published

October 5, 2024

The huggingface models are really neat. One thing I particularly like about them is how easy they are to use; load them, prepare the input, invoke. If we make our own model that extends them then the utility methods will be lost. It is possible to extend the huggingface classes in a way that keeps them.

Original Model

The model that I’m going to use for this is sentence-transformers/all-MiniLM-L6-v2 which I’ve been using quite a bit to embed documents. It’s small and very fast which makes it easy to try out. The model card includes custom code to do two additional steps - mean pooling and normalization.

Let’s start by using the model as the modelcard shows and then wrap that up into a nice huggingface compatible class.

MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
DOCUMENTS = [
    "The huggingface models are really neat.",
    "I like using open source models.",
]
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F

@torch.inference_mode()
def embed(documents: list[str]) -> torch.Tensor:
    # Load model from HuggingFace Hub
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_NAME,
        clean_up_tokenization_spaces=True,
    )
    model = AutoModel.from_pretrained(MODEL_NAME)
    
    # Tokenize sentences
    encoded_input = tokenizer(
        documents,
        padding=True,
        truncation=True,
        return_tensors="pt",
    )
    
    # Compute token embeddings
    model_output = model(**encoded_input)
    
    # Perform pooling
    sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
    
    # Normalize embeddings
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)

    return sentence_embeddings

#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

original_embedding = embed(DOCUMENTS)
original_embedding.shape
torch.Size([2, 384])

It has taken the two documents and turned them into 384 dimensional vector embeddings. Easy enough. Can we make a model which wraps up the additional steps that are performed after invoking the model?

Huggingface Pretrained Models

There are already models which extend the base pretrained models. The sequence classification versions of the models do exactly this. When you load them there are warnings clearly saying that they need training.

Looking at one of them can give some hints. The BERT model is a classic, and huggingface transformers itself was originally a repo called pytorch-pretrained-BERT.

Looking at the task specific models I can see that they extend a class called BertPreTrainedModel which starts like this:

class BertPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = BertConfig
    load_tf_weights = load_tf_weights_in_bert
    base_model_prefix = "bert"
    supports_gradient_checkpointing = True
    _supports_sdpa = True

These class parameters at the top coupled with extending the PreTrainedModel base. Some of these are documented in the parent class:

Class attributes (overridden by derived classes):

  • config_class ([PretrainedConfig]) – A subclass of [PretrainedConfig] to use as configuration class for this model architecture.
  • load_tf_weights (Callable) – A python method for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
    • model ([PreTrainedModel]) – An instance of the model on which to load the TensorFlow checkpoint.
    • config ([PreTrainedConfig]) – An instance of the configuration associated to the model.
    • path (str) – A path to the TensorFlow checkpoint.
  • base_model_prefix (str) – A string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
  • is_parallelizable (bool) – A flag indicating whether this model supports model parallelization.
  • main_input_name (str) – The name of the principal input to the model (often input_ids for NLP models, pixel_values for vision models and input_values for speech models).

base_model_prefix is one of the most important settings. With this calling from_pretrained can correctly populate this field when the referenced model is just the base model.

This means that load_tf_weights can be skipped. It doesn’t explain what supports_gradient_checkpointing or _supports_sdpa do though.

Looking up the terms I can find that gradient checkpointing is a way to train large models where the gradients cannot be completely stored in a single pass of the model. As we are adding layers that have no parameters we can be sure that this will work with our new model.

SPDA is scaled dot product attention (also known as flash attention). This is a way to make transformer models much faster. The BERT model is explicitly listed as supported, and we are not adding any attention layers so we continue to support SDPA.

We can demonstrate this working by making a model that only wraps an existing model.

from typing import Union, Tuple
import torch
from transformers import BertConfig, BertModel
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
from transformers.modeling_utils import PreTrainedModel


class SimplestModelExtension(PreTrainedModel):  # pylint: disable=abstract-method
    config_class = BertConfig
    base_model_prefix = "body"
    supports_gradient_checkpointing = True
    _supports_sdpa = True

    def __init__(self, config: BertConfig, **kwargs) -> None:
        super().__init__(config)
        self.body = BertModel(config) # base_model_prefix

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        **kwargs,
    ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
        return self.body(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs,
        )
from transformers import AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    clean_up_tokenization_spaces=True,
)
model = SimplestModelExtension.from_pretrained(MODEL_NAME)

tokens = tokenizer("hello world", return_tensors="pt")
with torch.inference_mode():
    output = model(**tokens)
output.last_hidden_state.shape
torch.Size([1, 4, 384])

We can see that the model loads and runs without error. This is just the sort of behaviour we want. Our embedding model should return a single embedding for the entire input though (i.e. a shape of (1, 384)).

Making Our Own

To make this I need to define the two modules, one to do mean pooling and then another to do the normalization. With these we can then create the embedding model. As these two layers have no parameters they only exist to define a forward method:

import torch
import torch.nn.functional as F
from torch import nn
from transformers import AutoModel, AutoTokenizer


class MeanPooling(nn.Module):
    def forward(
        self,
        last_hidden_state: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> torch.Tensor:
        unattended_tokens = ~attention_mask[..., None].bool()
        last_hidden_state = last_hidden_state.masked_fill(unattended_tokens, 0.0)
        attended_token_count = attention_mask.sum(dim=1)[..., None]
        return last_hidden_state.sum(dim=1) / attended_token_count


class Normalize(nn.Module):
    def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
        return F.normalize(embeddings, p=2, dim=1)

With this we can perform the two additional steps after the base embedding model. It’s important to pass the attention mask to the mean pooling layer so that it can correctly average over the attended tokens.

import torch
from transformers import BertConfig, BertModel
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
from transformers.modeling_utils import PreTrainedModel


class EmbeddingModel(PreTrainedModel):  # pylint: disable=abstract-method
    config_class = BertConfig
    base_model_prefix = "body"
    supports_gradient_checkpointing = True
    _supports_sdpa = True

    def __init__(self, config: BertConfig, **kwargs) -> None:
        super().__init__(config)
        self.body = BertModel(config)
        self.pool = MeanPooling()
        self.normalize = Normalize()

    # for simplicity this no longer respects return_dict
    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        **kwargs,
    ) -> torch.Tensor:
        output: BaseModelOutputWithPoolingAndCrossAttentions = self.body(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True,
        )
        embeddings = self.pool(
            last_hidden_state=output.last_hidden_state,
            attention_mask=attention_mask,
        )
        normalized_embeddings = self.normalize(embeddings)
        return normalized_embeddings

Does this work? Let’s test loading it using from_pretrained and running it over the documents.

from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    clean_up_tokenization_spaces=True,
)
model = EmbeddingModel.from_pretrained(MODEL_NAME)

@torch.inference_mode()
def model_embed(documents: list[str]) -> torch.Tensor:
    encoded_input = tokenizer(
        documents,
        padding=True,
        truncation=True,
        return_tensors="pt",
    )
    return model(**encoded_input)

model_embedding = model_embed(DOCUMENTS)
model_embedding.shape
torch.Size([2, 384])

The shape of the output has changed, as desired. Since this is the original model with the additional steps these outputs should exactly match the original embeddings we generated:

torch.all(original_embedding == model_embedding)
tensor(True)

This certainly appears to be a suitable replacement for the original code!