#!/usr/bin/env python
# coding: utf-8
import os
# Uncomment to run on cpu
# os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ["WANDB_DISABLED" ] = "true"
os.environ['WANDB_SILENT' ]= "true"
import random
import re
import torch
import gradio as gr
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard, shard_prng_key
from PIL import Image, ImageDraw, ImageFont
from functools import partial
from transformers import CLIPProcessor, FlaxCLIPModel, AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel
from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel
DALLE_REPO = "dalle-mini/dalle-mini/mini-1:v0"
DALLE_COMMIT_ID = None
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
model, params = DalleBart.from_pretrained(
DALLE_REPO,
revision= DALLE_COMMIT_ID,
dtype= jnp.float16,
# _do_init=False
)
vqgan, vqgan_params = VQModel.from_pretrained(
VQGAN_REPO,
revision= VQGAN_COMMIT_ID,
# _do_init=False
)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
encoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
decoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
model_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
viz_model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
def captioned_strip(images, caption= None , rows= 1 ):
increased_h = 0 if caption is None else 24
w, h = images[0 ].size[0 ], images[0 ].size[1 ]
img = Image.new("RGB" , (len (images) * w // rows, h * rows + increased_h))
for i, img_ in enumerate (images):
img.paste(img_, (i // rows * w, increased_h + (i % rows) * h))
if caption is not None :
draw = ImageDraw.Draw(img)
font = ImageFont.truetype(
"LiberationMono-Bold.ttf" , 7
)
draw.text((20 , 3 ), caption, (255 , 255 , 255 ), font= font)
return img
def get_images(indices, params):
return vqgan.decode_code(indices, params= params)
def predict_caption(image, max_length= 128 , num_beams= 4 ):
image = image.convert('RGB' )
image = feature_extractor(image, return_tensors= "pt" ).pixel_values.to(device)
clean_text = lambda x: x.replace('<|endoftext|>' ,'' ).split(' \n ' )[0 ]
caption_ids = viz_model.generate(image, max_length = max_length)[0 ]
caption_text = clean_text(tokenizer.decode(caption_ids))
return caption_text
# model inference
@partial (jax.pmap, axis_name= "batch" , static_broadcasted_argnums= (3 , 4 , 5 , 6 ))
def p_generate(
tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale
):
return model.generate(
** tokenized_prompt,
prng_key= key,
params= params,
top_k= top_k,
top_p= top_p,
temperature= temperature,
condition_scale= condition_scale,
)
# decode image
@partial (jax.pmap, axis_name= "batch" )
def p_decode(indices, params):
return vqgan.decode_code(indices, params= params)
p_get_images = jax.pmap(get_images, "batch" )
params = replicate(params)
vqgan_params = replicate(vqgan_params)
processor = DalleBartProcessor.from_pretrained(DALLE_REPO, revision= DALLE_COMMIT_ID)
print ("Initialized DalleBartProcessor" )
clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32" )
print ("Initialized FlaxCLIPModel" )
def hallucinate(prompt, num_images= 8 ):
gen_top_k = None
gen_top_p = None
temperature = None
cond_scale = 10.0
print (f"Prompts: { prompt} " )
prompt = [prompt] * jax.device_count()
inputs = processor(prompt)
inputs = replicate(inputs)
# create a random key
seed = random.randint(0 , 2 ** 32 - 1 )
key = jax.random.PRNGKey(seed)
images = []
for i in range (max (num_images // jax.device_count(), 1 )):
key, subkey = jax.random.split(key)
encoded_images = p_generate(
inputs,
shard_prng_key(subkey),
params,
gen_top_k,
gen_top_p,
temperature,
cond_scale,
)
print (f"Encoded image { i} " )
# remove BOS
encoded_images = encoded_images.sequences[..., 1 :]
# decode images
decoded_images = p_decode(encoded_images, vqgan_params)
decoded_images = decoded_images.clip(0.0 , 1.0 ).reshape((- 1 , 256 , 256 , 3 ))
for decoded_img in decoded_images:
img = Image.fromarray(np.asarray(decoded_img * 255 , dtype= np.uint8))
images.append(img)
print (f"Finished decoding image { i} " )
return images
def run_inference(prompt, num_roundtrips= 3 , num_images= 1 ):
outputs = []
for i in range (int (num_roundtrips)):
images = hallucinate(prompt, num_images= num_images)
image = images[0 ]
print ("Generated image" )
caption = predict_caption(image)
print (f"Predicted caption: { caption} " )
output_title = f"""
<font size="+3">
<b>[Roundtrip { i} ]</b><br>
Prompt: { prompt} <br>
🥑 :<br></font>"""
output_caption = f"""
<font size="+3">
🤖💬 : { caption} <br>
</font>
"""
outputs.append(output_title)
outputs.append(image)
outputs.append(output_caption)
prompt = caption
print ("Done." )
return outputs
inputs = gr.inputs.Textbox(label= "What prompt do you want to start with?" , default= "cookie monster the horror movie" )
# num_roundtrips = gr.inputs.Number(default=2, label="How many roundtrips?")
num_roundtrips = 3
outputs = []
for _ in range (int (num_roundtrips)):
outputs.append(gr.outputs.HTML(label= "" ))
outputs.append(gr.Image(label= "" ))
outputs.append(gr.outputs.HTML(label= "" ))
description = """
Round trip DALL·E-mini iterates between DALL·E generation and image captioning, inspired by round trip translation! FYI: runtime is forever (~1hr or possibly longer) because the app is running on CPU.
"""
article = "<p style='text-align: center'>Put together by: Najoung Kim | Dall-E Mini code from flax-community/dalle-mini | Caption code from SRDdev/Image-Caption</p>"
gr.Interface(
fn= run_inference,
inputs= [inputs],
outputs= outputs,
title= "Round Trip DALL·E mini 🥑🔁🤖💬" ,
description= description,
article= article,
theme= "default" ,
css = ".output-image, .input-image, .image-preview {height: 256px !important} "
).launch(enable_queue= False )