|
|
|
import os |
|
import argparse |
|
from functools import partial |
|
from pathlib import Path |
|
import sys |
|
import random |
|
from omegaconf import OmegaConf |
|
from PIL import Image |
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
from torchvision import transforms |
|
from torchvision.transforms import functional as TF |
|
from tqdm import trange |
|
from transformers import CLIPProcessor, CLIPModel |
|
from vqvae import VQVAE2 |
|
from diffusion_models import Diffusion |
|
from huggingface_hub import hf_hub_url, cached_download |
|
import gradio as gr |
|
|
|
|
|
vqvae_model_path = cached_download(hf_hub_url("huggingface/vqvae-2", filename="vqvae_model.ckpt")) |
|
diffusion_model_path = cached_download(hf_hub_url("huggingface/dalle-2", filename="diffusion_model.ckpt")) |
|
|
|
|
|
|
|
|
|
def parse_prompt(prompt, default_weight=3.): |
|
""" |
|
🎯 Parses a prompt into text and weight. |
|
""" |
|
vals = prompt.rsplit(':', 1) |
|
vals = vals + ['', default_weight][len(vals):] |
|
return vals[0], float(vals[1]) |
|
|
|
def resize_and_center_crop(image, size): |
|
""" |
|
✂️ Resize and crop image to center it beautifully. |
|
""" |
|
fac = max(size[0] / image.size[0], size[1] / image.size[1]) |
|
image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS) |
|
return TF.center_crop(image, size[::-1]) |
|
|
|
|
|
|
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
|
print('Using device:', device) |
|
print('loading models... 🛠️') |
|
|
|
|
|
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) |
|
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
|
|
vqvae = VQVAE2() |
|
vqvae.load_state_dict(torch.load(vqvae_model_path)) |
|
vqvae.eval().requires_grad_(False).to(device) |
|
|
|
|
|
diffusion_model = Diffusion() |
|
diffusion_model.load_state_dict(torch.load(diffusion_model_path)) |
|
diffusion_model = diffusion_model.to(device).eval().requires_grad_(False) |
|
|
|
|
|
|
|
|
|
def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method='ddim', eta=None): |
|
""" |
|
🖼️ Generates a list of PIL images based on given text and image prompts. |
|
""" |
|
zero_embed = torch.zeros([1, clip_model.config.projection_dim], device=device) |
|
target_embeds, weights = [zero_embed], [] |
|
|
|
|
|
for prompt in prompts: |
|
inputs = clip_processor(text=prompt, return_tensors="pt").to(device) |
|
text_embed = clip_model.get_text_features(**inputs).float() |
|
target_embeds.append(text_embed) |
|
weights.append(1.0) |
|
|
|
|
|
for prompt in images: |
|
path, weight = parse_prompt(prompt) |
|
img = Image.open(path).convert('RGB') |
|
img = resize_and_center_crop(img, (224, 224)) |
|
inputs = clip_processor(images=img, return_tensors="pt").to(device) |
|
image_embed = clip_model.get_image_features(**inputs).float() |
|
target_embeds.append(image_embed) |
|
weights.append(weight) |
|
|
|
|
|
weights = torch.tensor([1 - sum(weights), *weights], device=device) |
|
torch.manual_seed(seed) |
|
|
|
|
|
def cfg_model_fn(x, t): |
|
n = x.shape[0] |
|
n_conds = len(target_embeds) |
|
x_in = x.repeat([n_conds, 1, 1, 1]) |
|
t_in = t.repeat([n_conds]) |
|
embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0) |
|
vs = diffusion_model(x_in, t_in, embed_in).view([n_conds, n, *x.shape[1:]]) |
|
v = vs.mul(weights[:, None, None, None, None]).sum(0) |
|
return v |
|
|
|
|
|
def run(x, steps): |
|
if method == 'ddpm': |
|
return sampling.sample(cfg_model_fn, x, steps, 1., {}) |
|
if method == 'ddim': |
|
return sampling.sample(cfg_model_fn, x, steps, eta, {}) |
|
if method == 'plms': |
|
return sampling.plms_sample(cfg_model_fn, x, steps, {}) |
|
assert False |
|
|
|
|
|
batch_size = n |
|
x = torch.randn([n, 3, 64, 64], device=device) |
|
t = torch.linspace(1, 0, steps + 1, device=device)[:-1] |
|
pil_ims = [] |
|
for i in trange(0, n, batch_size): |
|
cur_batch_size = min(n - i, batch_size) |
|
out_latents = run(x[i:i + cur_batch_size], steps) |
|
outs = vqvae.decode(out_latents) |
|
for j, out in enumerate(outs): |
|
pil_ims.append(transforms.ToPILImage()(out)) |
|
|
|
return pil_ims |
|
|
|
|
|
def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'): |
|
""" |
|
💡 Gradio function to wrap image generation. |
|
""" |
|
if seed is None: |
|
seed = random.randint(0, 10000) |
|
prompts = [prompt] |
|
im_prompts = [] |
|
if im_prompt is not None: |
|
im_prompts = [im_prompt] |
|
pil_ims = generate(n=1, prompts=prompts, images=im_prompts, seed=seed, steps=n_steps, method=method) |
|
return pil_ims[0] |
|
|
|
|
|
iface = gr.Interface( |
|
fn=gen_ims, |
|
inputs=[ |
|
gr.Textbox(label="Text prompt"), |
|
gr.Image(optional=True, label="Image prompt", type='filepath') |
|
], |
|
outputs=gr.Image(type="pil", label="Generated Image"), |
|
examples=[ |
|
["A beautiful sunset over the ocean"], |
|
["A futuristic cityscape at night"], |
|
["A surreal dream-like landscape"] |
|
], |
|
title='CLIP + Diffusion Model Image Generator', |
|
description="Generate stunning images from text and image prompts using CLIP and a diffusion model.", |
|
) |
|
|
|
|
|
iface.launch(enable_queue=True) |
|
|