Reproducing DALL-E in PyTorch

Using the recent kuprel rewrite to understand and implement DALL-E
Published

June 28, 2022

I got quite frustrated recently attempting to use the JAX version of DALL-E. The problem was the dependencies were incompatible and functionality had been renamed or moved between versions. Google might make clean code but they do that at the cost of backwards compatability, which makes working with it tiresome.

The DALL-E model has been rewritten in PyTorch by kuprel so I am hoping to take that code and reproduce it here. Then I can understand each section of the codebase and how it fits together. As a bonus there will be a locally running version that I can play with, the huggingface hub version times out too frequently.

Code Review

The kuprel version is split into a flax version and a torch version. I’m interested in the torch version so the first thing is to understand how it all fits together.

Dependencies

The project lists torch and flax as dependencies. It also uses wandb without listing it as a dependency so the list is certainly incomplete.

The immediate problem that I see with this is that it has unpinned dependencies. This was the problem that I had with the jax/flax version. While I know that the repository was created in the last day having this is already a bad sign.

The next thing is that it clones another project to use as a submodule. The vqgan_imagenet_f16_16384 huggingface project is cloned into a folder, and it’s not using git submodules to do this. Another bad sign.

I’ve come back this morning and now flax is pinned to 0.4.2 and wandb is in the requirements. Nice.

Python Version

The code has an unstated dependency on Python <3.10.

It manifests itself as another flax/jax problem:

...
  File "/home/matthew/Programming/Python/min-dalle/min_dalle/load_params.py", line 5, in <module>
    from flax import traverse_util, serialization
  File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/flax/__init__.py", line 18, in <module>
    from . import core as core
  File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/flax/core/__init__.py", line 15, in <module>
    from .axes_scan import broadcast as broadcast
  File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/flax/core/axes_scan.py", line 19, in <module>
    import jax
  File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/jax/__init__.py", line 35, in <module>
    from jax import config as _config_module
  File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/jax/config.py", line 17, in <module>
    from jax._src.config import config
  File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/jax/_src/config.py", line 29, in <module>
    from jax._src import lib
  File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/jax/_src/lib/__init__.py", line 41, in <module>
    import scipy.signal as _signal
  File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/scipy/signal/__init__.py", line 302, in <module>
    from .filter_design import *
  File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/scipy/signal/filter_design.py", line 16, in <module>
    from scipy import special, optimize, fft as sp_fft
  File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/scipy/optimize/__init__.py", line 421, in <module>
    from ._shgo import shgo
  File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/scipy/optimize/_shgo.py", line 9, in <module>
    from scipy import spatial
  File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/scipy/spatial/__init__.py", line 107, in <module>
    from . import distance, transform
  File "/home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/scipy/spatial/transform/__init__.py", line 19, in <module>
    from .rotation import Rotation, Slerp
ImportError: /home/matthew/.cache/pypoetry/virtualenvs/min-dalle-13Fe8Z6x-py3.10/lib/python3.10/site-packages/scipy/spatial/transform/rotation.cpython-310-x86_64-linux-gnu.so: undefined symbol: _PyGen_Send

Downgrading to Python 3.9 fixes this issue. I’ve opened a ticket about this.

CUDA support

The jax that is installed with flax does not support GPU. Yet another way that jax breaks. Boo boo.

Converting this to fully torch will fix this.

Model Data

The vqgan project contains a configuration and a large flax model. I might be able to skip this.

Finally it is pulling two artifacts from wandb, being mini-1:v0 and mega-1-fp16:v14. These look like the weights for the model. It would be great if they could be written to huggingface to make it easier to download them.

Conversion

I can couple the review of the more concrete parts of the codebase with conversion to my preferred form.

Tokenizer

There is a reimplementation of tokenization in the min_dalle/text_tokenizer.py. I wonder if this can be replaced with a huggingface tokenizer. This would need testing.

I’ve been looking into this and the data that is downloaded from wandb has the files that huggingface uses to load a pretrained tokenizer. The tokenizer declares itself as a DalleBartTokenizer, which appears to come from here. This class extends the BartTokenizerFast with a mixin to load from wandb. I’ve managed to load the tokenizer using BartTokenizerFast.from_pretrained, however that does not establish that it is tokenizing correctly.

Luckily the code prints the tokens for every encoded sentence so it is quite easy to test.

Code
from transformers import BartTokenizerFast

tokenizer = BartTokenizerFast.from_pretrained("/data/dall-e/flax/dalle_bart_mini/")

tokenizer_test_cases = {
    "a monster chasing a leaf": [0, 58, 4673, 18436, 58, 3649, 2],
    # text is lower cased
    "An Elephant jumping for joy": [0, 101, 7575, 13637, 129, 4591, 2],
    # non ascii characters are dropped
    "A Café with a sign outside": [0, 58, 10390, 208, 58, 830, 6132, 2],
}

for text, expected in tokenizer_test_cases.items():
    prepared_text = " " + text # first fix
    prepared_text = prepared_text.lower() # second fix

    actual = tokenizer(prepared_text).input_ids
    if actual == expected:
        print(f"Tokenizer works for {text}")
    else:
        print(f"Tokenizer fails for {text}")
        print(f"Encoded version is: {actual} ({tokenizer.decode(actual)})")
        print(f"Should be: {expected} ({tokenizer.decode(expected)})")
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DalleBartTokenizer'. 
The class this function is called from is 'BartTokenizerFast'.
Tokenizer works for a monster chasing a leaf
Tokenizer works for An Elephant jumping for joy
Tokenizer works for A Café with a sign outside

There is a small amount of preprocessing that is required, but nothing big. There is a bug about incorrect behaviour of the tokenizer that is used in the project so it may be worth sticking with the bart decoder without any preprocessing. Only testing can resolve this.

Model Definition

The model is made from an encoder and a decoder. First the tokenized text is encoded, and then the encoder state is converted into image tokens. Those image tokens are finally converted into the image.

Each part is a separate model.

We can see that both the flax and the huggingface data is required to perform DALL-E inference.

The implementation of the model seems ok. It’s not done in the transformers style as it takes specific options instead of a configuration. Swapping it over to the configuration approach would be best.

Loading Model Data

The model data for the model comes from vqgan, which is in torch, or from a flax bart version which is converted on demand. It would be nice to convert and export the flax bart version to torch.

I’ve downloaded the data for the two models, which failed once. The next thing to do will be to convert these to torch weights. I don’t want to install flax/jax on my blog so I’m going to convert it in the dall-e project itself and copy the code over to here.

This is what I have ended up with:

Code
from min_dalle.models.vqgan_detokenizer import VQGanDetokenizer
from min_dalle.models.dalle_bart_encoder_torch import DalleBartEncoderTorch
from min_dalle.models.dalle_bart_decoder_torch import DalleBartDecoderTorch
from min_dalle.load_params import (
    load_vqgan_torch_params,
    convert_dalle_bart_torch_from_flax_params,
    load_dalle_bart_flax_params,
)
from min_dalle.generate_image import load_dalle_bart_metadata
import torch

def convert(name: str) -> None:
    path = f"./pretrained/dalle_bart_{name}"
    config, *_ = load_dalle_bart_metadata(path)
    params = load_dalle_bart_flax_params(path)

    encoder = load_encoder(config, params)
    decoder = load_decoder(config, params)
    detokenizer = load_detokenizer()

    torch.save(encoder.state_dict(), f"./converted/{name}/encoder.pt")
    torch.save(decoder.state_dict(), f"./converted/{name}/decoder.pt")
    torch.save(detokenizer.state_dict(), f"./converted/{name}/detokenizer.pt")

def load_encoder(config: dict, params: dict) -> DalleBartEncoderTorch:
    encoder = DalleBartEncoderTorch(
        layer_count = config['encoder_layers'],
        embed_count = config['d_model'],
        attention_head_count = config['encoder_attention_heads'],
        text_vocab_count = config['encoder_vocab_size'],
        text_token_count = config['max_text_length'],
        glu_embed_count = config['encoder_ffn_dim']
    )
    encoder_params = convert_dalle_bart_torch_from_flax_params(
        params.pop('encoder'),
        layer_count=config['encoder_layers'],
        is_encoder=True
    )
    encoder.load_state_dict(encoder_params, strict=False)
    return encoder

def load_decoder(
    config: dict,
    params: dict,
    image_token_count: int = 256
) -> DalleBartDecoderTorch:
    decoder = DalleBartDecoderTorch(
        image_vocab_size = config['image_vocab_size'],
        image_token_count = config['image_length'],
        sample_token_count = image_token_count,
        embed_count = config['d_model'],
        attention_head_count = config['decoder_attention_heads'],
        glu_embed_count = config['decoder_ffn_dim'],
        layer_count = config['decoder_layers'],
        batch_count = 2,
        start_token = config['decoder_start_token_id'],
        is_verbose = True
    )
    decoder_params = convert_dalle_bart_torch_from_flax_params(
        params.pop('decoder'),
        layer_count=config['decoder_layers'],
        is_encoder=False
    )
    decoder.load_state_dict(decoder_params, strict=False)
    return decoder

def load_detokenizer() -> VQGanDetokenizer:
    print("detokenizing image")
    model_path = './pretrained/vqgan'
    params = load_vqgan_torch_params(model_path)
    detokenizer = VQGanDetokenizer()
    detokenizer.load_state_dict(params)
    return detokenizer


convert("mini")
convert("mega")

This loads the three parts of the model and writes the state dict of each out to a file. I should be able to load this up and then save through huggingface after converting the model over.

One thing that is interesting is that the file size after exporting as pytorch is substantially larger.

Model Conversion

The model has been implemented without using the specialized blocks available in pytorch. I wonder if it is possible to replace them.

There is a lot of code here so for now I am going to copy it over directly and create a composite model.

Code
#collapse
from typing import List
import torch
from torch import nn, BoolTensor, FloatTensor, LongTensor


class GLUTorch(nn.Module):
    def __init__(self, count_in_out, count_middle):
        super().__init__()
        self.gelu = nn.GELU()
        self.ln0 = nn.LayerNorm(count_in_out)
        self.ln1 = nn.LayerNorm(count_middle)
        self.fc0 = nn.Linear(count_in_out, count_middle, bias=False)
        self.fc1 = nn.Linear(count_in_out, count_middle, bias=False)
        self.fc2 = nn.Linear(count_middle, count_in_out, bias=False)
    
    def forward(self, z: FloatTensor) -> FloatTensor:
        z = self.ln0.forward(z)
        w = self.fc0.forward(z)
        w = self.gelu.forward(w)
        v = self.fc1.forward(z)
        z = self.ln1.forward(w * v)
        z = self.fc2.forward(z)
        return z


class AttentionTorch(nn.Module):
    def __init__(self, head_count: int, embed_count: int):
        super().__init__()
        self.head_count = head_count
        self.embed_count = embed_count

        self.k_proj = nn.Linear(embed_count, embed_count, bias=False)
        self.v_proj = nn.Linear(embed_count, embed_count, bias=False)
        self.q_proj = nn.Linear(embed_count, embed_count, bias=False)
        self.out_proj = nn.Linear(embed_count, embed_count, bias=False)
        self.one = torch.ones((1, 1))
        if torch.cuda.is_available(): self.one = self.one.cuda()
    
    def forward(
        self,
        keys: FloatTensor,
        values: FloatTensor,
        queries: FloatTensor,
        attention_mask: BoolTensor
    ) -> FloatTensor:
        attention_bias = torch.where(
            attention_mask,
            self.one * 0,
            self.one * (-torch.inf),
        )
        attention_weights: FloatTensor = torch.einsum(
            'bqhc,bkhc->bhqk',
            queries, 
            keys
        )
        attention_weights += attention_bias[:, None, None, :]
        attention_weights = torch.softmax(attention_weights, -1)
        attention_output: FloatTensor = torch.einsum(
            "bhqk,bkhc->bqhc",
            attention_weights, 
            values
        )
        shape = attention_output.shape[:2] + (self.embed_count,)
        attention_output = attention_output.reshape(shape)
        attention_output = self.out_proj.forward(attention_output)
        return attention_output
Code
#collapse
from typing import List
import torch
from torch import nn, BoolTensor, FloatTensor, LongTensor


class DalleBartEncoderTorch(nn.Module):
    def __init__(
        self,
        layer_count: int,
        embed_count: int,
        attention_head_count: int,
        text_vocab_count: int,
        text_token_count: int,
        glu_embed_count: int
    ):
        super().__init__()
        self.embed_tokens = nn.Embedding(text_vocab_count, embed_count)
        self.embed_positions = nn.Embedding(text_token_count, embed_count)
        self.layers: List[EncoderLayerTorch] = nn.ModuleList([
            EncoderLayerTorch(
                embed_count = embed_count,
                head_count = attention_head_count,
                glu_embed_count = glu_embed_count
            ) 
            for _ in range(layer_count)
        ])
        self.layernorm_embedding = nn.LayerNorm(embed_count)
        self.final_ln = nn.LayerNorm(embed_count)
        self.token_indices = torch.arange(text_token_count).to(torch.long)
        if torch.cuda.is_available(): 
            self.token_indices = self.token_indices.cuda()

    def forward(self, text_tokens: LongTensor) -> FloatTensor:
        attention_mask = text_tokens.not_equal(1)
        batch_count = text_tokens.shape[0]
        pose_tokens = torch.stack([self.token_indices] * batch_count)
        encoder_state = (
            self.embed_tokens.forward(text_tokens) +
            self.embed_positions.forward(pose_tokens)
        )
        encoder_state = self.layernorm_embedding.forward(encoder_state)
        for layer in self.layers:
            encoder_state = layer.forward(encoder_state, attention_mask)
        encoder_state = self.final_ln.forward(encoder_state)
        return encoder_state


class EncoderSelfAttentionTorch(AttentionTorch):
    def forward(
        self,
        encoder_state: FloatTensor,
        attention_mask: BoolTensor
    ) -> FloatTensor:
        shape_split = encoder_state.shape[:2] + (self.head_count, -1)
        keys = self.k_proj.forward(encoder_state).reshape(shape_split)
        values = self.v_proj.forward(encoder_state).reshape(shape_split)
        queries = self.q_proj.forward(encoder_state).reshape(shape_split)
        queries /= queries.shape[-1] ** 0.5
        return super().forward(keys, values, queries, attention_mask)


class EncoderLayerTorch(nn.Module):
    def __init__(self, embed_count: int, head_count: int, glu_embed_count: int):
        super().__init__()
        self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count)
        self.self_attn = EncoderSelfAttentionTorch(head_count, embed_count)
        self.self_attn_layer_norm = nn.LayerNorm(embed_count)
        self.glu = GLUTorch(embed_count, glu_embed_count)
    
    def forward(
        self,
        encoder_state: FloatTensor,
        attention_mask: BoolTensor
    ) -> FloatTensor:
        residual = encoder_state
        encoder_state = self.pre_self_attn_layer_norm.forward(encoder_state)
        encoder_state = self.self_attn.forward(encoder_state, attention_mask)
        encoder_state = self.self_attn_layer_norm.forward(encoder_state)
        encoder_state = residual + encoder_state
        residual = encoder_state
        encoder_state = self.glu.forward(encoder_state)
        encoder_state = residual + encoder_state
        return encoder_state
Code
#collapse
from typing import Tuple, List
import torch
from torch import nn, BoolTensor, FloatTensor, LongTensor


class DalleBartDecoderTorch(nn.Module):
    def __init__(
        self,
        image_vocab_size: int,
        image_token_count: int,
        sample_token_count: int,
        embed_count: int,
        attention_head_count: int,
        glu_embed_count: int,
        layer_count: int,
        batch_count: int,
        start_token: int,
        is_verbose: bool
    ):
        super().__init__()
        self.is_verbose = is_verbose
        self.layer_count = layer_count
        self.sample_token_count = sample_token_count
        self.condition_factor = 10.0
        self.image_token_count = image_token_count
        self.embed_tokens = nn.Embedding(image_vocab_size + 1, embed_count)
        self.embed_positions = nn.Embedding(image_token_count, embed_count)
        self.layers: List[DecoderLayerTorch] = nn.ModuleList([
            DecoderLayerTorch(
                image_token_count,
                attention_head_count,
                embed_count,
                glu_embed_count
            ) 
            for _ in range(layer_count)
        ])
        self.layernorm_embedding = nn.LayerNorm(embed_count)
        self.final_ln = nn.LayerNorm(embed_count)
        self.lm_head = nn.Linear(embed_count, image_vocab_size + 1, bias=False)
        self.keys_values_state_shape = (
            layer_count * 2 * batch_count,
            image_token_count,
            attention_head_count,
            embed_count // attention_head_count
        )
        self.zero_prob = torch.zeros([1])
        self.token_indices = torch.arange(self.sample_token_count)
        self.start_token = torch.tensor([start_token]).to(torch.long)
        if torch.cuda.is_available():
            self.zero_prob = self.zero_prob.cuda()
            self.token_indices = self.token_indices.cuda()
            self.start_token = self.start_token.cuda()


    def decode_step(
        self,
        text_tokens: LongTensor,
        encoder_state: FloatTensor,
        keys_values_state: FloatTensor,
        prev_token_and_index: LongTensor
    ) -> Tuple[LongTensor, FloatTensor]:
        attention_mask = text_tokens.not_equal(1)
        batch_count = encoder_state.shape[0]
        prev_token = torch.cat([prev_token_and_index[:1]] * batch_count)
        token_index = torch.cat([prev_token_and_index[1:]] * batch_count)
        decoder_state = self.embed_tokens.forward(prev_token)
        decoder_state += self.embed_positions.forward(token_index)
        decoder_state = self.layernorm_embedding.forward(decoder_state)
        decoder_state = decoder_state[:, None]
        keys_values = []
        for i, layer in enumerate(self.layers):
            j1, j2 = i * 2 * batch_count, (i + 1) * 2 * batch_count
            decoder_state, keys_values_layer = layer.forward(
                decoder_state,
                encoder_state,
                keys_values_state[j1:j2],
                attention_mask,
                token_index[:1]
            )
            keys_values.append(keys_values_layer)
        keys_values = torch.cat(keys_values, dim=0)
        decoder_state = self.final_ln(decoder_state)
        logits = self.lm_head(decoder_state)
        a = self.condition_factor
        logits: FloatTensor = a * logits[0, -1] + (1 - a) * logits[1, -1]

        top_logits = logits.sort(descending=True)[0][:50]
        probs = torch.where(
            logits < top_logits[-1],
            self.zero_prob,
            torch.exp(logits - top_logits[0])
        )
        return probs, keys_values


    def forward(
        self,
        text_tokens: LongTensor,
        encoder_state: FloatTensor
    ) -> LongTensor:
        image_tokens: List[LongTensor] = []
        keys_values_state = torch.zeros(self.keys_values_state_shape)
        if torch.cuda.is_available(): 
            keys_values_state = keys_values_state.cuda()
        image_token = self.start_token

        for i in range(self.sample_token_count):
            token_index = self.token_indices[i:i+1]
            probs, keys_values_state = self.decode_step(
                text_tokens = text_tokens,
                encoder_state = encoder_state,
                keys_values_state = keys_values_state,
                prev_token_and_index = torch.cat([image_token, token_index])
            )

            image_token = torch.multinomial(probs, 1)
            image_tokens += [image_token]
            
        return torch.cat(image_tokens)


class DecoderCrossAttentionTorch(AttentionTorch):
    def forward(
        self,
        decoder_state: FloatTensor,
        encoder_state: FloatTensor,
        attention_mask: BoolTensor
    ) -> FloatTensor:
        keys = self.k_proj.forward(encoder_state)
        values = self.v_proj.forward(encoder_state)
        queries = self.q_proj.forward(decoder_state)
        query_shape = queries.shape[:2] + (self.head_count, -1)
        key_value_shape = keys.shape[:2] + (self.head_count, -1)
        keys = keys.reshape(key_value_shape)
        values = values.reshape(key_value_shape)
        queries = queries.reshape(query_shape)
        queries /= queries.shape[-1] ** 0.5
        return super().forward(keys, values, queries, attention_mask)


class DecoderSelfAttentionTorch(AttentionTorch):
    def forward(
        self, 
        decoder_state: FloatTensor,
        keys_values: FloatTensor,
        attention_mask: BoolTensor,
        token_mask: BoolTensor
    ) -> Tuple[FloatTensor, FloatTensor]:
        batch_count = decoder_state.shape[0]
        shape = (batch_count, 1) + keys_values.shape[2:]
        keys = self.k_proj.forward(decoder_state).view(shape)
        values = self.v_proj.forward(decoder_state).view(shape)
        keys_values = torch.where(
            token_mask[None, :, None, None], 
            torch.cat([keys, values]), 
            keys_values
        )
        queries = self.q_proj.forward(decoder_state).reshape(shape)
        queries /= queries.shape[-1] ** 0.5
        keys, values = keys_values[:batch_count], keys_values[batch_count:]
        decoder_state = super().forward(keys, values, queries, attention_mask)
        return decoder_state, keys_values


class DecoderLayerTorch(nn.Module):
    def __init__(
        self, 
        image_token_count: int,
        head_count: int, 
        embed_count: int,
        glu_embed_count: int
    ):
        super().__init__()
        self.image_token_count = image_token_count
        self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count)
        self.self_attn = DecoderSelfAttentionTorch(head_count, embed_count)
        self.self_attn_layer_norm = nn.LayerNorm(embed_count)
        self.pre_encoder_attn_layer_norm = nn.LayerNorm(embed_count)
        self.encoder_attn = DecoderCrossAttentionTorch(head_count, embed_count)
        self.encoder_attn_layer_norm = nn.LayerNorm(embed_count)
        self.glu = GLUTorch(embed_count, glu_embed_count)

        self.token_indices = torch.arange(self.image_token_count)
        if torch.cuda.is_available():
            self.token_indices = self.token_indices.cuda()

    def forward(
        self,
        decoder_state: FloatTensor,
        encoder_state: FloatTensor,
        keys_values_state: FloatTensor,
        attention_mask: BoolTensor,
        token_index: LongTensor
    ) -> Tuple[FloatTensor, FloatTensor]:
        # Self Attention
        residual = decoder_state
        decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state)
        self_attn_mask = self.token_indices < token_index + 1
        token_mask = self.token_indices == token_index
        self_attn_mask = torch.stack([self_attn_mask] * decoder_state.shape[0])
        decoder_state, keys_values_state = self.self_attn.forward(
            decoder_state,
            keys_values_state,
            self_attn_mask,
            token_mask
        )
        decoder_state = self.self_attn_layer_norm.forward(decoder_state)
        decoder_state = residual + decoder_state

        # Cross Attention
        residual = decoder_state
        decoder_state = self.pre_encoder_attn_layer_norm.forward(decoder_state)
        decoder_state = self.encoder_attn.forward(
            decoder_state,
            encoder_state,
            attention_mask
        )
        decoder_state = self.encoder_attn_layer_norm.forward(decoder_state)
        decoder_state = residual + decoder_state

        # Feed forward
        residual = decoder_state
        decoder_state = self.glu.forward(decoder_state)
        decoder_state = residual + decoder_state

        return decoder_state, keys_values_state
Code
#collapse
import torch
from torch import Tensor
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
torch.set_grad_enabled(False)

BATCH_COUNT: int = 1



class VQGanDetokenizer(Module):
    def __init__(self):
        super().__init__()
        m, n = 2 ** 14, 2 ** 8
        self.embedding = Embedding(m, n)
        self.post_quant_conv = Conv2d(n, n, 1)
        self.decoder = Decoder()

    def forward(self, z: Tensor) -> Tensor:
        z = self.embedding.forward(z)
        z = z.view((BATCH_COUNT, 2 ** 4, 2 ** 4, 2 ** 8))
        z = z.permute(0, 3, 1, 2).contiguous()
        z = self.post_quant_conv.forward(z)
        z = self.decoder.forward(z)
        z = z.permute(0, 2, 3, 1)
        z = z.clip(0.0, 1.0) * 255
        return z[0]

class ResnetBlock(Module):
    def __init__(self, log2_count_in: int, log2_count_out: int):
        super().__init__()
        m, n = 2 ** log2_count_in, 2 ** log2_count_out
        self.is_middle = m == n
        self.norm1 = GroupNorm(2 ** 5, m)
        self.conv1 = Conv2d(m, n, 3, padding=1)
        self.norm2 = GroupNorm(2 ** 5, n)
        self.conv2 = Conv2d(n, n, 3, padding=1)
        if not self.is_middle:
            self.nin_shortcut = Conv2d(m, n, 1)

    def forward(self, x: Tensor) -> Tensor:
        h = x
        h = self.norm1.forward(h)
        h *= torch.sigmoid(h)
        h = self.conv1.forward(h)
        h = self.norm2.forward(h)
        h *= torch.sigmoid(h)
        h = self.conv2(h)
        if not self.is_middle:
            x = self.nin_shortcut.forward(x)
        return x + h


class AttentionBlock(Module):
    def __init__(self):
        super().__init__()
        n = 2 ** 9
        self.norm = GroupNorm(2 ** 5, n)
        self.q = Conv2d(n, n, 1)
        self.k = Conv2d(n, n, 1)
        self.v = Conv2d(n, n, 1)
        self.proj_out = Conv2d(n, n, 1)

    def forward(self, x: Tensor) -> Tensor:
        n = 2 ** 9
        h = x
        h = self.norm(h)
        q = self.q.forward(h)
        k = self.k.forward(h)
        v = self.v.forward(h)
        q = q.reshape(BATCH_COUNT, n, 2 ** 8)
        q = q.permute(0, 2, 1)
        k = k.reshape(BATCH_COUNT, n, 2 ** 8)
        w = torch.bmm(q, k)
        w /= n ** 0.5
        w = torch.softmax(w, dim=2)
        v = v.reshape(BATCH_COUNT, n, 2 ** 8)
        w = w.permute(0, 2, 1)
        h = torch.bmm(v, w)
        h = h.reshape(BATCH_COUNT, n, 2 ** 4, 2 ** 4)
        h = self.proj_out.forward(h)
        return x + h

class MiddleLayer(Module):
    def __init__(self):
        super().__init__()
        self.block_1 = ResnetBlock(9, 9)
        self.attn_1 = AttentionBlock()
        self.block_2 = ResnetBlock(9, 9)
    
    def forward(self, h: Tensor) -> Tensor:
        h = self.block_1.forward(h)
        h = self.attn_1.forward(h)
        h = self.block_2.forward(h)
        return h

class Upsample(Module):
    def __init__(self, log2_count):
        super().__init__()
        n = 2 ** log2_count
        self.upsample = torch.nn.UpsamplingNearest2d(scale_factor=2)
        self.conv = Conv2d(n, n, 3, padding=1)

    def forward(self, x: Tensor) -> Tensor:
        x = self.upsample.forward(x)
        x = self.conv.forward(x)
        return x

class UpsampleBlock(Module):
    def __init__(
        self, 
        log2_count_in: int, 
        log2_count_out: int, 
        has_attention: bool, 
        has_upsample: bool
    ):
        super().__init__()
        self.has_attention = has_attention
        self.has_upsample = has_upsample
        self.block = ModuleList([
            ResnetBlock(log2_count_in, log2_count_out),
            ResnetBlock(log2_count_out, log2_count_out),
            ResnetBlock(log2_count_out, log2_count_out)
        ])
        if has_attention:
            self.attn = ModuleList([
                AttentionBlock(),
                AttentionBlock(),
                AttentionBlock()
            ])
        else:
            self.attn = ModuleList()

        if has_upsample:
            self.upsample = Upsample(log2_count_out)


    def forward(self, h: Tensor) -> Tensor:
        for j in range(3):
            h = self.block[j].forward(h)
            if self.has_attention:
                h = self.attn[j].forward(h)
        if self.has_upsample:
            h = self.upsample.forward(h)
        return h

class Decoder(Module):
    def __init__(self):
        super().__init__()

        self.conv_in = Conv2d(2 ** 8, 2 ** 9, 3, padding=1)
        self.mid = MiddleLayer()

        self.up = ModuleList([
            UpsampleBlock(7, 7, False, False),
            UpsampleBlock(8, 7, False, True),
            UpsampleBlock(8, 8, False, True),
            UpsampleBlock(9, 8, False, True),
            UpsampleBlock(9, 9, True, True)
        ])

        self.norm_out = GroupNorm(2 ** 5, 2 ** 7)
        self.conv_out = Conv2d(2 ** 7, 3, 3, padding=1)

    def forward(self, z: Tensor) -> Tensor:
        z = self.conv_in.forward(z)
        z = self.mid.forward(z)

        for i in reversed(range(5)):
            z = self.up[i].forward(z)

        z = self.norm_out.forward(z)
        z *= torch.sigmoid(z)
        z = self.conv_out.forward(z)
        return z

With these definitions of the three parts of the model we can create the overall model. This will take the tokenized sentence and return an image.

Code
from __future__ import annotations
import json
from pathlib import Path
import torch
from PIL import Image

class OverallModel(nn.Module):
    @staticmethod
    def from_pretrained(folder: Path) -> OverallModel:
        config = json.loads((folder / "config.json").read_text())
        model = OverallModel(config)
        model.encoder.load_state_dict(torch.load(folder / "encoder.pt"))
        model.decoder.load_state_dict(torch.load(folder / "decoder.pt"))
        model.detokenizer.load_state_dict(torch.load(folder / "detokenizer.pt"))
        return model
        
    def __init__(self, config) -> None:
        super().__init__()
        image_token_count = 256
        self.max_text_length = config["max_text_length"]
        self.encoder = DalleBartEncoderTorch(
            layer_count = config["encoder_layers"],                                                                                                                               
            embed_count = config["d_model"],                                                                                                                                      
            attention_head_count = config["encoder_attention_heads"],                                                                                                             
            text_vocab_count = config["encoder_vocab_size"],                                                                                                                      
            text_token_count = config["max_text_length"],                                                                                                                         
            glu_embed_count = config["encoder_ffn_dim"],
        )
        self.decoder = DalleBartDecoderTorch(
            image_vocab_size = config["image_vocab_size"],                                                                                                                        
            image_token_count = config["image_length"],                                                                                                                           
            sample_token_count = image_token_count,                                                                                                                               
            embed_count = config["d_model"],                                                                                                                                      
            attention_head_count = config["decoder_attention_heads"],                                                                                                             
            glu_embed_count = config["decoder_ffn_dim"],                                                                                                                          
            layer_count = config["decoder_layers"],                                                                                                                               
            batch_count = 2,                                                                                                                                                      
            start_token = config["decoder_start_token_id"],                                                                                                                       
            is_verbose = True,
        )
        self.detokenizer = VQGanDetokenizer()

    @torch.inference_mode()
    def forward(self, tokens):
        pt_tokens = torch.ones(
            (2, self.max_text_length),
            dtype=torch.long,
            device=next(model.parameters()).device
        )
        pt_tokens[0, :len(tokens)] = tokens
        pt_tokens[1, 0] = tokens[0]
        pt_tokens[1, 1] = tokens[-1]

        encoded = self.encoder(pt_tokens)
        decoded = self.decoder(pt_tokens, encoded)
        detokenized = self.detokenizer(decoded)

        array = detokenized.to(torch.uint8).cpu().detach().numpy()
        return Image.fromarray(array)

Evaluation

We can try out the model with different prompts. As you can see I’ve had quite a lot of fun with this.

The model runs on the GPU successfully and the marge version takes around 16G of space. A single inference is a few seconds.

Code
from pathlib import Path
from transformers import BartTokenizerFast
from PIL import Image

folder = Path("/data/dall-e/torch/mega/")
model = OverallModel.from_pretrained(folder)
model.eval()
model.cuda()

tokenizer = BartTokenizerFast.from_pretrained(folder)

def generate(text: str) -> Image.Image:
    tokens = tokenizer(text, return_tensors="pt").input_ids
    return model(tokens[0].cuda())
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DalleBartTokenizer'. 
The class this function is called from is 'BartTokenizerFast'.
Code
%%time
generate("A regal lion wearing a crown")
CPU times: user 5.73 s, sys: 20.2 ms, total: 5.75 s
Wall time: 5.55 s

Code
generate("A rose entwined with ivy in stained glass")

Code
generate("A screenshot of factorio")

Code
generate("A screenshot of minecraft")

Code
generate("A screenshot of diablo")

Code
generate("Our first Azmodan kill")

Code
generate("My painted miniature")

Code
generate("My painted orc miniature")

Code
generate("My painted elven lord miniature")

Code
generate("My painted hippogriff miniature")

Code
generate("A photo of our new puppy")

Code
generate("A stained glass window of our new puppy")

Code
generate("A rocket flying to the moon")

Code
generate("A face")

Code
generate("frida kahlo")

Code
generate("a goat climbing a mountain")

Code
generate("A house made out of clay")

Code
generate("The gingerbread house")

Code
generate("logo of an armchair in the shape of an avocado")

Code
generate("logo of a hacker")

Code
generate("logo of vampire survivors")

Code
generate("spider with googly eyes")

Code
generate("an angel with peacock feather wings")

Code
generate("a googly eyed angel with peacock feather wings")

Code
generate("a googly eyes angel with peacock feather wings")

Code
generate("a googly eyes angel with peacock feather wings")

Code
generate("the correct set of weights for my multilingual internalized prompt model")