Running DALL-E locally

Getting the DALL-E Mini model from huggingface running
Published

June 18, 2022

This is an attempt to get dalle-mini working. I’m using the borisdayma/dalle-mini repo.

To get this working I have tried the following dependencies:

poetry add 'git+https://github.com/borisdayma/dalle-mini.git@v0.1.0'
poetry add 'git+https://github.com/patil-suraj/vqgan-jax.git@main'
poetry add 'jax<0.3.2' 'jaxlib<0.3.2'
poetry add 'flax<0.5'

This code comes from the huggingface space.

Code
#!/usr/bin/env python
# coding: utf-8

import os
# Uncomment to run on cpu
# os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ["WANDB_DISABLED"] = "true"
os.environ['WANDB_SILENT']="true"

import random
import re
import torch

import gradio as gr
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard, shard_prng_key
from PIL import Image, ImageDraw, ImageFont

from functools import partial

from transformers import CLIPProcessor, FlaxCLIPModel, AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel 
from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel


DALLE_REPO = "dalle-mini/dalle-mini/mini-1:v0"
DALLE_COMMIT_ID = None

VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"

model, params = DalleBart.from_pretrained(
    DALLE_REPO,
    revision=DALLE_COMMIT_ID,
    dtype=jnp.float16,
    # _do_init=False
)
vqgan, vqgan_params = VQModel.from_pretrained(
    VQGAN_REPO,
    revision=VQGAN_COMMIT_ID,
    # _do_init=False
)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

encoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
decoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
model_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
viz_model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)


def captioned_strip(images, caption=None, rows=1):
    increased_h = 0 if caption is None else 24
    w, h = images[0].size[0], images[0].size[1]
    img = Image.new("RGB", (len(images) * w // rows, h * rows + increased_h))
    for i, img_ in enumerate(images):
        img.paste(img_, (i // rows * w, increased_h + (i % rows) * h))

    if caption is not None:
        draw = ImageDraw.Draw(img)
        font = ImageFont.truetype(
            "LiberationMono-Bold.ttf", 7
        )
        draw.text((20, 3), caption, (255, 255, 255), font=font)
    return img


def get_images(indices, params):
    return vqgan.decode_code(indices, params=params)


def predict_caption(image, max_length=128, num_beams=4):
    image = image.convert('RGB')
    image = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
    clean_text = lambda x: x.replace('<|endoftext|>','').split('\n')[0]
    caption_ids = viz_model.generate(image, max_length = max_length)[0]
    caption_text = clean_text(tokenizer.decode(caption_ids))
    return caption_text 


# model inference
@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
def p_generate(
    tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale
):
    return model.generate(
        **tokenized_prompt,
        prng_key=key,
        params=params,
        top_k=top_k,
        top_p=top_p,
        temperature=temperature,
        condition_scale=condition_scale,
    )


# decode image
@partial(jax.pmap, axis_name="batch")
def p_decode(indices, params):
    return vqgan.decode_code(indices, params=params)

p_get_images = jax.pmap(get_images, "batch")

params = replicate(params)
vqgan_params = replicate(vqgan_params)

processor = DalleBartProcessor.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
print("Initialized DalleBartProcessor")
clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
print("Initialized FlaxCLIPModel")


def hallucinate(prompt, num_images=8):
    gen_top_k = None
    gen_top_p = None
    temperature = None
    cond_scale = 10.0

    print(f"Prompts: {prompt}")
    prompt = [prompt] * jax.device_count()
    inputs = processor(prompt)
    inputs = replicate(inputs)

    # create a random key
    seed = random.randint(0, 2**32 - 1)
    key = jax.random.PRNGKey(seed)

    images = []
    for i in range(max(num_images // jax.device_count(), 1)):
        key, subkey = jax.random.split(key)
        encoded_images = p_generate(
            inputs,
            shard_prng_key(subkey),
            params,
            gen_top_k,
            gen_top_p,
            temperature,
            cond_scale,
        )
        print(f"Encoded image {i}")
        # remove BOS
        encoded_images = encoded_images.sequences[..., 1:]
        # decode images
        decoded_images = p_decode(encoded_images, vqgan_params)
        decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
        for decoded_img in decoded_images:
            img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
            images.append(img)

        print(f"Finished decoding image {i}")
    return images


def run_inference(prompt, num_roundtrips=3, num_images=1):
    outputs = []
    for i in range(int(num_roundtrips)):
        images = hallucinate(prompt, num_images=num_images)
        image = images[0]
        print("Generated image")
        caption = predict_caption(image)
        print(f"Predicted caption: {caption}")
         
        output_title = f"""
        <font size="+3">
        <b>[Roundtrip {i}]</b><br>
        Prompt: {prompt}<br>
        🥑 :<br></font>"""
        output_caption = f"""
        <font size="+3">
        🤖💬 : {caption}<br>
        </font>
        """
        outputs.append(output_title)
        outputs.append(image)
        outputs.append(output_caption)
        prompt = caption

    print("Done.")
    return outputs


inputs = gr.inputs.Textbox(label="What prompt do you want to start with?", default="cookie monster the horror movie")
# num_roundtrips = gr.inputs.Number(default=2, label="How many roundtrips?")
num_roundtrips = 3
outputs = []
for _ in range(int(num_roundtrips)):
    outputs.append(gr.outputs.HTML(label=""))
    outputs.append(gr.Image(label=""))
    outputs.append(gr.outputs.HTML(label=""))

description = """
Round trip DALL·E-mini iterates between DALL·E generation and image captioning, inspired by round trip translation! FYI: runtime is forever (~1hr or possibly longer) because the app is running on CPU.
"""
article = "<p style='text-align: center'>Put together by: Najoung Kim | Dall-E Mini code from flax-community/dalle-mini | Caption code from SRDdev/Image-Caption</p>"

gr.Interface(
    fn=run_inference,
    inputs=[inputs],
    outputs=outputs,
    title="Round Trip DALL·E mini 🥑🔁🤖💬",
    description=description,
    article=article,
    theme="default",
    css = ".output-image, .input-image, .image-preview {height: 256px !important} "
).launch(enable_queue=False)
ImportError: cannot import name 'constant' from 'jax.nn.initializers' (/home/matthew/.cache/pypoetry/virtualenvs/blog-HrtMnrOS-py3.9/lib/python3.9/site-packages/jax/nn/initializers.py)

I don’t get this. I’ve stepped through every version of jax available. If it’s not one error it’s another.

I don’t think there is a dependency set that actually works. I hate google so much, they break their own libraries all the time.