Using the recent kuprel rewrite to understand and implement DALL-E
Published
June 28, 2022
I got quite frustrated recently attempting to use the JAX version of DALL-E. The problem was the dependencies were incompatible and functionality had been renamed or moved between versions. Google might make clean code but they do that at the cost of backwards compatability, which makes working with it tiresome.
The DALL-E model has been rewritten in PyTorch by kuprel so I am hoping to take that code and reproduce it here. Then I can understand each section of the codebase and how it fits together. As a bonus there will be a locally running version that I can play with, the huggingface hub version times out too frequently.
Code Review
The kuprel version is split into a flax version and a torch version. I’m interested in the torch version so the first thing is to understand how it all fits together.
Dependencies
The project lists torch and flax as dependencies. It also uses wandb without listing it as a dependency so the list is certainly incomplete.
The immediate problem that I see with this is that it has unpinned dependencies. This was the problem that I had with the jax/flax version. While I know that the repository was created in the last day having this is already a bad sign.
The next thing is that it clones another project to use as a submodule. The vqgan_imagenet_f16_16384 huggingface project is cloned into a folder, and it’s not using git submodules to do this. Another bad sign.
I’ve come back this morning and now flax is pinned to 0.4.2 and wandb is in the requirements. Nice.
Python Version
The code has an unstated dependency on Python <3.10.
It manifests itself as another flax/jax problem:
...
File "/home/matthew/Programming/Python/min-dalle/min_dalle/load_params.py", line 5, in <module>
from flax import traverse_util, serialization
File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/flax/__init__.py", line 18, in <module>
from . import core as core
File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/flax/core/__init__.py", line 15, in <module>
from .axes_scan import broadcast as broadcast
File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/flax/core/axes_scan.py", line 19, in <module>
import jax
File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/jax/__init__.py", line 35, in <module>
from jax import config as _config_module
File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/jax/config.py", line 17, in <module>
from jax._src.config import config
File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/jax/_src/config.py", line 29, in <module>
from jax._src import lib
File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/jax/_src/lib/__init__.py", line 41, in <module>
import scipy.signal as _signal
File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/scipy/signal/__init__.py", line 302, in <module>
from .filter_design import *
File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/scipy/signal/filter_design.py", line 16, in <module>
from scipy import special, optimize, fft as sp_fft
File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/scipy/optimize/__init__.py", line 421, in <module>
from ._shgo import shgo
File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/scipy/optimize/_shgo.py", line 9, in <module>
from scipy import spatial
File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/scipy/spatial/__init__.py", line 107, in <module>
from . import distance, transform
File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/scipy/spatial/transform/__init__.py", line 19, in <module>
from .rotation import Rotation, Slerp
ImportError: /home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/scipy/spatial/transform/rotation.cpython-310-x86_64-linux-gnu.so: undefined symbol: _PyGen_Send
Downgrading to Python 3.9 fixes this issue. I’ve opened a ticket about this.
CUDA support
The jax that is installed with flax does not support GPU. Yet another way that jax breaks. Boo boo.
Converting this to fully torch will fix this.
Model Data
The vqgan project contains a configuration and a large flax model. I might be able to skip this.
Finally it is pulling two artifacts from wandb, being mini-1:v0 and mega-1-fp16:v14. These look like the weights for the model. It would be great if they could be written to huggingface to make it easier to download them.
Conversion
I can couple the review of the more concrete parts of the codebase with conversion to my preferred form.
Tokenizer
There is a reimplementation of tokenization in the min_dalle/text_tokenizer.py. I wonder if this can be replaced with a huggingface tokenizer. This would need testing.
I’ve been looking into this and the data that is downloaded from wandb has the files that huggingface uses to load a pretrained tokenizer. The tokenizer declares itself as a DalleBartTokenizer, which appears to come from here. This class extends the BartTokenizerFast with a mixin to load from wandb. I’ve managed to load the tokenizer using BartTokenizerFast.from_pretrained, however that does not establish that it is tokenizing correctly.
Luckily the code prints the tokens for every encoded sentence so it is quite easy to test.
Code
from transformers import BartTokenizerFasttokenizer = BartTokenizerFast.from_pretrained("/data/dall-e/flax/dalle_bart_mini/")tokenizer_test_cases = {"a monster chasing a leaf": [0, 58, 4673, 18436, 58, 3649, 2],# text is lower cased"An Elephant jumping for joy": [0, 101, 7575, 13637, 129, 4591, 2],# non ascii characters are dropped"A Café with a sign outside": [0, 58, 10390, 208, 58, 830, 6132, 2],}for text, expected in tokenizer_test_cases.items(): prepared_text =" "+ text # first fix prepared_text = prepared_text.lower() # second fix actual = tokenizer(prepared_text).input_idsif actual == expected:print(f"Tokenizer works for {text}")else:print(f"Tokenizer fails for {text}")print(f"Encoded version is: {actual} ({tokenizer.decode(actual)})")print(f"Should be: {expected} ({tokenizer.decode(expected)})")
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization.
The tokenizer class you load from this checkpoint is 'DalleBartTokenizer'.
The class this function is called from is 'BartTokenizerFast'.
Tokenizer works for a monster chasing a leaf
Tokenizer works for An Elephant jumping for joy
Tokenizer works for A Café with a sign outside
There is a small amount of preprocessing that is required, but nothing big. There is a bug about incorrect behaviour of the tokenizer that is used in the project so it may be worth sticking with the bart decoder without any preprocessing. Only testing can resolve this.
Model Definition
The model is made from an encoder and a decoder. First the tokenized text is encoded, and then the encoder state is converted into image tokens. Those image tokens are finally converted into the image.
Each part is a separate model.
We can see that both the flax and the huggingface data is required to perform DALL-E inference.
The implementation of the model seems ok. It’s not done in the transformers style as it takes specific options instead of a configuration. Swapping it over to the configuration approach would be best.
Loading Model Data
The model data for the model comes from vqgan, which is in torch, or from a flax bart version which is converted on demand. It would be nice to convert and export the flax bart version to torch.
I’ve downloaded the data for the two models, which failed once. The next thing to do will be to convert these to torch weights. I don’t want to install flax/jax on my blog so I’m going to convert it in the dall-e project itself and copy the code over to here.
This loads the three parts of the model and writes the state dict of each out to a file. I should be able to load this up and then save through huggingface after converting the model over.
One thing that is interesting is that the file size after exporting as pytorch is substantially larger.
Model Conversion
The model has been implemented without using the specialized blocks available in pytorch. I wonder if it is possible to replace them.
There is a lot of code here so for now I am going to copy it over directly and create a composite model.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization.
The tokenizer class you load from this checkpoint is 'DalleBartTokenizer'.
The class this function is called from is 'BartTokenizerFast'.
Code
%%timegenerate("A regal lion wearing a crown")
CPU times: user 5.73 s, sys: 20.2 ms, total: 5.75 s
Wall time: 5.55 s
Code
generate("A rose entwined with ivy in stained glass")
Code
generate("A screenshot of factorio")
Code
generate("A screenshot of minecraft")
Code
generate("A screenshot of diablo")
Code
generate("Our first Azmodan kill")
Code
generate("My painted miniature")
Code
generate("My painted orc miniature")
Code
generate("My painted elven lord miniature")
Code
generate("My painted hippogriff miniature")
Code
generate("A photo of our new puppy")
Code
generate("A stained glass window of our new puppy")
Code
generate("A rocket flying to the moon")
Code
generate("A face")
Code
generate("frida kahlo")
Code
generate("a goat climbing a mountain")
Code
generate("A house made out of clay")
Code
generate("The gingerbread house")
Code
generate("logo of an armchair in the shape of an avocado")
Code
generate("logo of a hacker")
Code
generate("logo of vampire survivors")
Code
generate("spider with googly eyes")
Code
generate("an angel with peacock feather wings")
Code
generate("a googly eyed angel with peacock feather wings")
Code
generate("a googly eyes angel with peacock feather wings")
Code
generate("a googly eyes angel with peacock feather wings")
Code
generate("the correct set of weights for my multilingual internalized prompt model")