Optimizing a BERT Model using Optimimum ONNX

Different ways to Optimize a model with performance
quantization
Published

September 21, 2022

This is a continuation of the previous post. The quickstart section had another section that I really should’ve done at the start. In it the ONNX model is optimized before applying the different quantization settings.

Dataset and Evaluator

This will reuse the dataset and evaluation code from before, so lets load that first.

Code
from typing import TypedDict
from datasets import load_dataset
import evaluate
import pandas as pd

class InputRow(TypedDict):
    label: int
    title: str
    content: str

class TransformedRow(TypedDict):
    label: int
    text: str

def combine(row: InputRow) -> TransformedRow:
    title = row["title"].strip()
    if title[0] not in {".", "!", "?"}:
        title += "."
    content = row["content"].strip()
    if content:
        content = " " + content
    return {
        "label": row["label"],
        "text": title + content
    }

task_evaluator = evaluate.evaluator("text-classification")
clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])
data = load_dataset("amazon_polarity", split="test[:3000]")
data = data.map(combine)

Maximum Optimization

The quickstart documentation uses an ORTOptimizer with an OptimizationConfig object. Looking at this configuration object it appears that the primary setting is the optimization_level which defaults to 1. This is more of an enumeration with the following values:

  • 0 will disable all optimizations
  • 1 will enable basic optimizations
  • 2 will enable basic and extended optimizations, including complex node fusions applied to the nodes assigned to the CPU or CUDA execution provider, making the resulting optimized graph hardware dependent
  • 99 will enable all available optimizations including layout optimizations

I am not too worried about being hardware dependent. It is fine to compile multiple versions of this model for the different CPU GPU types that will be used. So we can try out level 99 to see how much improvement is possible.

Code
from pathlib import Path

QUANTIZATION_FOLDER = Path("/data/blog/2022-09-12-model-quantization")
QUICKSTART_SAVE_FOLDER = QUANTIZATION_FOLDER / "quickstart"
OPTIMIZED_SAVE_FOLDER = QUICKSTART_SAVE_FOLDER / "optimized"
OPTIMIZED_SAVE_FOLDER.mkdir(parents=True, exist_ok=True)

MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"
Code
from pathlib import Path
from typing import Dict
from tempfile import TemporaryDirectory

from datasets import Dataset
from optimum.onnxruntime import ORTModelForSequenceClassification, ORTOptimizer, ORTQuantizer
from optimum.onnxruntime.configuration import AutoQuantizationConfig, OptimizationConfig
from transformers import AutoTokenizer, Pipeline, pipeline
import evaluate


def quantize_and_evaluate(
    model_name: str,
    data: Dataset,
    metric: evaluate.CombinedEvaluations,
    optimization_config: OptimizationConfig,
    quantization_config: AutoQuantizationConfig,
    input_column: str = "text",
    **details,
) -> Dict[str, float]:
    with TemporaryDirectory() as directory:
        directory = Path(directory)
        export_to_onnx(
            model_name=model_name,
            directory=directory,
            config=optimization_config,
        )
        dynamic_quantize_onnx_model(
            directory=directory,
            config=quantization_config,
        )
        quantized_pipeline = load_quantized_pipeline(
            directory=directory,
        )
        
        results = task_evaluator.compute(
            model_or_pipeline=quantized_pipeline,
            data=data,
            metric=metric,
            input_column=input_column,
            label_column="label",
            label_mapping={"POSITIVE": 1, "NEGATIVE": 0},
            device=-1, # CPU
        )
        return results | details



def export_to_onnx(model_name: str, directory: Path, config: OptimizationConfig) -> None:
    """
    Load the model from transformers and export it to the ONNX format.
    """
    model = ORTModelForSequenceClassification.from_pretrained(
        model_name, from_transformers=True
    )
    optimizer = ORTOptimizer.from_pretrained(model)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    optimizer.optimize(
        save_dir=directory,
        optimization_config=config,
    )
    model.config.save_pretrained(directory) # saves config.json
    tokenizer.save_pretrained(directory)


def dynamic_quantize_onnx_model(directory: Path, config: AutoQuantizationConfig) -> None:
    """
    Apply dynamic quantization to the model.
    """
    quantizer = ORTQuantizer.from_pretrained(directory, file_name="model_optimized.onnx")

    quantizer.quantize(
        save_dir=directory,
        quantization_config=config,
    )


def load_quantized_pipeline(directory: Path) -> Pipeline:
    """
    Load the quantized model as a text classification pipeline.
    """
    model = ORTModelForSequenceClassification.from_pretrained(
        directory, file_name="model_optimized_quantized.onnx"
    )
    tokenizer = AutoTokenizer.from_pretrained(directory)

    return pipeline("text-classification", model=model, tokenizer=tokenizer)
import pandas as pd

pd.DataFrame([
    quantize_and_evaluate(
        model_name=MODEL_NAME,
        data=data,
        metric=clf_metrics,
        optimization_config=OptimizationConfig(optimization_level=99),
        quantization_config=AutoQuantizationConfig.arm64(
            is_static=False,
            per_channel=False,
        ),
    )
])
2022-09-21 17:04:07.921107289 [W:onnxruntime:, inference_session.cc:1488 Initialize] Serializing optimized model with Graph Optimization level greater than ORT_ENABLE_EXTENDED and the NchwcTransformer enabled. The generated model may contain hardware specific optimizations, and should only be used in the same environment the model was optimized in.
accuracy f1 precision recall total_time_in_seconds samples_per_second latency_in_seconds
0 0.885 0.887951 0.901715 0.8746 25.665156 116.889997 0.008555

I’m including the warnings here as they are quite significant. If I were to move this model to another machine it might stop working or produce wildly inaccurate results.

These results seem really good. If we compare them to the results from the previous post:

name accuracy samples per second accuracy Δ speed Δ speed ratio
baseline 0.888667 55.432209 0.000000 0.000000 1.000000
dynamic 0.886000 92.315347 -0.002667 36.883138 1.665374
static 0.532000 69.555485 -0.356667 14.123276 1.254785
optimized settings 0.886000 98.460482 -0.002667 43.028273 1.776232
onnx optimizations 0.885000 116.889997 -0.003667 61.457788 2.108702

This is a dramatic improvement and there may be scope for a modest increase if we reapply the quantization settings and operators. These would have to be recalculated from scratch as the changes to the underlying onnx model are likely to result in different outcomes. Furthermore I want to move to a more pragmatic way of selecting those optimizations.

Optimization Settings

The optimization configuration object has more than one setting, and it would be worth evaluating the different combinations. The other settings are around GPU and mixed precision, most of which have a default that I desire:

  • optimize_for_gpu: bool = False I am not aiming for GPU so this default is fine.

  • fp16: bool = False Since I want to fully quantize this model moving to mixed precision would likely introduce more chances for accuracy loss.

  • optimize_with_onnxruntime_only: bool = False > Whether to only use ONNX Runtime to optimize the model and no graph fusion in Python

    This is an interesting one as it shows that there are optimizations that are performed in the huggingface library. The graph fusions are the subject of the following settings.

  • disable_gelu: bool = False This would disable the gelu fusion which is a specific optimization.

  • disable_layer_norm: bool = False This would disable the layer normalization fusion which is a specific optimization.

  • disable_attention: bool = False This would disable the attention fusion which is a specific optimization.

  • disable_skip_layer_norm: bool = False This would disable the skip layer normalization fusion which is a specific optimization.

  • disable_bias_skip_layer_norm: bool = False This would disable two fusions which are specific optimizations.

  • disable_bias_gelu: bool = False This would disable two fusions which are specific optimizations.

  • disable_embed_layer_norm: bool = True This disables the embed layer norm fusion, which is incompatible with ONNX runtime quantization. Since I want to perform that quantization I cannot enable this optimization.

The ones that are most interesting to me are:

  • enable_gelu_approximation: bool = False This is an optimisation which replaces gelu with fast gelu and can impact model accuracy.

  • use_mask_index: bool = False > Whether to use mask index instead of raw attention mask in the attention operator.

    It would be interesting to try this out to see if it speeds things up.

  • no_attention_mask: bool = False > Whether to not use attention masks. Only works for bert model type.

    Again another attention mask optimization.

It would be nice to try out altering these to see if they can produce any speed improvement.

Code
pd.DataFrame([
    quantize_and_evaluate(
        model_name=MODEL_NAME,
        data=data,
        metric=clf_metrics,
        optimization_config=OptimizationConfig(
            optimization_level=99,
            enable_gelu_approximation=enable_gelu_approximation,
            use_mask_index=use_mask_index,
            # no_attention_mask=no_attention_mask,
        ),
        quantization_config=AutoQuantizationConfig.arm64(
            is_static=False,
            per_channel=False,
        ),
        enable_gelu_approximation=enable_gelu_approximation,
        use_mask_index=use_mask_index,
        # no_attention_mask=no_attention_mask,
    )
    for enable_gelu_approximation in [False, True]
    for use_mask_index in [False, True]
    # for no_attention_mask in [False, True]
])
2022-09-21 17:08:51.887826391 [W:onnxruntime:, inference_session.cc:1488 Initialize] Serializing optimized model with Graph Optimization level greater than ORT_ENABLE_EXTENDED and the NchwcTransformer enabled. The generated model may contain hardware specific optimizations, and should only be used in the same environment the model was optimized in.
2022-09-21 17:09:28.271454053 [W:onnxruntime:, inference_session.cc:1488 Initialize] Serializing optimized model with Graph Optimization level greater than ORT_ENABLE_EXTENDED and the NchwcTransformer enabled. The generated model may contain hardware specific optimizations, and should only be used in the same environment the model was optimized in.
2022-09-21 17:10:04.574039548 [W:onnxruntime:, inference_session.cc:1488 Initialize] Serializing optimized model with Graph Optimization level greater than ORT_ENABLE_EXTENDED and the NchwcTransformer enabled. The generated model may contain hardware specific optimizations, and should only be used in the same environment the model was optimized in.
2022-09-21 17:10:40.503322709 [W:onnxruntime:, inference_session.cc:1488 Initialize] Serializing optimized model with Graph Optimization level greater than ORT_ENABLE_EXTENDED and the NchwcTransformer enabled. The generated model may contain hardware specific optimizations, and should only be used in the same environment the model was optimized in.
accuracy f1 precision recall total_time_in_seconds samples_per_second latency_in_seconds enable_gelu_approximation use_mask_index
0 0.885000 0.887951 0.901715 0.874600 25.870445 115.962443 0.008623 False False
1 0.885000 0.887951 0.901715 0.874600 25.910328 115.783946 0.008637 False True
2 0.885667 0.888237 0.905046 0.872041 25.906065 115.802998 0.008635 True False
3 0.885667 0.888237 0.905046 0.872041 26.353920 113.835056 0.008785 True True

Changing these settings does not appear to have any significant positive impact on the model throughput. I’m not confident that these are statistically different to the baseline as I’ve seen that result vary around 115. Applying the optimization at all is still a huge boost.

The no_attention_mask flag had to be disabled as it caused an error. I was expecting it to be applicable as this is a distilled BERT model.

Now I can review the ONNX section of the documentation to see if there is anything else I have missed…

Stuff I missed

One thing that has come up is the QuantizationConfig object which was investigated in the last post is explicitly a holder for the ONNX settings. It would be worth consulting the ONNX documentation directly to find out a pragmatic way to choose the different constructors and settings.

There is also an ORTConfig which wraps the QuantizationConfig and OptimizationConfig. This may be easier to handle than the two separate objects as it might make the transformation code simpler.

More Justifiable Operator Selection

Reviewing the ORTOptimizer I can see that it has a get_fused_operators method. It might be good to review the output of this to see how it interacts with the ONNX optimizations for those operators.

Code
model = ORTModelForSequenceClassification.from_pretrained(
    MODEL_NAME, from_transformers=True
)
optimizer = ORTOptimizer.from_pretrained(model)
optimizer.get_fused_operators(OPTIMIZED_SAVE_FOLDER / "model_optimized.onnx")
{'Attention': 6,
 'BiasGelu': 6,
 'LayerNormalization': 1,
 'SkipLayerNormalization': 12}
Code
pd.DataFrame([
    quantize_and_evaluate(
        model_name=MODEL_NAME,
        data=data,
        metric=clf_metrics,
        optimization_config=OptimizationConfig(
            optimization_level=99,
        ),
        quantization_config=AutoQuantizationConfig.arm64(
            is_static=False,
            per_channel=False,
            operators_to_quantize=[
                "MatMul", # included because it's such a strong contender
                "Attention",
                "BiasGelu",
                "LayerNormalization",
                "SkipLayerNormalization",
            ]
        ),
    )
])
2022-09-22 10:01:36.821332748 [W:onnxruntime:, inference_session.cc:1488 Initialize] Serializing optimized model with Graph Optimization level greater than ORT_ENABLE_EXTENDED and the NchwcTransformer enabled. The generated model may contain hardware specific optimizations, and should only be used in the same environment the model was optimized in.
accuracy f1 precision recall total_time_in_seconds samples_per_second latency_in_seconds
0 0.889667 0.892498 0.906332 0.879079 22.550959 133.032036 0.007517

I ran this twice and the first time it got ~122 samples per second. This time it’s 133 samples per second. In both cases it’s a huge boost from the 116 samples per second that the earlier optimized model got.

The variance in these results does suggest that the dataset isn’t big enough. There are a few other methods that seem interesting.

Code
optimizer.get_nodes_number_difference(
    onnx_model_path=QUICKSTART_SAVE_FOLDER / "dynamic" / "model.onnx",
    onnx_optimized_model_path=OPTIMIZED_SAVE_FOLDER / "model_optimized.onnx",
)
537
Code
optimizer.get_operators_difference(
    onnx_model_path=QUICKSTART_SAVE_FOLDER / "dynamic" / "model.onnx",
    onnx_optimized_model_path=OPTIMIZED_SAVE_FOLDER / "model_optimized.onnx",
)
{'Erf': 6,
 'Transpose': 24,
 'Unsqueeze': 36,
 'Add': 80,
 'MatMul': 30,
 'Mul': 25,
 'Where': 6,
 'Cast': 5,
 'Softmax': 6,
 'SkipLayerNormalization': -12,
 'Pow': 13,
 'Constant': 78,
 'LayerNormalization': -1,
 'Shape': 18,
 'Reshape': 30,
 'Div': 25,
 'Equal': 6,
 'Gather': 12,
 'Identity': 74,
 'ReduceMean': 26,
 'Sub': 13,
 'Sqrt': 13,
 'Attention': -6,
 'BiasGelu': -6,
 'Concat': 30,
 'Expand': 6}

To me it looks like the optimization has removed the Attention, BiasGelu, LayerNormalization and SkipLayerNormalization nodes. That’s quite interesting as it suggests that the optimization that I just performed was pointless? It’s interesting.

This is also a different list of operators to those returned by inspecting the ONNX model in the previous post.

Code
pd.DataFrame([
    quantize_and_evaluate(
        model_name=MODEL_NAME,
        data=data,
        metric=clf_metrics,
        optimization_config=OptimizationConfig(
            optimization_level=99,
        ),
        quantization_config=AutoQuantizationConfig.arm64(
            is_static=False,
            per_channel=False,
            operators_to_quantize=sorted(
                operator
                for operator, count in optimizer.get_operators_difference(
                    onnx_model_path=QUICKSTART_SAVE_FOLDER / "dynamic" / "model.onnx",
                    onnx_optimized_model_path=OPTIMIZED_SAVE_FOLDER / "model_optimized.onnx",
                ).items()
                if count > 0
            ),
        ),
    )
])
2022-09-22 10:25:12.715598468 [W:onnxruntime:, inference_session.cc:1488 Initialize] Serializing optimized model with Graph Optimization level greater than ORT_ENABLE_EXTENDED and the NchwcTransformer enabled. The generated model may contain hardware specific optimizations, and should only be used in the same environment the model was optimized in.
accuracy f1 precision recall total_time_in_seconds samples_per_second latency_in_seconds
0 0.885333 0.887875 0.904983 0.871401 28.347616 105.829003 0.009449
Code
pd.DataFrame([
    quantize_and_evaluate(
        model_name=MODEL_NAME,
        data=data,
        metric=clf_metrics,
        optimization_config=OptimizationConfig(
            optimization_level=99,
        ),
        quantization_config=AutoQuantizationConfig.avx2(
            is_static=False,
            per_channel=False,
            operators_to_quantize=sorted(
                operator
                for operator, count in optimizer.get_operators_difference(
                    onnx_model_path=QUICKSTART_SAVE_FOLDER / "dynamic" / "model.onnx",
                    onnx_optimized_model_path=OPTIMIZED_SAVE_FOLDER / "model_optimized.onnx",
                ).items()
                if count < 0
            ),
        ),
    )
])
2022-09-22 10:30:37.736121741 [W:onnxruntime:, inference_session.cc:1488 Initialize] Serializing optimized model with Graph Optimization level greater than ORT_ENABLE_EXTENDED and the NchwcTransformer enabled. The generated model may contain hardware specific optimizations, and should only be used in the same environment the model was optimized in.
accuracy f1 precision recall total_time_in_seconds samples_per_second latency_in_seconds
0 0.886333 0.88838 0.909517 0.868202 35.276149 85.043295 0.011759

The fact that neither side (operators added or operators removed) leads to an improvement is quite interesting.

I’ve also been reading about the arm64 / avx2 / avx512 / avx512_vnni divide and it’s pretty simple to be consistent about this. The /proc/cpuinfo file has the flags for your cpu and it contains those strings (or possibly avx512vnni for the last one). The pytorch test code has an example test for this that I can reuse:

Code
def get_cpu_flags() -> Dict[str, bool]:
    text = Path("/proc/cpuinfo").read_text()
    flags = {
        key: key in text
        for key in ["arm64", "avx2", "avx512"]
    }
    flags["avx512_vnni"] = ("avx512vnni" in text) or ("avx512_vnni" in text)
    return flags

get_cpu_flags()
{'arm64': False, 'avx2': True, 'avx512': False, 'avx512_vnni': False}

I’ve checked my CPU flags as well as the processor specification and (unfortunately) this is accurate. I should be using avx2.

The quantized model will be running on AWS so using the flags above will help select the correct quantization configuration. It’s interesting that the use of the wrong AutoQuantizationConfig wasn’t a problem.

Code
from tqdm.auto import tqdm

df = pd.DataFrame([
    quantize_and_evaluate(
        model_name=MODEL_NAME,
        data=data,
        metric=clf_metrics,
        optimization_config=OptimizationConfig(
            optimization_level=99,
        ),
        quantization_config=AutoQuantizationConfig.avx2(
            is_static=False,
            per_channel=False,
            operators_to_quantize=[
                "MatMul", # included because it's such a strong contender
                "Attention",
                "BiasGelu",
                "LayerNormalization",
                "SkipLayerNormalization",
            ]
        ),
    )
    for _ in tqdm(range(20))
])
df.samples_per_second.describe()
count     20.000000
mean     107.057550
std        1.853127
min      101.231848
25%      107.185212
50%      107.631845
75%      108.037973
max      108.878874
Name: samples_per_second, dtype: float64
Code
from tqdm.auto import tqdm

df = pd.DataFrame([
    quantize_and_evaluate(
        model_name=MODEL_NAME,
        data=data,
        metric=clf_metrics,
        optimization_config=OptimizationConfig(
            optimization_level=99,
        ),
        quantization_config=AutoQuantizationConfig.arm64(
            is_static=False,
            per_channel=False,
            operators_to_quantize=[
                "MatMul", # included because it's such a strong contender
                "Attention",
                "BiasGelu",
                "LayerNormalization",
                "SkipLayerNormalization",
            ]
        ),
    )
    for _ in tqdm(range(20))
])
df.samples_per_second.describe()
count     20.000000
mean     127.737016
std        2.591476
min      121.209252
25%      127.422428
50%      128.360943
75%      129.505146
max      131.198863
Name: samples_per_second, dtype: float64

So I was uncertain about the variation in individual runs. The variation is larger than it was before but the results are pretty clear. On my Intel avx2 capable CPU it is significantly better to use the arm optimizations than the avx2 ones.

Code
from tqdm.auto import tqdm

df = pd.DataFrame([
    quantize_and_evaluate(
        model_name=MODEL_NAME,
        data=data,
        metric=clf_metrics,
        optimization_config=OptimizationConfig(
            optimization_level=99,
        ),
        quantization_config=AutoQuantizationConfig.avx2(
            is_static=False,
            per_channel=False,
            operators_to_quantize=[
                "MatMul", # included because it's such a strong contender
                "Attention",
                "BiasGelu",
                "LayerNormalization",
                "SkipLayerNormalization",
            ]
        ),
        name="avx2"
    ),
    quantize_and_evaluate(
        model_name=MODEL_NAME,
        data=data,
        metric=clf_metrics,
        optimization_config=OptimizationConfig(
            optimization_level=99,
        ),
        quantization_config=AutoQuantizationConfig.arm64(
            is_static=False,
            per_channel=False,
            operators_to_quantize=[
                "MatMul", # included because it's such a strong contender
                "Attention",
                "BiasGelu",
                "LayerNormalization",
                "SkipLayerNormalization",
            ]
        ),
        name="arm64"
    )
])
df
accuracy f1 precision recall total_time_in_seconds samples_per_second latency_in_seconds name
0 0.887000 0.889899 0.903694 0.876520 30.471591 98.452359 0.010157 avx2
1 0.889667 0.892498 0.906332 0.879079 22.985481 130.517175 0.007662 arm64

This is really crazy. The arm64 version is better in every respect, and the speed difference is dramatic.