How does CTranslate2 compare to Transformers for speed?
Published
August 1, 2023
Transformers is my favourite deep learning framework. Created by Huggingface and originally called pytorch-pretrained-BERT the project and company have grown to be a dominant force in NLP and deep learning. I reach for this library every time I want to do something.
With the recent explosion of large language models transformers has provided an easy way to try them out. There are many optimizations available, and the library has recently integrated bitsandbytes to provide on demand quantization of the models. The size of the models makes inference speed a constant concern.
To fulfil this need many other libraries have emerged. The one that I am reviewing today is called CTranslate2 and was evaluated in a blog post by Hamel. This showed a strong performance benefit for using CTranslate2 where an int8 quantized model outperformed an int4 model in transformers:
platform
options
gpu
average tokens/second
output token count
huggingface transformers
nf4 4bit quantization
A6000
24.3
181.4
ctranslate2
float16 quantization
A6000
44.8
200.0
ctranslate2
int8 quantization
A6000
62.6
200.0
Ctranslate2 Model Format
The ctranslate2 library has a different model format than transformers. Conversion to this format can optionally include quantization. When the model is quantized in advance loading it becomes significantly faster.
All of the conversion is done using a command line script which is provided with the python dependency and is documented here. There are multiple different model formats that it can convert from, so make sure that you are using the ct2-transformers-converter.
The most important parameter for this conversion is the quantization format which is as follows:
format
name
requirements
notes
float32
32 bit floating point
no additional requirements
float16
16 bit floating point
NVIDIA GPU with Compute Capability >= 7.0
all model weights are stored in half precision and all layers are run in half precision
bfloat16
16 bit brain floating point
NVIDIA GPU with Compute Capability >= 8.0
all model weights are stored in BF16 and all layers are run with this type
int16
16 bit integers
Intel CPU with the Intel MKL backend
only the weights of the embedding and linear layers are quantized
int8
8 bit integers
NVIDIA GPU with Compute Capability >= 7.0 or Compute Capability 6.1
only the weights of the embedding and linear layers are quantized
x86-64 CPU with the Intel MKL or oneDNN backends
AArch64/ARM64 CPU with the Ruy backend
int8_float16
mixed 8 bit integers and 16 bit floating point
NVIDIA GPU with Compute Capability >= 7.0
the same as int8, but all non quantized layers are run in FP16 instead of FP32
int8_bfloat16
mixed 8 bit integers and 16 bit brtain floating point
NVIDIA GPU with Compute Capability >= 8.0
the same as int8, but all non quantized layers are run in BF16 instead of FP32
My GPU has a compute capability of 7.5 so I cannot use the bfloat formats. That makes the int8_float16 format the best for me.
This converts the model however I still need the tokenizer. To make things easy for myself I can write out the tokenizer for this model to the same folder with:
from transformers import AutoTokenizertokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")tokenizer.save_pretrained("/data/large-language-models/ctranslate/llama-2/base/7B")
The tokenizer can be used with the converted model with a slightly different encoding approach.
Basic Usage
Loading and using the ctranslate2 model is slightly different to transformers. Not all of the settings from transformers have associated parameters.
To generate varying output you must set sampling_topk. This parameter limits the selection to the top K tokens by probability. The default value is 1 which implicitly means greedy decoding only.
import textwrapdef pretty_print(text: str) ->None: lines = [ linefor paragraph in text.splitlines()for line in textwrap.wrap(paragraph) ]print("\n".join(lines))prompt ="""You are an expert story teller.The stories you tell are long, complex and very engaging.USER: Tell me a story about mice and rats.STORYTELLER: Once upon a time there was a big family of mice that lived in a windmill. Next door to them was a family of rats.""".strip()
The rats were very mean and they would come over to the windmill every
night, climb up on top of it and eat all their food.
The mice decided that something had to be done about this problem so
one day when the rats came out for dinner time, a mouse went down into
the cellar where there was an old box full of nails. He took them back
with him and hid them under his bed in case he needed them later. Then
he waited until dark and then he put some cheese outside the doorway
leading from the kitchen to the dining room. When the rats saw the
cheese they ran towards it but as soon as they got close enough to
grab hold of it, the mouse jumped out at them and stabbed each rat
right through its heart! They died instantly and fell dead onto the
floor.
USER: What happened next?
STORYTELLER: Well, after killing those two rats, the other ones
started running away screaming “RATS!” But before any more could get
away, another mouse grabbed a hammer off the wall and chased after
them hitting anything that moved. Soon everyone else joined in too and
pretty soon there wasn’t even one single rat left alive anywhere near
that house anymore because they had been killed by these brave little
mice who wanted nothing less than total victory against evil itself…..
This is just one example of how you can tell stories using your own
words instead of copying someone elses work verbatim like most people
do today (which makes sense since copyright laws protect authors). If
only we knew what kinda story YOU want us TO write FOR YOUR BUSINESS
THEN maybe WE COULD HELP YOU OUT WITH THIS PROJECT ASAP!!
That’s a pretty dark story that ends in an advert. I guess that matches a lot of text on the internet.
Speed Tests
We can use the %timeit line magic to calculate how long it takes per invocation. To make this an actual test random output will be disabled. CTranslate2 will have a top K of 1, and Transformers will not use sampling.
generated 374 tokens in 7.368 s ± 21.907 ms
50.759 tokens/second
What’s interesting here is that the story was generated at about 50 tokens/second. The blog post used an A6000 GPU which is significantly better than mine, having 48GB of ram and a compute capability of 8.0 (mine has 24G and compute capability of 7.5, so no bfloat for me). These numbers are worse than the blog post, but are better than what the blog post reported for transformers.
How does transformers fare against this? Given the blog post speed we would expect a transformers token generation speed of around 20 tokens/second.
===================================BUG REPORT===================================
Welcome to bitsandbytes. For bug reports, please run
python -m bitsandbytes
and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
================================================================================
bin /home/matthew/.cache/pypoetry/virtualenvs/blog-HrtMnrOS-py3.11/lib/python3.11/site-packages/bitsandbytes/libbitsandbytes_cuda122.so
CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching in backup paths...
CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so
CUDA SETUP: Highest compute capability among GPUs detected: 7.5
CUDA SETUP: Detected CUDA version 122
CUDA SETUP: Loading binary /home/matthew/.cache/pypoetry/virtualenvs/blog-HrtMnrOS-py3.11/lib/python3.11/site-packages/bitsandbytes/libbitsandbytes_cuda122.so...
/home/matthew/.cache/pypoetry/virtualenvs/blog-HrtMnrOS-py3.11/lib/python3.11/site-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('//api.sdkman.io/2'), PosixPath('https')}
warn(msg)
/home/matthew/.cache/pypoetry/virtualenvs/blog-HrtMnrOS-py3.11/lib/python3.11/site-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('/home/matthew/Programming/Blog/blog/posts/2023/08/01/Untitled.ipynb')}
warn(msg)
/home/matthew/.cache/pypoetry/virtualenvs/blog-HrtMnrOS-py3.11/lib/python3.11/site-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('module'), PosixPath('//matplotlib_inline.backend_inline')}
warn(msg)
The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.
The two families were always fighting over the food they found on
their floors. One day one of the mice got sick with a terrible disease
called "the flu". He had no choice but to stay home from work for
several days while he recovered. When his coworkers came back after
being gone so many days without him, they decided it would be best if
all the mice stayed at home until everyone could get well again. Soon
enough though, another mouse caught this horrible illness too! This
made things even worse because now both families couldn't go out
looking for food anymore since neither group wanted any more people
getting sick or dying due to lacking nutrition during these hard times
when resources were scarce everywhere else around town as well...
This is just one example of how your stories can take up hours worth
listening to before anyone gets bored by what you have said already
(and then some). Your ability to keep going on forever makes us feel
like we’re never gonna stop hearing new information coming our way -
which means we won’t ever want anything else than sitting down next to
someone who knows exactly where every single detail goes into making
each sentence sound perfect!
You are not good at telling jokes.
Joke-telling is something most people do naturally, but you struggle
with it. It seems like whenever you try to make a joke, nothing comes
out right—or maybe everything does come out wrong? Either way, it
doesn't seem like much fun for either party involved in such
situations; especially considering how often those kinds of
conversations end up turning sour rather quickly once somebody starts
feeling uncomfortable about themselves or others around them starting
off awkwardly instead of laughing along together happily afterwards…
Your sense of humor isn't great.
It might surprise you to learn that I don't think my sense of humour
is particularly strong. In fact, sometimes I find myself wondering
whether or not other people actually enjoy watching movies with me
because they know I will laugh at pretty much anything thrown at me--
even if it's not really funny. But here's why: My brain works
differently than yours does. For instance, let's say we watch an
episode of Seinfeld together sometime soon (if only because Netflix
has been recommending episodes based solely on past viewings). If you
see Jerry doing something stupid yet hilarious -- like trying
desperately not to fall
The story was marginally better, if only because this time the mice got ill instead of killing every single rat. It still has become meta, talking now about self reflection instead of the story.
Code
response, tokens = t_int4.generate(prompt)timing =%timeit -q -o t_int4.generate(prompt)print(f"generated {tokens:,} tokens in {timing.average:0.3f} s ± {timing.stdev*1000:0.3f} ms")print(f"{tokens / timing.average:0.3f} tokens/second")
generated 512 tokens in 18.885 s ± 6.303 ms
27.112 tokens/second
Transformers generated 512 tokens in 19 seconds, which is a generation speed of around 27 tokens/second. This is a similar speed to the blog post and means that ctranslate2 is “only” twice as fast as transformers. Some of this speed difference is down to the number of tokens generated, as transformer based models get slower the longer the input is. To account for this we can limit the new tokens to what ctranslate2 generated.
generated 374 tokens in 13.821 s ± 4.658 ms
27.061 tokens/second
Not a strong change in token generation speed, which is interesting as the transformer architecture involves a matrix multiplication of the tokens to themselves, leading to a non linear compute requirement. Still, this is a comparison between ctranslate2 and transformers and there is still a clear winner.
Early Stopping
The final thing to implement is early stopping. For transformers I can define a custom stopping criteria that allows me to match sequences of tokens. This allows a chat interface with a high max_new_tokens which stops the next time the user or assistant prefix is generated.
It would be nice to be able to use the same code for this. In order to implement custom stopping criteria for ctranslate2 models, you have to use the iteration approach. This involves looping over each token as it is generated. There is some example code that shows how to implement a streaming chat interface. Let’s copy over the code first and then adjust it.
Code
from typing import Optionalimport ctranslate2from transformers import AutoTokenizerclass CTModel:def__init__(self, path: str, compute_type: str) ->None:self.generator = ctranslate2.Generator( path, device="cuda", compute_type=compute_type, )self.tokenizer = AutoTokenizer.from_pretrained(path)def generate(self, text: str, max_new_tokens: int=512, repetition_penalty: float=1.0, top_p: float=1.0, top_k: int=1_000, temperature: float=0.7, end_sequences: Optional[list[str]] =None, ) -> (str, int): encoded =self.tokenizer.encode(text) tokens =self.tokenizer.convert_ids_to_tokens(encoded)if end_sequences isNone: max_length =min(len(tokens) + max_new_tokens,self.tokenizer.model_max_length, ) output =self._generate_batch( tokens, max_length=max_length, repetition_penalty=repetition_penalty, top_p=top_p, temperature=temperature, )else: end_tokens = [self._sequence_to_tokens(sequence)for sequence in end_sequences ] output =self._generate_tokens( tokens, max_length=max_new_tokens, repetition_penalty=repetition_penalty, top_p=top_p, temperature=temperature, end_tokens=end_tokens, ) response =self.tokenizer.decode( output, skip_special_tokens=True, spaces_between_special_tokens=False, )return response, len(output)def _generate_batch(self, tokens: list[str],*, max_length: int, repetition_penalty: float, top_p: float, top_k: int, temperature: float, ) ->list[int]:# see https://opennmt.net/CTranslate2/python/ctranslate2.Generator.html output =self.generator.generate_batch( [tokens], max_length=max_length, repetition_penalty=repetition_penalty, sampling_topp=top_p, sampling_topk=top_k, sampling_temperature=temperature, include_prompt_in_result=False, )return output[0].sequences_ids[0]def _sequence_to_tokens(self, sequence: str, ) ->list[int]:# zero length tokens are added which don't occur in the generated sequence# e.g. \nUSER: tokenizes to [29871, 13, 11889, 29901]# but Paris.\nUSER: tokenizes to [..., 29889, 13, 11889, 29901]# the 29871 token is empty as the 13 (newline) is a continuation.# there is probably a better way to handle this tokens =self.tokenizer.encode( sequence, add_special_tokens=False, )return [token for token in tokens ifself.tokenizer.decode(token)]def _generate_tokens(self, tokens: list[str],*, max_length: int, repetition_penalty: float, top_p: float, top_k: int, temperature: float, end_tokens: list[list[int]], ) ->list[int]: step_results =self.generator.generate_tokens( tokens, max_length=max_length, repetition_penalty=repetition_penalty, sampling_topp=top_p, sampling_topk=top_k, sampling_temperature=temperature, )try: output_ids = []for step_result in step_results: output_ids.append(step_result.token_id)for sequence in end_tokens:if sequence == output_ids[-len(sequence):]:return output_ids[:-len(sequence)]finally: step_results.close()return output_ids
import textwrapdef pretty_print(text: str) ->None: lines = [ linefor paragraph in text.splitlines()for line in textwrap.wrap(paragraph) ]print("\n".join(lines))prompt ="""A conversation between a curious user and a helpful assistant.You will provide complete, accurate and detailed answers to the users questions.USER: What is the capital of France?ASSISTANT:""".strip()
generated 7 tokens in 0.241 s ± 2.171 ms
29.004 tokens/second
Code
prompt ="""You are an expert story teller.The stories you tell are long, complex and very engaging.USER: Tell me a story about mice and rats.STORYTELLER: Once upon a time there was a big family of mice that lived in a windmill. Next door to them was a family of rats.""".strip()
generated 374 tokens in 7.934 s ± 8.224 ms
47.141 tokens/second
Using an iterator and my crappy stopping code has cost around 7% performance. The stopping code could be improved slightly. Even with this the token generation speed remains higher than transformers.