Quantizing a BERT Model using Optimimum ONNX

Different ways to Quantize a model with performance
quantization
Published

September 12, 2022

The state of the art for model quantization has progressed quite a bit. It seems that huggingface is incorporating all of the coolest tech right now so they have quite a lot of quantization and optimization available through the Optimum library. This post is an investigation of that library using ONNX and Intel optimizations to try to improve the performance of text sentiment classification.

I am going to run through all of the parts of the Optimum documentation and see how easy each one is to get working. It would also be nice to test the speed and performance of each approach to see how they compare.

Dataset

To provide a baseline performance metric we can take the Standford Sentiment Treebank (Socher et al. 2013). Huggingface provide a way to load and use that.

Socher, Richard, Alex Perelygin, Jean Wu, Jason Chuang, Christopher D. Manning, Andrew Ng, and Christopher Potts. 2013. “Recursive Deep Models for Semantic Compositionality over a Sentiment Treebank.” In Proceedings of the 2013 Conference on Empirical Methods in Natural Language Processing, 1631–42. Seattle, Washington, USA: Association for Computational Linguistics. https://aclanthology.org/D13-1170.
Code
import pandas as pd
from datasets import load_dataset

dataset = load_dataset("sst2")
df = pd.DataFrame(dataset["train"])
df = df.set_index("idx")
df
sentence label
idx
0 hide new secretions from the parental units 0
1 contains no wit , only labored gags 0
2 that loves its characters and communicates som... 1
3 remains utterly satisfied to remain the same t... 0
4 on the worst revenge-of-the-nerds clichés the ... 0
... ... ...
67344 a delightful comedy 1
67345 anguish , anger and frustration 0
67346 at achieving the modest , crowd-pleasing goals... 1
67347 a patient viewer 1
67348 this new jangle of noise , mayhem and stupidit... 0

67349 rows × 2 columns

That first row looks quite weird, and the text is quite short. Even so this has been used to train different sentiment models and there are even pretrained versions available that used this. I can use this to quantify the change in accuracy of the model.

Baseline Evaluation

The baseline model that I am going to use is the DistilBERT base uncased finetuned SST-2 model. This is a simple model that has already been distilled and so any quantized version of the model should be at least this good. This can be used with and without the GPU to provide baseline timing and accuracy results.

This is also an opportunity to try out the huggingface evaluate library. I’m quite hopeful about this as previously I’ve had to construct trainers or use the metrics (which were nice but seemed misaligned with the trainer). If this makes it easier to work with then that would be great.

Code
import evaluate
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

task_evaluator = evaluate.evaluator("text-classification")
clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])
data = load_dataset("sst2", split="validation")

MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
Code
from typing import Dict, Union
import pandas as pd

def format_results(results: Dict[str, Union[float, Dict]]) -> pd.DataFrame:
    # the bootstrap results are dicts with confidence_interval, standard_error and score
    # the non bootstrap results are just scores.
    # for now map all that to just scores
    def expand_row(name: str, value: Union[float, Dict]) -> Dict[str, Union[str, float]]:
        if isinstance(value, float):
            return {
                "name": name,
                "value": value,
            }
        return {
            "name": name,
            "value": value["score"],
            "confidence_low": value["confidence_interval"][0],
            "confidence_high": value["confidence_interval"][1],
            "std": value["standard_error"],
        }
    df = pd.DataFrame([
        expand_row(name, value)
        for name, value in results.items()
    ])
    df = df.set_index("name")
    return df
format_results(
    task_evaluator.compute(
        model_or_pipeline=model,
        tokenizer=tokenizer,
        data=data,
        metric=clf_metrics,
        input_column="sentence",
        label_column="label",
        label_mapping={"POSITIVE": 1, "NEGATIVE": 0},
        strategy="bootstrap",
        n_resamples=10,
        random_state=0
    )
)
value confidence_low confidence_high std
name
accuracy 0.910550 0.896956 0.917235 0.008721
f1 0.913717 0.901645 0.925092 0.008324
precision 0.897826 0.874706 0.907254 0.010246
recall 0.930180 0.911785 0.946024 0.011576
total_time_in_seconds 2.877721 NaN NaN NaN
samples_per_second 303.017543 NaN NaN NaN
latency_in_seconds 0.003300 NaN NaN NaN
format_results(
    task_evaluator.compute(
        model_or_pipeline=model,
        tokenizer=tokenizer,
        data=data,
        metric=clf_metrics,
        input_column="sentence",
        label_column="label",
        label_mapping={"POSITIVE": 1, "NEGATIVE": 0},
    )
)

If I only inspect the score value from the results then the two results are absolutely equivalent and the non-bootstrap approach is about 30x quicker. The bootstrap evaluation does provide a lot more information about the quality of the model, so it would be nice to use that for the full evaluations.

The samples per second appears to be comparable between the two and is one of the metrics I was most interested in. This is on the GPU though, so how does it compare when moved to CPU?

model.cpu()

format_results(
    task_evaluator.compute(
        model_or_pipeline=model,
        tokenizer=tokenizer,
        data=data,
        metric=clf_metrics,
        input_column="sentence",
        label_column="label",
        label_mapping={"POSITIVE": 1, "NEGATIVE": 0},
        device=-1, # CPU
    )
)

It’s gone from 319 samples per second to 103. So as a baseline we have:

device accuracy samples/second
GPU 0.910550 319.239023
CPU 0.910550 102.782367

Overall I think that the evaluation framework is excellent and very easy to use.

Now we can try comparing this to different versions of the quantized or optimized model.

Optimum

To start we can try out the quickstart example. I’ve had to adjust the code slightly as the original did not work (opened an issue about that).

ONNX Dynamic Quantization

This will convert the model to the ONNX format and then quantize it. The quantization will be dynamic which means that inputs of any length can be provided to the quantized model. Keeping the model dynamic restricts the optimizations that can be performed to the model.

Code
from pathlib import Path

from optimum.onnxruntime import ORTModelForSequenceClassification, ORTQuantizer
from optimum.onnxruntime.configuration import AutoQuantizationConfig
from transformers import AutoTokenizer, Pipeline, pipeline

QUANTIZATION_FOLDER = Path("/data/blog/2022-09-12-model-quantization")
QUICKSTART_SAVE_FOLDER = QUANTIZATION_FOLDER / "quickstart"
DYNAMIC_SAVE_FOLDER = QUICKSTART_SAVE_FOLDER / "dynamic"
DYNAMIC_SAVE_FOLDER.mkdir(parents=True, exist_ok=True)

MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"


def export_to_onnx(model_name: str, directory: Path) -> None:
    """
    Load the model from transformers and export it to the ONNX format.
    """
    model = ORTModelForSequenceClassification.from_pretrained(
        model_name, from_transformers=True
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model.save_pretrained(directory, file_name="model.onnx")
    tokenizer.save_pretrained(directory)


def dynamic_quantize_onnx_model(directory: Path, **config) -> None:
    """
    Apply dynamic quantization to the model.
    """
    quantization_config = AutoQuantizationConfig.arm64(**config)
    quantizer = ORTQuantizer.from_pretrained(directory)

    quantizer.quantize(
        save_dir=directory,
        quantization_config=quantization_config,
    )


def load_quantized_pipeline(directory: Path) -> Pipeline:
    """
    Load the quantized model as a text classification pipeline.
    """
    model = ORTModelForSequenceClassification.from_pretrained(
        directory, file_name="model_quantized.onnx"
    )
    tokenizer = AutoTokenizer.from_pretrained(directory)

    return pipeline("text-classification", model=model, tokenizer=tokenizer)


export_to_onnx(
    model_name=MODEL_NAME,
    directory=DYNAMIC_SAVE_FOLDER
)
dynamic_quantize_onnx_model(
    directory=DYNAMIC_SAVE_FOLDER,
    is_static=False,
    per_channel=False,
)
quantized_pipeline = load_quantized_pipeline(
    directory=DYNAMIC_SAVE_FOLDER,
)

At this point we have a pipeline which can be used directly:

quantized_pipeline("I love burritos!")
[{'label': 'POSITIVE', 'score': 0.9996811151504517}]

I’ve created the pipeline as it’s a suitable format to use with the evaluate framework. As such we can now compare the performance of the quantized model to the original distilled model.

Code
%%time
#| code-fold: false

format_results(
    task_evaluator.compute(
        model_or_pipeline=quantized_pipeline,
        tokenizer=tokenizer,
        data=data,
        metric=clf_metrics,
        input_column="sentence",
        label_column="label",
        label_mapping={"POSITIVE": 1, "NEGATIVE": 0},
        strategy="bootstrap",
        n_resamples=10,
        random_state=0,
        device=-1, # CPU
    )
)
CPU times: user 1min 47s, sys: 892 ms, total: 1min 48s
Wall time: 1min 10s
value confidence_low confidence_high std
name
accuracy 0.897936 0.885770 0.904523 0.006890
f1 0.901874 0.888681 0.910548 0.007186
precision 0.883369 0.856407 0.895553 0.012170
recall 0.921171 0.903992 0.935958 0.011506
total_time_in_seconds 4.091242 NaN NaN NaN
samples_per_second 213.138190 NaN NaN NaN
latency_in_seconds 0.004692 NaN NaN NaN

This is encouraging. The results from this quantization compare well to the baseline:

device accuracy accuracy Δ samples/second relative speed
Distilled GPU 0.910550 319.239023
Distilled CPU 0.910550 102.782367
Dynamic Quantized CPU 0.897936 -0.012614 269.176181 2.618895

So for a cost of 0.013 accuracy we more than double in speed. This is a good start.

ONNX Static Quantization

Now we can try the static quantization. This involves loading a calibration dataset that can be used to determine how to create the static version of the model. It’s good to use the train dataset for this as it covers the required variation of inputs.

Code
from pathlib import Path
from functools import partial

from optimum.onnxruntime import ORTQuantizer
from optimum.onnxruntime.configuration import AutoCalibrationConfig, AutoQuantizationConfig
from transformers import AutoTokenizer


STATIC_SAVE_FOLDER = QUICKSTART_SAVE_FOLDER / "static"
STATIC_SAVE_FOLDER.mkdir(parents=True, exist_ok=True)

def static_quantize_onnx_model(directory: Path, **config) -> None:
    """
    Apply dynamic quantization to the model.
    """
    quantization_config = AutoQuantizationConfig.arm64(**config)
    quantizer = ORTQuantizer.from_pretrained(directory)
    tokenizer = AutoTokenizer.from_pretrained(directory)
    
    def preprocess_fn(row, tokenizer):
        return tokenizer(row["sentence"])

    # Create the calibration dataset
    calibration_dataset = quantizer.get_calibration_dataset(
        "sst2",
        preprocess_function=partial(preprocess_fn, tokenizer=tokenizer),
        num_samples=50,
        dataset_split="train",
    )

    # Create the calibration configuration containing the parameters related to calibration.
    calibration_config = AutoCalibrationConfig.minmax(calibration_dataset)

    # Perform the calibration step: computes the activations quantization ranges
    ranges = quantizer.fit(
        dataset=calibration_dataset,
        calibration_config=calibration_config,
        # onnx_model_path=onnx_path,
        operators_to_quantize=quantization_config.operators_to_quantize,
    )

    # Apply static quantization on the model
    quantizer.quantize(
        save_dir=directory,
        quantization_config=quantization_config,
        calibration_tensors_range=ranges,
    )


export_to_onnx(
    model_name=MODEL_NAME,
    directory=STATIC_SAVE_FOLDER
)
static_quantize_onnx_model(
    directory=STATIC_SAVE_FOLDER,
    is_static=True,
    per_channel=False,
)
quantized_pipeline = load_quantized_pipeline(
    directory=STATIC_SAVE_FOLDER,
)
%%time

format_results(
    task_evaluator.compute(
        model_or_pipeline=quantized_pipeline,
        tokenizer=tokenizer,
        data=data,
        metric=clf_metrics,
        input_column="sentence",
        label_column="label",
        label_mapping={"POSITIVE": 1, "NEGATIVE": 0},
        strategy="bootstrap",
        n_resamples=10,
        random_state=0,
        device=-1, # CPU
    )
)
/home/matthew/.local/share/virtualenvs/blog-1tuLwbZm/lib/python3.10/site-packages/scipy/stats/_resampling.py:118: RuntimeWarning: invalid value encountered in double_scalars
  a_hat = num / den
/home/matthew/.local/share/virtualenvs/blog-1tuLwbZm/lib/python3.10/site-packages/scipy/stats/_resampling.py:92: DegenerateDataWarning: The bootstrap distribution is degenerate; the confidence interval is not defined.
  warnings.warn(DegenerateDataWarning(msg))
CPU times: user 1min 47s, sys: 769 ms, total: 1min 48s
Wall time: 1min 9s
value confidence_low confidence_high std
name
accuracy 0.544725 0.520952 0.557125 0.012889
f1 0.191446 0.183737 0.199604 0.008865
precision 1.000000 NaN NaN 0.000000
recall 0.105856 0.101163 0.110867 0.005378
total_time_in_seconds 4.252979 NaN NaN NaN
samples_per_second 205.032759 NaN NaN NaN
latency_in_seconds 0.004877 NaN NaN NaN

It’s good to know that this can be done. The results are absolute trash though:

device accuracy accuracy Δ samples/second relative speed
Distilled GPU 0.910550 319.239023
Distilled CPU 0.910550 102.782367
Dynamic Quantized CPU 0.897936 -0.012614 269.176181 2.618895
Static Quantized CPU 0.544725 -0.365825 205.032759 1.994824

The static quantized model is both slower than dynamic quantization and has terrible terrible accuracy. Remember that an accuracy of 0.5 would be expected when randomly choosing the answer!

AutoQuantizationConfig… arm64??

I’m on a linux machine with an x86_64 CPU. The examples in the quickstart use AutoQuantizationConfig.arm64. Is this a problem? I can try out the different methods for the AutoQuantizationConfig to see what the effects are.

It would also be interesting to try varying the different settings that are available.

Code
from pathlib import Path

from optimum.onnxruntime import ORTQuantizer
from optimum.onnxruntime.configuration import AutoQuantizationConfig

def custom_quantize_onnx_model(directory: Path, config: AutoQuantizationConfig) -> None:
    """
    Apply dynamic quantization to the model.
    """
    quantizer = ORTQuantizer.from_pretrained(directory)

    quantizer.quantize(
        save_dir=directory,
        quantization_config=config,
    )
Code
from tempfile import TemporaryDirectory
from typing import Dict

import evaluate
from datasets import load_dataset, Dataset
from optimum.onnxruntime.configuration import AutoQuantizationConfig
from transformers import AutoTokenizer
from tqdm.auto import tqdm

arm_configs = [
    {
        "config": AutoQuantizationConfig.arm64(
            is_static=False,
            use_symmetric_activations=use_symmetric_activations,
            use_symmetric_weights=use_symmetric_weights,
            per_channel=per_channel,
        ),
        "config_type": "arm64",
        "use_symmetric_activations": use_symmetric_activations,
        "use_symmetric_weights": use_symmetric_weights,
        "per_channel": per_channel,
        "reduce_range": None,
    }
    for use_symmetric_activations in [True, False]
    for use_symmetric_weights in [True, False]
    for per_channel in [True, False]
]
avx2_configs = [
    {
        "config": AutoQuantizationConfig.avx2(
            is_static=False,
            use_symmetric_activations=use_symmetric_activations,
            use_symmetric_weights=use_symmetric_weights,
            per_channel=per_channel,
            reduce_range=reduce_range,
        ),
        "config_type": "avx2",
        "use_symmetric_activations": use_symmetric_activations,
        "use_symmetric_weights": use_symmetric_weights,
        "per_channel": per_channel,
        "reduce_range": reduce_range,
    }
    for use_symmetric_activations in [True, False]
    for use_symmetric_weights in [True, False]
    for per_channel in [True, False]
    for reduce_range in [True, False]
]
avx512_configs = [
    {
        "config": AutoQuantizationConfig.avx512(
            is_static=False,
            use_symmetric_activations=use_symmetric_activations,
            use_symmetric_weights=use_symmetric_weights,
            per_channel=per_channel,
            reduce_range=reduce_range,
        ),
        "config_type": "avx512",
        "use_symmetric_activations": use_symmetric_activations,
        "use_symmetric_weights": use_symmetric_weights,
        "per_channel": per_channel,
        "reduce_range": reduce_range,
    }
    for use_symmetric_activations in [True, False]
    for use_symmetric_weights in [True, False]
    for per_channel in [True, False]
    for reduce_range in [True, False]
]
avx512_vnni_configs = [
    {
        "config": AutoQuantizationConfig.avx512_vnni(
            is_static=False,
            use_symmetric_activations=use_symmetric_activations,
            use_symmetric_weights=use_symmetric_weights,
            per_channel=per_channel,
        ),
        "config_type": "avx512_vnni",
        "use_symmetric_activations": use_symmetric_activations,
        "use_symmetric_weights": use_symmetric_weights,
        "per_channel": per_channel,
        "reduce_range": None,
    }
    for use_symmetric_activations in [True, False]
    for use_symmetric_weights in [True, False]
    for per_channel in [True, False]
]

task_evaluator = evaluate.evaluator("text-classification")
clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])
data = load_dataset("sst2", split="validation")

MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"

def quantize_and_evaluate(
    model_name: str,
    data: Dataset,
    metric: evaluate.CombinedEvaluations,
    config: AutoQuantizationConfig,
    input_column: str = "sentence",
    **details,
) -> Dict[str, float]:
    with TemporaryDirectory() as directory:
        directory = Path(directory)
        export_to_onnx(
            model_name=model_name,
            directory=directory
        )
        custom_quantize_onnx_model(
            directory=directory,
            config=config,
        )
        quantized_pipeline = load_quantized_pipeline(
            directory=directory,
        )
        
        results = task_evaluator.compute(
            model_or_pipeline=quantized_pipeline,
            data=data,
            metric=metric,
            input_column=input_column,
            label_column="label",
            label_mapping={"POSITIVE": 1, "NEGATIVE": 0},
            device=-1, # CPU
        )
        return results | details

results = pd.DataFrame([
    quantize_and_evaluate(
        model_name=MODEL_NAME,
        data=data,
        metric=clf_metrics,
        **config
    )
    for config in tqdm(arm_configs + avx2_configs + avx512_configs + avx512_vnni_configs)
])
(
    results
    [results.accuracy > 0.85]
    [[
        "accuracy",
        "f1",
        "samples_per_second",
        "config_type",
        "use_symmetric_activations",
        "use_symmetric_weights",
        "per_channel",
        "reduce_range",
    ]]
    .sort_values(by=["samples_per_second", "accuracy"], ascending=False)
    .head()
)
accuracy f1 samples_per_second config_type use_symmetric_activations use_symmetric_weights per_channel reduce_range
30 0.908257 0.911308 271.873541 avx512 True False False True
45 0.897936 0.901874 270.687187 avx512_vnni False True False None
34 0.894495 0.898455 270.478760 avx512 False True False True
26 0.894495 0.898455 270.171438 avx512 True True False True
32 0.901376 0.905077 269.470640 avx512 False True True True

A small increase in both speed and accuracy is available by customizing the quantization configuration. (In this run I got 271 samples per second but previously I managed 278, so there is clearly some variation). That puts us at:

device accuracy accuracy Δ samples/second relative speed
Distilled GPU 0.910550 319.239023
Distilled CPU 0.910550 102.782367
Dynamic Quantized CPU 0.897936 -0.012614 269.176181 2.618895
Static Quantized CPU 0.544725 -0.365825 205.032759 1.994824
avx512_vnni Quantized CPU 0.897936 -0.012614 278.353499 2.708183

So searching for the correct settings gave us about a 10% speed improvement.

A larger problem here is that the speed results are not stable. The dataset can be processed in about two seconds so the variations in time are problematic. I need to be able to process this a lot so having a larger dataset is also problematic. Having at least 10 seconds worth of data should stabilize the times a bit more.

Larger Dataset

I’m going to try using the amazon_polarity dataset (Zhang, Zhao, and LeCun 2015). This is still a two class sentiment dataset that is substantially larger. The only problem is that the text is split into title and content

Zhang, Xiang, Junbo Zhao, and Yann LeCun. 2015. “Character-Level Convolutional Networks for Text Classification.” arXiv. https://doi.org/10.48550/ARXIV.1509.01626.
Code
from typing import TypedDict
from datasets import load_dataset
import pandas as pd

class InputRow(TypedDict):
    label: int
    title: str
    content: str

class TransformedRow(TypedDict):
    label: int
    text: str

def combine(row: InputRow) -> TransformedRow:
    title = row["title"].strip()
    if title[0] not in {".", "!", "?"}:
        title += "."
    content = row["content"].strip()
    if content:
        content = " " + content
    return {
        "label": row["label"],
        "text": title + content
    }

data = load_dataset("amazon_polarity", split="test[:3000]")
data = data.map(combine)
Code
quantized_pipeline = load_quantized_pipeline(
    directory=DYNAMIC_SAVE_FOLDER,
)

pd.DataFrame([
    task_evaluator.compute(
        model_or_pipeline=quantized_pipeline,
        data=data,
        metric=clf_metrics,
        input_column="text",
        label_column="label",
        label_mapping={"POSITIVE": 1, "NEGATIVE": 0},
        device=-1, # CPU
    )
])
accuracy f1 precision recall total_time_in_seconds samples_per_second latency_in_seconds
0 0.886 0.889105 0.901381 0.877159 33.41182 89.788583 0.011137

This has more text so the evaluation has taken longer. I’m hopeful that this will produce more consistent results. One way to check that is to run it several times and see how it varies.

Code
df = pd.DataFrame([
    task_evaluator.compute(
        model_or_pipeline=quantized_pipeline,
        data=data,
        metric=clf_metrics,
        input_column="text",
        label_column="label",
        label_mapping={"POSITIVE": 1, "NEGATIVE": 0},
        device=-1, # CPU
    )
    for _ in range(20)
])
df.samples_per_second.describe()
count    20.000000
mean     92.570765
std       0.917716
min      92.052706
25%      92.325451
50%      92.364204
75%      92.405892
max      96.440913
Name: samples_per_second, dtype: float64

This has a standard deviation of less than 1, so I think this dataset is suitable. Any results which vary by less than ~1 second are unreliable. The baseline and existing comparisons need to be recalculated so we can continue.

Code
import evaluate
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

task_evaluator = evaluate.evaluator("text-classification")
clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])

MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

dynamic_pipeline = load_quantized_pipeline(
    directory=DYNAMIC_SAVE_FOLDER,
)
static_pipeline = load_quantized_pipeline(
    directory=STATIC_SAVE_FOLDER,
)

df = pd.DataFrame([
    {"name": "baseline"} | task_evaluator.compute(
        model_or_pipeline=model,
        tokenizer=tokenizer,
        data=data,
        metric=clf_metrics,
        input_column="text",
        label_column="label",
        label_mapping={"POSITIVE": 1, "NEGATIVE": 0},
        device=-1, # CPU
    ),
    {"name": "dynamic"} | task_evaluator.compute(
        model_or_pipeline=dynamic_pipeline,
        data=data,
        metric=clf_metrics,
        input_column="text",
        label_column="label",
        label_mapping={"POSITIVE": 1, "NEGATIVE": 0},
        device=-1, # CPU
    ),
    {"name": "static"} | task_evaluator.compute(
        model_or_pipeline=static_pipeline,
        data=data,
        metric=clf_metrics,
        input_column="text",
        label_column="label",
        label_mapping={"POSITIVE": 1, "NEGATIVE": 0},
        device=-1, # CPU
    )
])

baseline_row = df[df.name == "baseline"].iloc[0]
baseline_accuracy = baseline_row.accuracy
baseline_speed = baseline_row.samples_per_second

df = df[["name", "accuracy", "samples_per_second"]]
df["accuracy Δ"] = df.accuracy - baseline_accuracy
df["speed Δ"] = df.samples_per_second - baseline_speed
df["speed ratio"] = df.samples_per_second / baseline_speed

df
name accuracy samples_per_second accuracy Δ speed Δ speed ratio
0 baseline 0.888667 55.432209 0.000000 0.000000 1.000000
1 dynamic 0.886000 92.315347 -0.002667 36.883138 1.665374
2 static 0.532000 69.555485 -0.356667 14.123276 1.254785
Code
df = pd.DataFrame([
    quantize_and_evaluate(
        model_name=MODEL_NAME,
        data=data,
        metric=clf_metrics,
        input_column="text",
        **config
    )
    for config in tqdm(arm_configs + avx2_configs + avx512_configs + avx512_vnni_configs)
])
(
    df
    [df.accuracy > 0.85]
    [[
        "accuracy",
        "f1",
        "samples_per_second",
        "config_type",
        "use_symmetric_activations",
        "use_symmetric_weights",
        "per_channel",
        "reduce_range",
    ]]
    .sort_values(by=["samples_per_second", "accuracy"], ascending=False)
    .head()
)
accuracy f1 samples_per_second config_type use_symmetric_activations use_symmetric_weights per_channel reduce_range
1 0.886 0.889105 98.460482 arm64 True True False None
35 0.886 0.889105 96.763367 avx512 False True False False
38 0.886 0.888817 96.608009 avx512 False False False True
5 0.886 0.889105 96.468477 arm64 False True False None
32 0.888 0.890838 96.348063 avx512 False True True True

Putting this all together we get:

name accuracy samples per second accuracy Δ speed Δ speed ratio
baseline 0.888667 55.432209 0.000000 0.000000 1.000000
dynamic 0.886000 92.315347 -0.002667 36.883138 1.665374
static 0.532000 69.555485 -0.356667 14.123276 1.254785
optimised 0.886000 98.460482 -0.002667 43.028273 1.776232

The samples per second are different enough to be significant. A difference of ~6 is about 6 standard deviations. So optimizing the quantization settings certainly pays off.

The optimized settings themselves are more questionable. Each row differs from the previous by less than 1. Overall I do believe that the top optimizations are a benefit however I would like a more principled way to select the settings.

Operators

The next thing would be the operators_to_quantize. This is a list of the operations within the model which can be optimized, and defaults to ['MatMul', 'Add'].

It would be good to find a complete list of the supported operators and then start trying out more of them. I’ve started by finding this issue which discusses expanding the list of quantizable operators. I’ll try out the proposed list:

Code
df = pd.DataFrame([
    quantize_and_evaluate(
        model_name=MODEL_NAME,
        data=data,
        metric=clf_metrics,
        input_column="text",
        config=AutoQuantizationConfig.arm64(
            is_static=False,
            use_symmetric_activations=True,
            use_symmetric_weights=True,
            per_channel=False,
            operators_to_quantize=[
                "Conv",
                "MatMul",
                "Attention",
                "LSTM",
                "Gather",
                "Transpose",
                "EmbedLayerNormalization"
            ]
        )
    )
])
df
accuracy f1 precision recall total_time_in_seconds samples_per_second latency_in_seconds
0 0.888 0.89098 0.903884 0.878439 35.22876 85.157695 0.011743

It’s interesting that this model has now become slower. In the ticket they talk about a significant size reduction of nearly half, so it may be that the application of quantization to some operators makes the model smaller but slower.

The next thing would be to find all of the operators that exist in the model. Looking at this discussion provides me with a way to inspect the ONNX model to get the operations:

from pathlib import Path
from typing import List
import onnx

def get_model_operators(path: Path) -> List[str]:
    model = onnx.load(path)
    return sorted({
        node.op_type
        for node in model.graph.node
    })

get_model_operators(DYNAMIC_SAVE_FOLDER / "model.onnx")
['Add',
 'Cast',
 'Concat',
 'Constant',
 'Div',
 'Equal',
 'Erf',
 'Expand',
 'Gather',
 'Gemm',
 'Identity',
 'MatMul',
 'Mul',
 'Pow',
 'ReduceMean',
 'Relu',
 'Reshape',
 'Shape',
 'Slice',
 'Softmax',
 'Sqrt',
 'Sub',
 'Transpose',
 'Unsqueeze',
 'Where']

We can now try quantizing every single one of these:

Code
df = pd.DataFrame([
    quantize_and_evaluate(
        model_name=MODEL_NAME,
        data=data,
        metric=clf_metrics,
        input_column="text",
        config=AutoQuantizationConfig.arm64(
            is_static=False,
            use_symmetric_activations=True,
            use_symmetric_weights=True,
            per_channel=False,
            operators_to_quantize=get_model_operators(DYNAMIC_SAVE_FOLDER / "model.onnx")
        )
    )
])
df
accuracy f1 precision recall total_time_in_seconds samples_per_second latency_in_seconds
0 0.888 0.89098 0.903884 0.878439 30.971639 96.862811 0.010324

Given that there are 25 operators in the model it’s not feasible to evaluate all of the combinations (~33 million). Instead I can establish a baseline for the model with no optimizations then try each operator in turn, comparing it to the baseline.

Code
from optimum.onnxruntime.configuration import AutoQuantizationConfig
from tqdm.auto import tqdm

operators = get_model_operators(DYNAMIC_SAVE_FOLDER / "model.onnx")
configurations = [
    {
        "config": AutoQuantizationConfig.arm64(
            is_static=False,
            use_symmetric_activations=True,
            use_symmetric_weights=True,
            per_channel=False,
            operators_to_quantize=[],
        ),
        "name": "baseline",
    },
] + [
    {
        "config": AutoQuantizationConfig.arm64(
            is_static=False,
            use_symmetric_activations=True,
            use_symmetric_weights=True,
            per_channel=False,
            operators_to_quantize=[operator],
        ),
        "name": operator,
    }
    for operator in operators
]

results = pd.DataFrame([
    quantize_and_evaluate(
        model_name=MODEL_NAME,
        data=data,
        metric=clf_metrics,
        input_column="text",
        **config
    )
    for config in tqdm(configurations)
])

baseline_row = results[results.name == "baseline"].iloc[0]
baseline_speed = baseline_row.samples_per_second
baseline_accuracy = baseline_row.accuracy
results["samples Δ"] = results.samples_per_second - baseline_speed
results["accuracy Δ"] = results.accuracy - baseline_accuracy

(
    results
    [[
        "accuracy",
        "accuracy Δ", 
        "f1",
        "samples_per_second",
        "samples Δ",
        "name",
    ]]
    .sort_values(by="samples Δ", ascending=False)
)
accuracy accuracy Δ f1 samples_per_second samples Δ name
12 0.886000 -0.002667 0.889105 95.060374 26.247635 MatMul
0 0.888667 0.000000 0.890850 68.812739 0.000000 baseline
3 0.888667 0.000000 0.890850 67.759519 -1.053221 Concat
9 0.888667 0.000000 0.890707 67.749849 -1.062891 Gather
8 0.888667 0.000000 0.890850 67.623759 -1.188980 Expand
2 0.888667 0.000000 0.890850 67.571570 -1.241170 Cast
6 0.888667 0.000000 0.890850 67.525588 -1.287151 Equal
13 0.888667 0.000000 0.890850 67.505238 -1.307501 Mul
7 0.888667 0.000000 0.890850 67.495581 -1.317158 Erf
14 0.888667 0.000000 0.890850 67.449766 -1.362973 Pow
4 0.888667 0.000000 0.890850 67.413731 -1.399008 Constant
10 0.888667 0.000000 0.890850 67.126520 -1.686219 Gemm
11 0.888667 0.000000 0.890850 67.087859 -1.724881 Identity
5 0.888667 0.000000 0.890850 66.985536 -1.827204 Div
15 0.888667 0.000000 0.890850 66.862638 -1.950101 ReduceMean
18 0.888667 0.000000 0.890850 66.631532 -2.181207 Shape
17 0.888667 0.000000 0.890850 66.507014 -2.305725 Reshape
19 0.888667 0.000000 0.890850 66.487829 -2.324910 Slice
24 0.888667 0.000000 0.890850 66.392239 -2.420501 Unsqueeze
23 0.888667 0.000000 0.890850 66.343722 -2.469018 Transpose
22 0.888667 0.000000 0.890850 66.328632 -2.484107 Sub
25 0.888667 0.000000 0.890850 65.967035 -2.845704 Where
21 0.888667 0.000000 0.890850 65.839237 -2.973503 Sqrt
20 0.888667 0.000000 0.890850 65.723235 -3.089504 Softmax
1 0.888667 0.000000 0.890850 65.687039 -3.125700 Add
16 0.888667 0.000000 0.890850 65.444072 -3.368667 Relu

With this list of beneficial operators (only MatMul) I can then try quantizing a model:

Code
df = pd.DataFrame([
    quantize_and_evaluate(
        model_name=MODEL_NAME,
        data=data,
        metric=clf_metrics,
        input_column="text",
        config=AutoQuantizationConfig.arm64(
            is_static=False,
            use_symmetric_activations=True,
            use_symmetric_weights=True,
            per_channel=False,
            operators_to_quantize=results[results["samples Δ"] > 0].name.tolist()
        )
    )
])
df

These numbers don’t make sense if the optimizations are independent. What we have is:

operations optimized samples per second
MatMul 92.211473
MatMul and Add 98.460482
7 operations 85.157695
all operations 96.862811

Given that optimizing all operations is better than just optimizing MatMul, but when we evaluate operations independently then only MatMul is beneficial, it must be the case that optimised operations affect each other. If this is the case then there may be a set of optimized operations which perform better than MatMul and Add. It’s not feasible to evaluate all combinations of operator optimizations, so a more efficient means of calculating benefit is required. I wonder if some kind of search would be appropriate here.