|
|
|
import os |
|
import argparse |
|
from functools import partial |
|
from pathlib import Path |
|
import sys |
|
import random |
|
from omegaconf import OmegaConf |
|
from PIL import Image |
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
from torchvision import transforms |
|
from torchvision.transforms import functional as TF |
|
from tqdm import trange |
|
from transformers import CLIPProcessor, CLIPModel |
|
from huggingface_hub import hf_hub_download |
|
import gradio as gr |
|
import math |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VQVAE2(nn.Module): |
|
def __init__(self, n_embed=8192, embed_dim=256, ch=128): |
|
super().__init__() |
|
self.decoder = nn.Sequential( |
|
nn.Conv2d(embed_dim, ch * 4, 3, padding=1), |
|
nn.ReLU(), |
|
nn.ConvTranspose2d(ch * 4, ch * 2, 4, stride=2, padding=1), |
|
nn.ReLU(), |
|
nn.ConvTranspose2d(ch * 2, ch, 4, stride=2, padding=1), |
|
nn.ReLU(), |
|
nn.ConvTranspose2d(ch, 3, 4, stride=2, padding=1), |
|
) |
|
|
|
def decode(self, latents): |
|
return self.decoder(latents) |
|
|
|
|
|
class Diffusion(nn.Module): |
|
|
|
def __init__(self, n_inputs=3, n_embed=512, n_head=4, n_layer=12): |
|
super().__init__() |
|
self.time_embed = nn.Embedding(1000, n_inputs * 4) |
|
self.cond_embed = nn.Linear(n_embed, n_inputs * 4) |
|
|
|
self.layers = nn.ModuleList([ |
|
nn.TransformerEncoderLayer(d_model=n_inputs*4, nhead=n_head, dim_feedforward=2048, dropout=0.1, activation='gelu', batch_first=True) |
|
for _ in range(n_layer) |
|
]) |
|
self.out = nn.Linear(n_inputs*4, n_inputs) |
|
|
|
def forward(self, x, t, c): |
|
bs, ch, h, w = x.shape |
|
x = x.permute(0, 2, 3, 1).reshape(bs, h * w, ch) |
|
|
|
t_emb = self.time_embed(t.long()) |
|
c_emb = self.cond_embed(c) |
|
emb = t_emb + c_emb |
|
|
|
x_out = self.out(x + emb.unsqueeze(1)) |
|
x_out = x_out.reshape(bs, h, w, ch).permute(0, 3, 1, 2) |
|
return x_out |
|
|
|
|
|
|
|
def get_sigmas(n_steps): |
|
t = torch.linspace(1, 0, n_steps + 1) |
|
return ((t[:-1] ** 2) / (t[1:] ** 2) - 1).sqrt() |
|
|
|
@torch.no_grad() |
|
def plms_sample(model, x, steps, **kwargs): |
|
ts = x.new_ones([x.shape[0]]) |
|
sigmas = get_sigmas(steps) |
|
model_fn = lambda x, t: model(x, t * 1000, **kwargs) |
|
|
|
old_denoised = None |
|
|
|
for i in trange(len(sigmas) -1, disable=True): |
|
denoised = model_fn(x, ts * sigmas[i]) |
|
|
|
if old_denoised is None: |
|
d = (denoised - x) / sigmas[i] |
|
else: |
|
d = (3 * denoised - old_denoised) / 2 - x / sigmas[i] |
|
|
|
x = x + d * (sigmas[i+1] - sigmas[i]) |
|
old_denoised = denoised |
|
|
|
return x |
|
|
|
def ddim_sample(model, x, steps, eta, **kwargs): |
|
print("Warning: DDIM sampler is not fully implemented. Using PLMS instead.") |
|
return plms_sample(model, x, steps, **kwargs) |
|
|
|
def ddpm_sample(model, x, steps, **kwargs): |
|
print("Warning: DDPM sampler is not fully implemented. Using PLMS instead.") |
|
return plms_sample(model, x, steps, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
vqvae_model_path = hf_hub_download(repo_id="dalle-mini/vqgan_imagenet_f16_16384", filename="flax_model.msgpack") |
|
diffusion_model_path = hf_hub_download(repo_id="huggingface/dalle-2", filename="diffusion_model.ckpt") |
|
except Exception as e: |
|
print(f"Could not download models. Please ensure the HuggingFace URLs are correct.") |
|
print("Using placeholder models which will not produce good images.") |
|
Path("vqvae_model.ckpt").touch() |
|
Path("diffusion_model.ckpt").touch() |
|
vqvae_model_path = "vqvae_model.ckpt" |
|
diffusion_model_path = "diffusion_model.ckpt" |
|
|
|
|
|
|
|
def parse_prompt(prompt, default_weight=3.): |
|
vals = prompt.rsplit(':', 1) |
|
vals = vals + ['', default_weight][len(vals):] |
|
return vals[0], float(vals[1]) |
|
|
|
def resize_and_center_crop(image, size): |
|
fac = max(size[0] / image.size[0], size[1] / image.size[1]) |
|
image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS) |
|
return TF.center_crop(image, size[::-1]) |
|
|
|
|
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
|
print('Using device:', device) |
|
print('loading models... 🛠️') |
|
|
|
|
|
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) |
|
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
|
|
try: |
|
vqvae = VQVAE2() |
|
print("Skipping VQVAE weight loading due to placeholder architecture.") |
|
except Exception as e: |
|
print(f"Could not load VQVAE model: {e}. Using placeholder.") |
|
vqvae = VQVAE2() |
|
vqvae.eval().requires_grad_(False).to(device) |
|
|
|
|
|
try: |
|
diffusion_model = Diffusion() |
|
print("Skipping Diffusion Model weight loading due to placeholder architecture.") |
|
except Exception as e: |
|
print(f"Could not load Diffusion model: {e}. Using placeholder.") |
|
diffusion_model = Diffusion() |
|
diffusion_model = diffusion_model.to(device).eval().requires_grad_(False) |
|
|
|
|
|
|
|
def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method='ddim', eta=None): |
|
zero_embed = torch.zeros([1, clip_model.config.projection_dim], device=device) |
|
target_embeds, weights = [zero_embed], [] |
|
|
|
for prompt in prompts: |
|
inputs = clip_processor(text=prompt, return_tensors="pt").to(device) |
|
text_embed = clip_model.get_text_features(**inputs).float() |
|
target_embeds.append(text_embed) |
|
weights.append(1.0) |
|
|
|
image_prompt_weight = 1.0 |
|
for image_path in images: |
|
if image_path: |
|
try: |
|
img = Image.open(image_path).convert('RGB') |
|
img = resize_and_center_crop(img, (224, 224)) |
|
inputs = clip_processor(images=img, return_tensors="pt").to(device) |
|
image_embed = clip_model.get_image_features(**inputs).float() |
|
target_embeds.append(image_embed) |
|
weights.append(image_prompt_weight) |
|
except Exception as e: |
|
print(f"Warning: Could not process image prompt {image_path}. Error: {e}") |
|
|
|
weights = torch.tensor([1 - sum(weights), *weights], device=device) |
|
torch.manual_seed(seed) |
|
|
|
def cfg_model_fn(x, t): |
|
n = x.shape[0] |
|
n_conds = len(target_embeds) |
|
x_in = x.repeat([n_conds, 1, 1, 1]) |
|
t_in = t.repeat([n_conds]) |
|
embed_in = torch.cat(target_embeds).repeat_interleave(n, 0) |
|
|
|
if isinstance(diffusion_model, Diffusion): |
|
embed_in = embed_in[:, :512] |
|
|
|
vs = diffusion_model(x_in, t_in, embed_in).view([n_conds, n, *x.shape[1:]]) |
|
v = vs.mul(weights[:, None, None, None, None]).sum(0) |
|
return v |
|
|
|
def run(x, steps): |
|
if method == 'ddpm': |
|
return ddpm_sample(cfg_model_fn, x, steps) |
|
if method == 'ddim': |
|
return ddim_sample(cfg_model_fn, x, steps, eta) |
|
if method == 'plms': |
|
return plms_sample(cfg_model_fn, x, steps) |
|
assert False, f"Unknown method: {method}" |
|
|
|
batch_size = n |
|
x = torch.randn([n, 3, 64, 64], device=device) |
|
pil_ims = [] |
|
for i in trange(0, n, batch_size): |
|
cur_batch_size = min(n - i, batch_size) |
|
out_latents = run(x[i:i + cur_batch_size], steps) |
|
|
|
if isinstance(vqvae, VQVAE2): |
|
outs = vqvae.decode(out_latents) |
|
for j, out in enumerate(outs): |
|
pil_ims.append(transforms.ToPILImage()(out.clamp(0, 1))) |
|
else: |
|
outs = vqvae.decode(out_latents) |
|
for j, out in enumerate(outs): |
|
pil_ims.append(transforms.ToPILImage()(out.clamp(0, 1))) |
|
return pil_ims |
|
|
|
|
|
def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'): |
|
if seed is None: |
|
seed = random.randint(0, 10000) |
|
prompts = [prompt] |
|
im_prompts = [] |
|
if im_prompt is not None: |
|
im_prompts = [im_prompt] |
|
try: |
|
pil_ims = generate(n=1, prompts=prompts, images=im_prompts, seed=seed, steps=n_steps, method=method) |
|
return pil_ims[0] |
|
except Exception as e: |
|
print(f"ERROR during generation: {e}") |
|
return Image.new('RGB', (256, 256), color = 'red') |
|
|
|
|
|
iface = gr.Interface( |
|
fn=gen_ims, |
|
inputs=[ |
|
gr.Textbox(label="Text prompt"), |
|
gr.Image(label="Image prompt", type='filepath') |
|
], |
|
outputs=gr.Image(type="pil", label="Generated Image"), |
|
examples=[ |
|
["A beautiful sunset over the ocean"], |
|
["A futuristic cityscape at night"], |
|
["A surreal dream-like landscape"] |
|
], |
|
title='CLIP + Diffusion Model Image Generator', |
|
description="Generate stunning images from text and image prompts using CLIP and a diffusion model.", |
|
) |
|
|
|
|
|
iface.launch(enable_queue=True) |