ONNX Python vs Java

Is ONNX really a reliable way to translate code from Python to Java?
quantization
Published

November 23, 2022

One of the attractive things about ONNX is the ability to translate code from Python to Java. This allows a production model to run in Java, which engineers often prefer, while data scientists can work on it in Python. If this is going to work then the Java version needs to produce the same output as the Python version. Does it?

This post will be an investigation of that. For most of the posts on this blog I include everything required to reproduce the results. Since jupyter cannot run Java code I have created a separate repo that can load an ONNX model and allow you to perform inference over a simple REST API. The code is available here.

Model

The task that we perform is not significant, so I am going to use a pretrained sentiment model from the huggingface hub. This one is distilbert base uncased (Sanh et al. 2019) and has been fine tuned for sentiment on SST2 (Socher et al. 2013). I am going to export it to the onnx format and load it using pure onnx code. The huggingface optimum library is not available in Java, so I will not be using that, to try to keep this as fair as possible.

Sanh, Victor, Lysandre Debut, Julien Chaumond, and Thomas Wolf. 2019. “DistilBERT, a Distilled Version of BERT: Smaller, Faster, Cheaper and Lighter.” arXiv. https://doi.org/10.48550/ARXIV.1910.01108.
Socher, Richard, Alex Perelygin, Jean Wu, Jason Chuang, Christopher D. Manning, Andrew Ng, and Christopher Potts. 2013. “Recursive Deep Models for Semantic Compositionality over a Sentiment Treebank.” In Proceedings of the 2013 Conference on Empirical Methods in Natural Language Processing, 1631–42. Seattle, Washington, USA: Association for Computational Linguistics. https://aclanthology.org/D13-1170.
Code
from pathlib import Path

MODEL_NAME = "nlptown/bert-base-multilingual-uncased-sentiment"

MODEL_FOLDER = Path("/data/blog/2022/11/23/onnx-python-vs-java")
MODEL_FILE = MODEL_FOLDER / "model.onnx"
MODEL_OPTIMIZED_FILE = MODEL_FOLDER / "model_optimized.onnx"
MODEL_QUANTIZED_FILE = MODEL_FOLDER / "model_optimized_quantized.onnx"

MODEL_FOLDER.mkdir(parents=True, exist_ok=True)
Code
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)

inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
with torch.no_grad():
    logits = model(**inputs).logits

predicted_class_id = logits.argmax().item()
model.config.id2label[predicted_class_id]
'5 stars'
Code
from pathlib import Path

from optimum.onnxruntime import ORTModelForSequenceClassification, ORTQuantizer
from optimum.onnxruntime.configuration import AutoQuantizationConfig
from transformers import AutoTokenizer, Pipeline, pipeline

model = ORTModelForSequenceClassification.from_pretrained(
    MODEL_NAME, from_transformers=True
)
model.save_pretrained(MODEL_FOLDER, file_name="model.onnx")

Pytorch to Unoptimized ONNX

Let’s start comparing the unoptimized ONNX model to the original pytorch / huggingface version. We can then try this same comparison against the java version.

Code
import onnxruntime as ort

def python_inference(
    ort_session: ort.InferenceSession,
    input_ids: torch.Tensor,
    token_type_ids: torch.Tensor,
    attention_mask: torch.Tensor,
) -> np.array:
    return ort_session.run(
        None,
        {
            "input_ids": input_ids.numpy(),
            "token_type_ids": token_type_ids.numpy(),
            "attention_mask": attention_mask.numpy()
        }
    )[0]
Code
import onnxruntime as ort
import numpy as np
import pandas as pd

ort_session = ort.InferenceSession(str(MODEL_FILE))
ort_logits = python_inference(ort_session=ort_session, **inputs)

pd.DataFrame([
    {
        "label": label,
        "pytorch": logits[0, index].item(),
        "py-onnx": ort_logits[0, index],
        "pytorch py-onnx Δ": logits[0, index].item() - ort_logits[0, index],
    }
    for index, label in model.config.id2label.items()
])
label pytorch py-onnx pytorch py-onnx Δ
0 1 star -1.625722 -1.625721 -8.344650e-07
1 2 stars -1.541344 -1.541344 -3.576279e-07
2 3 stars 0.013823 0.013823 3.576279e-07
3 4 stars 0.995189 0.995188 4.172325e-07
4 5 stars 1.702700 1.702700 3.576279e-07

At this point this is working pretty well. The output for the ONNX version of the model is around \(\frac{1}{10000000}\) different from the original.

How does Java compare?

Code
import requests
import numpy as np

def java_inference(
    input_ids: torch.Tensor,
    token_type_ids: torch.Tensor,
    attention_mask: torch.Tensor,
) -> np.array:
    response = requests.post(
        "http://localhost:8080/inference",
        json={
            "inputIds": input_ids.tolist(),
            "tokenTypeIds": token_type_ids.tolist(),
            "attentionMask": attention_mask.tolist(),
        },
    )
    logits = response.json()["logits"]
    logits = np.array(logits)
    return logits
Code
import pandas as pd

java_logits = java_inference(**inputs)

pd.DataFrame([
    {
        "label": label,
        "pytorch": logits[0, index].item(),
        "py-onnx": ort_logits[0, index],
        "java-onnx": java_logits[0, index],
        "pytorch py-onnx Δ": logits[0, index].item() - ort_logits[0, index],
        "pytorch java-onnx Δ": logits[0, index].item() - java_logits[0, index],
        "py-onnx java-onnx Δ": ort_logits[0, index] - java_logits[0, index],
    }
    for index, label in model.config.id2label.items()
])
label pytorch py-onnx java-onnx pytorch py-onnx Δ pytorch java-onnx Δ py-onnx java-onnx Δ
0 1 star -1.625722 -1.625721 -1.625722 -8.344650e-07 -7.890854e-07 4.537964e-08
1 2 stars -1.541344 -1.541344 -1.541344 -3.576279e-07 -3.658020e-07 -8.174133e-09
2 3 stars 0.013823 0.013823 0.013823 3.576279e-07 3.574395e-07 -1.883955e-10
3 4 stars 0.995189 0.995188 0.995188 4.172325e-07 4.342598e-07 1.702728e-08
4 5 stars 1.702700 1.702700 1.702700 3.576279e-07 3.765106e-07 1.888275e-08

This is extremely consistent with the output of the python code. So far, so good.

PyTorch to Optimized and Quantized ONNX

Now we can try using the huggingface optimum library to optimize the model, and then compare the results once again.

Code
from pathlib import Path

from optimum.onnxruntime import (
    ORTModelForSequenceClassification,
    ORTOptimizer,
    ORTQuantizer,
)
from optimum.onnxruntime.configuration import AutoQuantizationConfig, OptimizationConfig


def optimize(
    model_folder: Path,
    output_folder: Path,
    optimization_config: OptimizationConfig,
    quantization_config: AutoQuantizationConfig,
) -> None:
    """
    Optimize and quantize the model, writing it to the output folder.
    """
    model = _to_onnx(model_folder)
    _optimize(model=model, config=optimization_config, output_folder=output_folder)
    _quantize(folder=output_folder, config=quantization_config)

def _to_onnx(model_folder: Path) -> ORTModelForSequenceClassification:
    return ORTModelForSequenceClassification.from_pretrained(
        model_folder,
        from_transformers=True,
        num_labels=5,
    )

def _optimize(
    model: ORTModelForSequenceClassification,
    config: OptimizationConfig,
    output_folder: Path,
) -> None:
    optimizer = ORTOptimizer.from_pretrained(model)

    optimizer.optimize(
        save_dir=output_folder,
        optimization_config=config,
    )
    model.config.id2label = {0: "NEGATIVE", 1: "POSITIVE"}
    model.config.save_pretrained(output_folder)  # saves config.json

def _quantize(
    folder: Path,
    config: AutoQuantizationConfig,
) -> None:
    quantizer = ORTQuantizer.from_pretrained(folder, file_name="model_optimized.onnx")
    quantizer.quantize(
        save_dir=folder,
        quantization_config=config,
    )

optimize(
    model_folder=MODEL_NAME,
    output_folder=MODEL_FOLDER,
    optimization_config=OptimizationConfig(optimization_level=99),
    quantization_config=AutoQuantizationConfig.avx2(is_static=False),
)
2022-11-23 17:35:04.025416476 [W:onnxruntime:, inference_session.cc:1458 Initialize] Serializing optimized model with Graph Optimization level greater than ORT_ENABLE_EXTENDED and the NchwcTransformer enabled. The generated model may contain hardware specific optimizations, and should only be used in the same environment the model was optimized in.
Code
import onnxruntime as ort
import numpy as np
import pandas as pd

ort_session = ort.InferenceSession(str(MODEL_QUANTIZED_FILE))
ort_quantized_logits = python_inference(ort_session=ort_session, **inputs)

pd.DataFrame([
    {
        "label": label,
        "pytorch": logits[0, index].item(),
        "py-quantized": ort_quantized_logits[0, index],
        "pytorch py-quantized Δ": logits[0, index].item() - ort_quantized_logits[0, index],
    }
    for index, label in model.config.id2label.items()
])
label pytorch py-quantized pytorch py-quantized Δ
0 1 star -1.625722 -1.579802 -0.045920
1 2 stars -1.541344 -1.532239 -0.009105
2 3 stars 0.013823 -0.039851 0.053674
3 4 stars 0.995189 0.932431 0.062757
4 5 stars 1.702700 1.748329 -0.045629
Code
import pandas as pd

java_quantized_logits = java_inference(**inputs)

pd.DataFrame([
    {
        "label": label,
        "pytorch": logits[0, index].item(),
        "py-quantized": ort_quantized_logits[0, index],
        "java-quantized": java_quantized_logits[0, index],
        "pytorch py-quantized Δ": logits[0, index].item() - ort_quantized_logits[0, index],
        "pytorch java-quantized Δ": logits[0, index].item() - java_quantized_logits[0, index],
        "py-quantized java-quantized Δ": ort_quantized_logits[0, index] - java_quantized_logits[0, index],
    }
    for index, label in model.config.id2label.items()
])
label pytorch py-quantized java-quantized pytorch py-quantized Δ pytorch java-quantized Δ py-quantized java-quantized Δ
0 1 star -1.625722 -1.579802 -1.579802 -0.045920 -0.045920 4.450531e-08
1 2 stars -1.541344 -1.532239 -1.532239 -0.009105 -0.009105 -1.789398e-08
2 3 stars 0.013823 -0.039851 -0.039851 0.053674 0.053674 1.735687e-09
3 4 stars 0.995189 0.932431 0.932431 0.062757 0.062757 2.175903e-10
4 5 stars 1.702700 1.748329 1.748329 -0.045629 -0.045629 1.016235e-09

The quantization process has dramatically increased the amount of difference, however the java version of the code is still consistent with the python version.

Systematic Testing

Let’s make this a more thorough test by expanding the dataset trying to find sentences that cause greater differences. I care most about consistency rather than classification, so a dataset with more text is all that I require. For now I am going to use the IMDB dataset from huggingface (Maas et al. 2011).

Maas, Andrew L., Raymond E. Daly, Peter T. Pham, Dan Huang, Andrew Y. Ng, and Christopher Potts. 2011. “Learning Word Vectors for Sentiment Analysis.” In Proceedings of the 49th Annual Meeting of the Association for Computational Linguistics: Human Language Technologies, 142–50. Portland, Oregon, USA: Association for Computational Linguistics. http://www.aclweb.org/anthology/P11-1015.
Code
import datasets

imdb_dataset = datasets.load_dataset("imdb", split="test")
len(imdb_dataset)
Downloading and preparing dataset imdb/plain_text to /home/matthew/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1...
Dataset imdb downloaded and prepared to /home/matthew/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1. Subsequent calls will reuse this data.
25000
Code
imdb_dataset["text"][0]
'I love sci-fi and am willing to put up with a lot. Sci-fi movies/TV are usually underfunded, under-appreciated and misunderstood. I tried to like this, I really did, but it is to good TV sci-fi as Babylon 5 is to Star Trek (the original). Silly prosthetics, cheap cardboard sets, stilted dialogues, CG that doesn\'t match the background, and painfully one-dimensional characters cannot be overcome with a \'sci-fi\' setting. (I\'m sure there are those of you out there who think Babylon 5 is good sci-fi TV. It\'s not. It\'s clichéd and uninspiring.) While US viewers might like emotion and character development, sci-fi is a genre that does not take itself seriously (cf. Star Trek). It may treat important issues, yet not as a serious philosophy. It\'s really difficult to care about the characters here as they are not simply foolish, just missing a spark of life. Their actions and reactions are wooden and predictable, often painful to watch. The makers of Earth KNOW it\'s rubbish as they have to always say "Gene Roddenberry\'s Earth..." otherwise people would not continue watching. Roddenberry\'s ashes must be turning in their orbit as this dull, cheap, poorly edited (watching it without advert breaks really brings this home) trudging Trabant of a show lumbers into space. Spoiler. So, kill off a main character. And then bring him back as another actor. Jeeez! Dallas all over again.'
Code
import pandas as pd
import onnxruntime as ort
from transformers import AutoTokenizer
from tqdm.auto import tqdm

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
ort_session = ort.InferenceSession(str(MODEL_QUANTIZED_FILE))

def infer(
    ort_session: ort.InferenceSession,
    text: str,
) -> dict[str, float | str]:
    inputs = tokenizer(text, return_tensors="pt", truncation=True)
    py_outputs = python_inference(ort_session=ort_session, **inputs)
    java_outputs = java_inference(**inputs)
    biggest_difference = np.absolute(py_outputs - java_outputs).max()
    return {
        "text": text,
        "python": py_outputs,
        "java": java_outputs,
        "difference": biggest_difference,
    }

difference_df = pd.DataFrame([
    infer(ort_session=ort_session, text=text)
    for text in tqdm(imdb_dataset["text"][:1_000])
])
difference_df.sort_values(by="difference", ascending=False)
text python java difference
651 Wow, what a racist, profane piece of celluloid... [[4.135885, 1.035452, -1.1389924, -2.0513384, ... [[4.135885, 1.035452, -1.1389924, -2.0513384, ... 2.381897e-07
630 Boring, ridicules and stupid "Submerged" is a ... [[4.88149, 2.02761, -1.1743344, -2.7686796, -2... [[4.88149, 2.02761, -1.1743344, -2.7686796, -2... 2.305603e-07
705 This is one of the worst films i've ever seen,... [[4.21962, 2.1770673, -0.780342, -2.687975, -2... [[4.21962, 2.1770673, -0.780342, -2.687975, -2... 2.278137e-07
152 After 15 minutes watching the movie I was aski... [[4.075547, 2.0027628, -0.44526598, -2.436773,... [[4.075547, 2.0027628, -0.44526598, -2.436773,... 2.183228e-07
476 I went to see this movie tonight, trying to ke... [[0.78645927, 2.5583208, 1.5190651, 0.1253709,... [[0.78645927, 2.5583208, 1.5190651, 0.1253709,... 2.131042e-07
... ... ... ... ...
124 Wow, another Kevin Costner hero movie. Postman... [[-0.5090611, 0.18023343, -0.006972208, 0.2196... [[-0.5090611, 0.18023343, -0.006972208, 0.2196... 3.842659e-09
337 "The Secret Life" starts with the worst possib... [[-0.95750105, 0.44672653, 0.697375, 0.6537277... [[-0.95750105, 0.44672653, 0.697375, 0.6537277... 3.810120e-09
967 This show is what happened to The Screen Saver... [[0.108357675, 0.16106683, -0.44998556, -0.111... [[0.108357675, 0.16106683, -0.44998556, -0.111... 3.755035e-09
200 He who fights with monsters might take care le... [[-0.026390638, 0.28594577, -0.09914864, 0.113... [[-0.026390638, 0.28594577, -0.09914864, 0.113... 3.529415e-09
772 I too am a House Party Fan...House Party I is ... [[-0.47317317, 0.080483854, 0.2053096, 0.18421... [[-0.47317317, 0.080483854, 0.2053096, 0.18421... 3.479652e-09

1000 rows × 4 columns

When the largest difference between the two quantization runs is in the order of \(\frac{2}{10000000}\) I think that the Java version of ONNX is a safe substitute for Python.