Different ways to Quantize a model with performance
quantization
Published
September 12, 2022
The state of the art for model quantization has progressed quite a bit. It seems that huggingface is incorporating all of the coolest tech right now so they have quite a lot of quantization and optimization available through the Optimum library. This post is an investigation of that library using ONNX and Intel optimizations to try to improve the performance of text sentiment classification.
I am going to run through all of the parts of the Optimum documentation and see how easy each one is to get working. It would also be nice to test the speed and performance of each approach to see how they compare.
Dataset
To provide a baseline performance metric we can take the Standford Sentiment Treebank (Socher et al. 2013). Huggingface provide a way to load and use that.
Socher, Richard, Alex Perelygin, Jean Wu, Jason Chuang, Christopher D. Manning, Andrew Ng, and Christopher Potts. 2013. “Recursive Deep Models for Semantic Compositionality over a Sentiment Treebank.” In Proceedings of the 2013 Conference on Empirical Methods in Natural Language Processing, 1631–42. Seattle, Washington, USA: Association for Computational Linguistics. https://aclanthology.org/D13-1170.
That first row looks quite weird, and the text is quite short. Even so this has been used to train different sentiment models and there are even pretrained versions available that used this. I can use this to quantify the change in accuracy of the model.
Baseline Evaluation
The baseline model that I am going to use is the DistilBERT base uncased finetuned SST-2 model. This is a simple model that has already been distilled and so any quantized version of the model should be at least this good. This can be used with and without the GPU to provide baseline timing and accuracy results.
This is also an opportunity to try out the huggingface evaluate library. I’m quite hopeful about this as previously I’ve had to construct trainers or use the metrics (which were nice but seemed misaligned with the trainer). If this makes it easier to work with then that would be great.
from typing import Dict, Unionimport pandas as pddef format_results(results: Dict[str, Union[float, Dict]]) -> pd.DataFrame:# the bootstrap results are dicts with confidence_interval, standard_error and score# the non bootstrap results are just scores.# for now map all that to just scoresdef expand_row(name: str, value: Union[float, Dict]) -> Dict[str, Union[str, float]]:ifisinstance(value, float):return {"name": name,"value": value, }return {"name": name,"value": value["score"],"confidence_low": value["confidence_interval"][0],"confidence_high": value["confidence_interval"][1],"std": value["standard_error"], } df = pd.DataFrame([ expand_row(name, value)for name, value in results.items() ]) df = df.set_index("name")return df
If I only inspect the score value from the results then the two results are absolutely equivalent and the non-bootstrap approach is about 30x quicker. The bootstrap evaluation does provide a lot more information about the quality of the model, so it would be nice to use that for the full evaluations.
The samples per second appears to be comparable between the two and is one of the metrics I was most interested in. This is on the GPU though, so how does it compare when moved to CPU?
It’s gone from 319 samples per second to 103. So as a baseline we have:
device
accuracy
samples/second
GPU
0.910550
319.239023
CPU
0.910550
102.782367
Overall I think that the evaluation framework is excellent and very easy to use.
Now we can try comparing this to different versions of the quantized or optimized model.
Optimum
To start we can try out the quickstart example. I’ve had to adjust the code slightly as the original did not work (opened an issue about that).
ONNX Dynamic Quantization
This will convert the model to the ONNX format and then quantize it. The quantization will be dynamic which means that inputs of any length can be provided to the quantized model. Keeping the model dynamic restricts the optimizations that can be performed to the model.
Code
from pathlib import Pathfrom optimum.onnxruntime import ORTModelForSequenceClassification, ORTQuantizerfrom optimum.onnxruntime.configuration import AutoQuantizationConfigfrom transformers import AutoTokenizer, Pipeline, pipelineQUANTIZATION_FOLDER = Path("/data/blog/2022-09-12-model-quantization")QUICKSTART_SAVE_FOLDER = QUANTIZATION_FOLDER /"quickstart"DYNAMIC_SAVE_FOLDER = QUICKSTART_SAVE_FOLDER /"dynamic"DYNAMIC_SAVE_FOLDER.mkdir(parents=True, exist_ok=True)MODEL_NAME ="distilbert-base-uncased-finetuned-sst-2-english"def export_to_onnx(model_name: str, directory: Path) ->None:""" Load the model from transformers and export it to the ONNX format. """ model = ORTModelForSequenceClassification.from_pretrained( model_name, from_transformers=True ) tokenizer = AutoTokenizer.from_pretrained(model_name) model.save_pretrained(directory, file_name="model.onnx") tokenizer.save_pretrained(directory)def dynamic_quantize_onnx_model(directory: Path, **config) ->None:""" Apply dynamic quantization to the model. """ quantization_config = AutoQuantizationConfig.arm64(**config) quantizer = ORTQuantizer.from_pretrained(directory) quantizer.quantize( save_dir=directory, quantization_config=quantization_config, )def load_quantized_pipeline(directory: Path) -> Pipeline:""" Load the quantized model as a text classification pipeline. """ model = ORTModelForSequenceClassification.from_pretrained( directory, file_name="model_quantized.onnx" ) tokenizer = AutoTokenizer.from_pretrained(directory)return pipeline("text-classification", model=model, tokenizer=tokenizer)export_to_onnx( model_name=MODEL_NAME, directory=DYNAMIC_SAVE_FOLDER)dynamic_quantize_onnx_model( directory=DYNAMIC_SAVE_FOLDER, is_static=False, per_channel=False,)quantized_pipeline = load_quantized_pipeline( directory=DYNAMIC_SAVE_FOLDER,)
At this point we have a pipeline which can be used directly:
I’ve created the pipeline as it’s a suitable format to use with the evaluate framework. As such we can now compare the performance of the quantized model to the original distilled model.
CPU times: user 1min 47s, sys: 892 ms, total: 1min 48s
Wall time: 1min 10s
value
confidence_low
confidence_high
std
name
accuracy
0.897936
0.885770
0.904523
0.006890
f1
0.901874
0.888681
0.910548
0.007186
precision
0.883369
0.856407
0.895553
0.012170
recall
0.921171
0.903992
0.935958
0.011506
total_time_in_seconds
4.091242
NaN
NaN
NaN
samples_per_second
213.138190
NaN
NaN
NaN
latency_in_seconds
0.004692
NaN
NaN
NaN
This is encouraging. The results from this quantization compare well to the baseline:
device
accuracy
accuracy Δ
samples/second
relative speed
Distilled GPU
0.910550
319.239023
Distilled CPU
0.910550
102.782367
Dynamic Quantized CPU
0.897936
-0.012614
269.176181
2.618895
So for a cost of 0.013 accuracy we more than double in speed. This is a good start.
ONNX Static Quantization
Now we can try the static quantization. This involves loading a calibration dataset that can be used to determine how to create the static version of the model. It’s good to use the train dataset for this as it covers the required variation of inputs.
Code
from pathlib import Pathfrom functools import partialfrom optimum.onnxruntime import ORTQuantizerfrom optimum.onnxruntime.configuration import AutoCalibrationConfig, AutoQuantizationConfigfrom transformers import AutoTokenizerSTATIC_SAVE_FOLDER = QUICKSTART_SAVE_FOLDER /"static"STATIC_SAVE_FOLDER.mkdir(parents=True, exist_ok=True)def static_quantize_onnx_model(directory: Path, **config) ->None:""" Apply dynamic quantization to the model. """ quantization_config = AutoQuantizationConfig.arm64(**config) quantizer = ORTQuantizer.from_pretrained(directory) tokenizer = AutoTokenizer.from_pretrained(directory)def preprocess_fn(row, tokenizer):return tokenizer(row["sentence"])# Create the calibration dataset calibration_dataset = quantizer.get_calibration_dataset("sst2", preprocess_function=partial(preprocess_fn, tokenizer=tokenizer), num_samples=50, dataset_split="train", )# Create the calibration configuration containing the parameters related to calibration. calibration_config = AutoCalibrationConfig.minmax(calibration_dataset)# Perform the calibration step: computes the activations quantization ranges ranges = quantizer.fit( dataset=calibration_dataset, calibration_config=calibration_config,# onnx_model_path=onnx_path, operators_to_quantize=quantization_config.operators_to_quantize, )# Apply static quantization on the model quantizer.quantize( save_dir=directory, quantization_config=quantization_config, calibration_tensors_range=ranges, )export_to_onnx( model_name=MODEL_NAME, directory=STATIC_SAVE_FOLDER)static_quantize_onnx_model( directory=STATIC_SAVE_FOLDER, is_static=True, per_channel=False,)quantized_pipeline = load_quantized_pipeline( directory=STATIC_SAVE_FOLDER,)
/home/matthew/.local/share/virtualenvs/blog-1tuLwbZm/lib/python3.10/site-packages/scipy/stats/_resampling.py:118: RuntimeWarning: invalid value encountered in double_scalars
a_hat = num / den
/home/matthew/.local/share/virtualenvs/blog-1tuLwbZm/lib/python3.10/site-packages/scipy/stats/_resampling.py:92: DegenerateDataWarning: The bootstrap distribution is degenerate; the confidence interval is not defined.
warnings.warn(DegenerateDataWarning(msg))
CPU times: user 1min 47s, sys: 769 ms, total: 1min 48s
Wall time: 1min 9s
value
confidence_low
confidence_high
std
name
accuracy
0.544725
0.520952
0.557125
0.012889
f1
0.191446
0.183737
0.199604
0.008865
precision
1.000000
NaN
NaN
0.000000
recall
0.105856
0.101163
0.110867
0.005378
total_time_in_seconds
4.252979
NaN
NaN
NaN
samples_per_second
205.032759
NaN
NaN
NaN
latency_in_seconds
0.004877
NaN
NaN
NaN
It’s good to know that this can be done. The results are absolute trash though:
device
accuracy
accuracy Δ
samples/second
relative speed
Distilled GPU
0.910550
319.239023
Distilled CPU
0.910550
102.782367
Dynamic Quantized CPU
0.897936
-0.012614
269.176181
2.618895
Static Quantized CPU
0.544725
-0.365825
205.032759
1.994824
The static quantized model is both slower than dynamic quantization and has terrible terrible accuracy. Remember that an accuracy of 0.5 would be expected when randomly choosing the answer!
AutoQuantizationConfig… arm64??
I’m on a linux machine with an x86_64 CPU. The examples in the quickstart use AutoQuantizationConfig.arm64. Is this a problem? I can try out the different methods for the AutoQuantizationConfig to see what the effects are.
It would also be interesting to try varying the different settings that are available.
Code
from pathlib import Pathfrom optimum.onnxruntime import ORTQuantizerfrom optimum.onnxruntime.configuration import AutoQuantizationConfigdef custom_quantize_onnx_model(directory: Path, config: AutoQuantizationConfig) ->None:""" Apply dynamic quantization to the model. """ quantizer = ORTQuantizer.from_pretrained(directory) quantizer.quantize( save_dir=directory, quantization_config=config, )
A small increase in both speed and accuracy is available by customizing the quantization configuration. (In this run I got 271 samples per second but previously I managed 278, so there is clearly some variation). That puts us at:
device
accuracy
accuracy Δ
samples/second
relative speed
Distilled GPU
0.910550
319.239023
Distilled CPU
0.910550
102.782367
Dynamic Quantized CPU
0.897936
-0.012614
269.176181
2.618895
Static Quantized CPU
0.544725
-0.365825
205.032759
1.994824
avx512_vnni Quantized CPU
0.897936
-0.012614
278.353499
2.708183
So searching for the correct settings gave us about a 10% speed improvement.
A larger problem here is that the speed results are not stable. The dataset can be processed in about two seconds so the variations in time are problematic. I need to be able to process this a lot so having a larger dataset is also problematic. Having at least 10 seconds worth of data should stabilize the times a bit more.
Larger Dataset
I’m going to try using the amazon_polarity dataset (Zhang, Zhao, and LeCun 2015). This is still a two class sentiment dataset that is substantially larger. The only problem is that the text is split into title and content
This has more text so the evaluation has taken longer. I’m hopeful that this will produce more consistent results. One way to check that is to run it several times and see how it varies.
count 20.000000
mean 92.570765
std 0.917716
min 92.052706
25% 92.325451
50% 92.364204
75% 92.405892
max 96.440913
Name: samples_per_second, dtype: float64
This has a standard deviation of less than 1, so I think this dataset is suitable. Any results which vary by less than ~1 second are unreliable. The baseline and existing comparisons need to be recalculated so we can continue.
The samples per second are different enough to be significant. A difference of ~6 is about 6 standard deviations. So optimizing the quantization settings certainly pays off.
The optimized settings themselves are more questionable. Each row differs from the previous by less than 1. Overall I do believe that the top optimizations are a benefit however I would like a more principled way to select the settings.
Operators
The next thing would be the operators_to_quantize. This is a list of the operations within the model which can be optimized, and defaults to ['MatMul', 'Add'].
It would be good to find a complete list of the supported operators and then start trying out more of them. I’ve started by finding this issue which discusses expanding the list of quantizable operators. I’ll try out the proposed list:
It’s interesting that this model has now become slower. In the ticket they talk about a significant size reduction of nearly half, so it may be that the application of quantization to some operators makes the model smaller but slower.
The next thing would be to find all of the operators that exist in the model. Looking at this discussion provides me with a way to inspect the ONNX model to get the operations:
from pathlib import Pathfrom typing import Listimport onnxdef get_model_operators(path: Path) -> List[str]: model = onnx.load(path)returnsorted({ node.op_typefor node in model.graph.node })get_model_operators(DYNAMIC_SAVE_FOLDER /"model.onnx")
Given that there are 25 operators in the model it’s not feasible to evaluate all of the combinations (~33 million). Instead I can establish a baseline for the model with no optimizations then try each operator in turn, comparing it to the baseline.
These numbers don’t make sense if the optimizations are independent. What we have is:
operations optimized
samples per second
MatMul
92.211473
MatMul and Add
98.460482
7 operations
85.157695
all operations
96.862811
Given that optimizing all operations is better than just optimizing MatMul, but when we evaluate operations independently then only MatMul is beneficial, it must be the case that optimised operations affect each other. If this is the case then there may be a set of optimized operations which perform better than MatMul and Add. It’s not feasible to evaluate all combinations of operator optimizations, so a more efficient means of calculating benefit is required. I wonder if some kind of search would be appropriate here.
Operator Beam Search
Let’s start by bringing out the big guns. We can see that there are clear differences between the different operators, and that combining operators also has an effect. This suggests that the search space cannot be solved by merely reviewing the effect of each operator independently.
There are 25 operators in the model. If we were to try every combination of these operators then it would take about a century to fully evaluate (\(2^{25} = 33,554,432\), there are around 31 million seconds in a year, and it’s about 90s per evaluation). So that is totally impractical.
Instead we can use a technique from text generation and apply BEAM search. This involves performing a breadth first search but keeping only the top N results at any step. Since this is a search where some results may have fewer operations than others, we can stop searching as soon as all of the top N operation sets have fewer operations than the sets in the last round of evaluation (i.e. the last round did not find any good results).
Code
from typing import List, Set, Dict, Tupledef beam_search(operations: List[str], top_n: int=4) -> Tuple[Set[str], Dict[Set[str], Dict[str, float]]]: operations =set(operations)def make_config(operators: Set[str]) -> AutoQuantizationConfig:return AutoQuantizationConfig.arm64( is_static=False, use_symmetric_activations=True, use_symmetric_weights=True, per_channel=False, operators_to_quantize=sorted(operators) )def best(results: Dict[Set[str], Dict[str, float]], top_n: int) -> List[Set[str]]: scores = [ (result["samples_per_second"], operators)for operators, result in results.items() ] scores =sorted(scores, reverse=True, key=lambda row: row[0]) scores = scores[:top_n]return [ operatorsfor _, operators in scores ]def most_operators(results: Dict[Set[str], Dict[str, float]]) ->int:returnmax(len(operators) for operators in results.keys())with tqdm(total=2) as progress: cache = {frozenset(): quantize_and_evaluate( model_name=MODEL_NAME, data=data, metric=clf_metrics, input_column="text", config=make_config(set()), ) } progress.update(1)whileTrue: tops = best(cache, top_n=top_n) maximum = most_operators(cache) edge = [operators for operators in tops iflen(operators) == maximum]ifnot edge:return tops[0], cache# generate the list of operations to evaluate evaluate = {frozenset(operators | {operation})for operators in topsfor operation in operations - operators } evaluate = { operatorsfor operators in evaluateif operators notin cache }ifnot evaluate:return tops[0], cache progress.total +=len(evaluate)for operators in evaluate: cache[operators] = quantize_and_evaluate( model_name=MODEL_NAME, data=data, metric=clf_metrics, input_column="text", config=make_config(operators), ) progress.update(1)
It’s taken a long time to compute but 365 evaluations later we have some operations that could be an improvement. Remember that we are interested in the operations that result in a speed at least a second better than Add and MatMul together. The very best result is:
This shows me that there is a significant result across all of these evaluations, but that the difference between the top results is not significant. It might be better to take the most commonly occurring operations from the beneficial evaluations as they are more likely to be consistently better.
The samples per second for this is within the standard deviation of the very best, so while the measured
operations optimized
samples per second
MatMul
92.211473
MatMul and Add
98.460482
MatMul and Add in BEAM
95.037586
Most common operations
97.719777
Best operations
98.264263
The fact that the recorded speed of MatMul and Add varies so much between the previous evaluation and the BEAM search is not encouraging. It might be worth using the six operators that I found but I’m not confident.
At this point this evaluation of the ONNX optimizations seems complete. I am going to make a new post to evaluate the Intel optimizations available. It would also be good to evaluate ONNX directly as the list of available operations to optimize includes Attention but was not listed in the model. If that were fixed it might be a significant improvement.