How good is the OpenAI Whisper speech to text model? How easy is it to quantize?
speech to text
Published
November 9, 2022
I want to check out the OpenAI Whisper (Alec Radford 2022) model. Since I’ve previously used speech to text as part of my house conversation bot, I might revive that. If I do so then it would be nice to be able to serve it entirely from Javascript, and as such converting the model to ONNX would be an essential part of that.
The first thing to do will be to load the smallest model I can find. There are several models available and the whisper-tiny.en model is 151 MB compared to 6 GB for the whisper-large model. I should be able to hook it up to gradio and see how well it performs.
Pipeline with Manual Evaluation
I’m going to use the huggingface automatic speech recognition pipeline to handle transcribing the audio. This does a very good job of wrapping up the fiddly details. I have found that passing a filename is a reliable way to handle the oddities of my microphone - it’s a good microphone and if I get the numpy array directly then the sample rate can cause issues.
Gradio makes this slightly more difficult than it should be. The inline version that is rendered in the notebook never works for this, and even the local URL fails to trigger the microphone permission dialog. Instead I have to navigate to the public URL.
Code
from transformers import pipelineimport gradio as grimport shutilMODEL_NAME ="openai/whisper-tiny.en"pipe = pipeline("automatic-speech-recognition", MODEL_NAME)def transcribe(filepath: str) ->str: shutil.copy(filepath, "audio.wav") # for reproducibility text = pipe(filepath)["text"]return textapp = gr.Interface( fn=transcribe, inputs=gr.Audio(source="microphone", type="filepath"), outputs="text")app.launch(share=True) # I have to use the public URL to get the microphone to work...
I spoke the opening sentence of Moby Dick. I’m not a great orator but it was well recorded, and the transcription was:
Call me Ishmael. Some years ago, never mind how long precisely, having little or no money in my purse, and nothing particular to interest me, I’m sure. I thought I would say about a little and see the watery part of the world.
The real opening sentence is:
Call me Ishmael. Some years ago - never mind how long precisely - having little or no money in my purse, and nothing particular to interest me on shore, I thought I would sail about a little and see the watery part of the world.
Considering how bad the transcription was the last time I tried it, this is a wild improvement. A larger model might well be able to address the two transcription errors.
Quantization and Optimization
So can this model be optimized? Using the huggingface optimum library was straightforward last time, let’s see if it works with this.
It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
KeyError: " model type is not supported yet. Only ['albert', 'bart', 'bert', 'big_bird', 'camembert', 'codegen', 'deberta', 'deberta-v2', 'distilbert', 'electra', 'gpt2', 'gpt_neo', 'mt5', 'marian', 'roberta', 'xlm-roberta'] are supported. If you want to support please propose a PR or open up an issue."
So this … appears to have worked? There are quite a few warnings and the optimization stage completely failed. Quantizing the model has reduced the size by about a third from 145 MB to 97 MB.
The model 'ORTModelForCustomTasks' is not supported for automatic-speech-recognition. Supported models are ['SpeechEncoderDecoderModel', 'Speech2TextForConditionalGeneration', 'WhisperForConditionalGeneration', 'Data2VecAudioForCTC', 'HubertForCTC', 'MCTCTForCTC', 'SEWForCTC', 'SEWDForCTC', 'UniSpeechForCTC', 'UniSpeechSatForCTC', 'Wav2Vec2ForCTC', 'Wav2Vec2ConformerForCTC', 'WavLMForCTC'].
So that’s a no. In order to handle this I need to create a wrapper that can map between the ONNX code and the Whisper code.
Manual Invocation
To understand how I can quantize this and create that wrapper I first need to understand what the pipeline and model are doing. This is best done from the outside in, so let’s look more closely at the pipeline first.
Pipeline Investigation
The pipeline is a sequence to sequence encoder, and that means that the output is formed from a recursive generation process which takes the encoded input with the current output to generate the next token of output. Huggingface models have a generate method on them which comes from the transformers.generation_utils.GenerationMixin. This can repeatedly sample the model output to implement such generation strategies as greedy search, beam search, and probabilistic searches.
We can determine the generation mode that is used for the model by reviewing some of the model configuration parameters.
This code comes from here but this code has changed since the version that I am using, so it might well be different by the time you look at it. For reference it currently looks like:
num_beams, num_beam_groups and do_sample come from the model configuration. You can see how is_constraint_gen_mode and is_contrastive_search_gen_mode are calculated - these fundamentally rely on arguments passed to the generate method, and default to False.
We can determine if the model defaults to greedy generation by running this same test:
Code
( (pipe.model.config.num_beams ==1)and (pipe.model.config.num_beam_groups ==1)and pipe.model.config.do_sample isFalse# and not is_constraint_gen_mode# and not is_contrastive_search_gen_mode)
True
Generation being greedy is not surprising as we are not generating new content, we are instead attempting to transcribe existing content.
So what happens to the audio file before it gets greedily decoded? There is quite a lot of code here but it’s easy to follow. I’ve annotated this with the values at each point.
Code
# MATT - default model kwargs is empty dictionarymodel_kwargs = {}...# 2. Define model inputs# inputs_tensor has to be defined# model_input_name is defined if model-specific keyword input is passed# otherwise model_input_name is None# all model-specific keyword inputs are removed from `model_kwargs`inputs_tensor, model_input_name, model_kwargs =self._prepare_model_inputs(inputs, bos_token_id, model_kwargs)# MATT -# input tensor is a 3d tensor,# model_input_name is 'input_features' and# model_kwargs remains an empty dictbatch_size = inputs_tensor.shape[0] # MATT - 1# 3. Define other model kwargsmodel_kwargs["output_attentions"] = output_attentions # MATT - Falsemodel_kwargs["output_hidden_states"] = output_hidden_states # MATT - Falsemodel_kwargs["use_cache"] = use_cacheaccepts_attention_mask ="attention_mask"inset(inspect.signature(self.forward).parameters.keys())# MATT - accepts_attention_mask is Falserequires_attention_mask ="encoder_outputs"notin model_kwargs# MATT - requires_attention_mask is Trueif model_kwargs.get("attention_mask", None) isNoneand requires_attention_mask and accepts_attention_mask:# MATT - not executed, accepts_attention_mask is False model_kwargs["attention_mask"] =self._prepare_attention_mask_for_generation( inputs_tensor, pad_token_id, eos_token_id )ifself.config.is_encoder_decoder and"encoder_outputs"notin model_kwargs:# MATT - this is executed# if model is encoder decoder encoder_outputs are created# and added to `model_kwargs` model_kwargs =self._prepare_encoder_decoder_kwargs_for_generation( inputs_tensor, model_kwargs, model_input_name )# MATT - model_kwargs now holds# 'encoder_outputs' which is a BaseModelOutput with last_hidden_state# 4. Prepare `input_ids` which will be used for auto-regressive generationifself.config.is_encoder_decoder:# MATT - This is true for the model input_ids =self._prepare_decoder_input_ids_for_generation( batch_size, # MATT - 1 decoder_start_token_id=decoder_start_token_id, # MATT - None bos_token_id=bos_token_id, # MATT - model.config.bos_token_id model_kwargs=model_kwargs, # MATT - {'encoder_outputs'...} device=inputs_tensor.device, )# MATT - this method is a funny one.# the equivalent code is:if model_kwargs isnotNoneand"decoder_input_ids"in model_kwargs: input_ids = model_kwargs.pop("decoder_input_ids")else: decoder_start_token_id =self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)if device isNone: device =self.device input_ids = torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id# MATT - the model_kwargs do not contain decoder_input_ids# so we get a single element tensor with the value of decoder_start_token_idelse:# if decoder-only then inputs_tensor has to be `input_ids` input_ids = inputs_tensor
To summarize, the model passes the inputs through the encoder to produce the encoder_outputs. Then the model prepares the generation by defining the very start of the output - the decoder_start_token_id.
It’s clear to me that this structure both permits a wide variety of models, as the encoder can be skipped or work quite differently, and that you can continue generation by passing decoder_input_ids.
Manual Model Invocation
To replace the default model with another we need to find the different points where a model is invoked and then find a way to adapt them. Given that we are creating an encoded embedding of the inputs followed by the generation, we can start by trying to create the generator just using the three parts (the encoder, the decoder, and the language model head).
Code
from typing import Unionfrom pathlib import Pathimport torchfrom transformers import AutoTokenizer, Pipelinefrom transformers.pipelines.audio_utils import ffmpeg_read@torch.inference_mode()def pipeline_generate(pipe: Pipeline, audio: Union[str, Path], max_length: int=None) ->str:if max_length isNone: max_length = pipe.model.config.max_lengthassert max_length <= pipe.model.config.max_length audio_features = load_audio(pipe=pipe, file=audio) encoder = pipe.model.model.encoder decoder = pipe.model.model.decoder lm_head = pipe.model.proj_out tokenizer = pipe.tokenizer bos_token_id = pipe.model.config.bos_token_id eos_token_id = pipe.model.config.eos_token_idreturn manual_generate( audio_features=audio_features, encoder=encoder, decoder=decoder, lm_head=lm_head, tokenizer=tokenizer, max_length=max_length, bos_token_id=bos_token_id, eos_token_id=eos_token_id, )def manual_generate(*, audio_features: torch.Tensor, encoder, decoder, lm_head, tokenizer: AutoTokenizer, max_length: int, bos_token_id: int, eos_token_id: int,) ->str: encoded_input = encoder(audio_features)[0] # or accessed using .last_hidden_state tokens = torch.tensor([[bos_token_id]], dtype=torch.long, device=pipe.device)# this takes the most probable token at each step until the stopping condition is reached.# it would be an error to while tokens[0, -1] != eos_token_id and tokens.shape[1] < max_length: output = decoder( input_ids=tokens, encoder_hidden_states=encoded_input, )[0] # or accessed using .last_hidden_state logits = lm_head(output) # this has shape (batch_size=1, positions, tokens=51,864) tokens = torch.cat([ tokens, logits[:, -1].argmax(dim=-1)[:, None], ], dim=1)return tokenizer.decode(tokens[0])def load_audio(pipe: Pipeline, file: Union[str, Path]) -> torch.Tensor:withopen(file, "rb") as handle: audio_bytes = handle.read()# resample audio to the sampling rate desired using ffmpeg sampling_rate = pipe.feature_extractor.sampling_rate audio_input = ffmpeg_read( audio_bytes, sampling_rate=sampling_rate, )# this is a transformers.models.whisper.feature_extraction_whisper.WhisperFeatureExtractor# > This class extracts mel-filter bank features from raw speech using a custom numpy implementation of the `Short Time# > Fourier Transform` which should match pytorch's `torch.stft` equivalent.# It's not a model so I'm not quantizing it audio_features = pipe.feature_extractor( audio_input, sampling_rate=sampling_rate, return_tensors="pt", )return audio_features["input_features"]
Code
pipeline_generate(pipe, "audio.wav")
torch.Size([1, 1500, 384])
"<|startoftranscript|><|notimestamps|> Call me Ishmael. Some years ago, never mind how long precisely, having little or no money in my purse, and nothing particular to interest me, I'm sure. I thought I would say about a little and see the watery part of the world.<|endoftext|>"
So we’ve got a reproduction of the original pipeline. If we were able to export the different parts to ONNX then we could try to use this same code on the same input.
Exporting Models to ONNX
I have to use the torch.onnx package to export instead of huggingface as the configuration of this model is not yet supported, as we saw earlier. The three separate parts are reasonably direct pytorch models so it should be possible to export them directly. The configuration options that I am using come from the ONNX python quickstart.
Code
import torch.onnxtorch.onnx.export( pipe.model.model.encoder, # model being run torch.randn(1, 80, 3000).to("cpu"), # model input (or a tuple for multiple inputs) (MODEL_FOLDER /"encoder.onnx"), # where to save the model (can be a file or file-like object) export_params=True, # store the trained parameter weights inside the model file opset_version=10, # the ONNX version to export the model to do_constant_folding=True, # whether to execute constant folding for optimization input_names = ["input_features"], # the model's input names output_names = ["last_hidden_state"], # the model's output names dynamic_axes = { # variable length axes'input_features': {0 : "batch_size", 1: "tokens"},'last_hidden_state': {0 : "batch_size", 1: "tokens"} },)
/home/matthew/.local/share/virtualenvs/blog-1tuLwbZm/lib/python3.10/site-packages/transformers/models/whisper/modeling_whisper.py:200: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
/home/matthew/.local/share/virtualenvs/blog-1tuLwbZm/lib/python3.10/site-packages/transformers/models/whisper/modeling_whisper.py:239: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
Code
import torch.onnxtorch.onnx.export( pipe.model.model.decoder, # model being run ( torch.randint( low=0, high=pipe.tokenizer.vocab_size, size=(1, pipe.model.config.max_length), dtype=int, ), torch.ones(1, pipe.model.config.max_length, dtype=int, ), torch.randn(1, pipe.model.config.max_source_positions, pipe.model.config.d_model, ).to("cpu"), ), # model input (or a tuple for multiple inputs) (MODEL_FOLDER /"decoder.onnx"), # where to save the model (can be a file or file-like object) input_names = ["input_ids","attention_mask","encoder_hidden_states", ], # the model's input names output_names = ["last_hidden_state"], # the model's output names dynamic_axes = { # variable length axes'input_ids': {0 : "batch_size", 1: "tokens"},'attention_mask': {0 : "batch_size", 1: "tokens"},'encoder_hidden_states': {0 : "batch_size", 1: "tokens"},'last_hidden_state': {0 : "batch_size", 1: "tokens"}, },)
/home/matthew/.local/share/virtualenvs/blog-1tuLwbZm/lib/python3.10/site-packages/transformers/models/whisper/modeling_whisper.py:750: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if input_shape[-1] > 1:
/home/matthew/.local/share/virtualenvs/blog-1tuLwbZm/lib/python3.10/site-packages/transformers/models/whisper/modeling_whisper.py:74: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
/home/matthew/.local/share/virtualenvs/blog-1tuLwbZm/lib/python3.10/site-packages/transformers/models/whisper/modeling_whisper.py:756: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if attention_mask.shape[-1] > input_shape[-1] > 0:
/home/matthew/.local/share/virtualenvs/blog-1tuLwbZm/lib/python3.10/site-packages/transformers/models/whisper/modeling_whisper.py:207: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
Code
import torch.onnxtorch.onnx.export( pipe.model.proj_out, # model being run ( torch.randn(1, pipe.model.config.max_source_positions, pipe.model.config.d_model, ).to("cpu"), ), # model input (or a tuple for multiple inputs) (MODEL_FOLDER /"lm_head.onnx"), # where to save the model (can be a file or file-like object) input_names = ["last_hidden_state", ], # the model's input names output_names = ["lm_logits"], # the model's output names dynamic_axes = { # variable length axes'last_hidden_state': {0 : "batch_size", 1: "tokens"}, },)
To make the onnx inference fit with the manual_generate function, above, I need to define some adapters that will correctly invoke the onnx runtime. This is because the onnx runtime is invoked in a slightly different way, and it accepts numpy arrays instead of torch tensors. The torch invocation:
"<|startoftranscript|><|notimestamps|> Call me Ishmael. Some years ago, never mind how long precisely, having little or no money in my purse, and nothing particular to interest me, I'm sure. I thought I would say about a little and see the watery part of the world.<|endoftext|>"
Woo! It works!!
This took quite a while to get working.
At this point it would be possible to quantize the model and see how it compares.
I wanted to quantize the model to make a static webpage that can use whisper to recognize speech. Given that the model is already more than 100 MB I think this is an unrealistic goal. Making a webpage with that kind of download is cruel and unusual, even for me.
I’ve spent quite a while on this and I want to investigate some other things, so perhaps quantization is a task for another time.