text2img-clip / app.py
CW-613's picture
Update app.py
5b67b83
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from datasets import load_dataset
from diffusers import DDIMScheduler, DDPMPipeline
from matplotlib import pyplot as plt
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
device = "cpu"
# Load the pretrained pipeline
pipeline_name = "johnowhitaker/sd-class-wikiart-from-bedrooms"
image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device)
# Sample some images with a DDIM Scheduler over 40 steps
scheduler = DDIMScheduler.from_pretrained(pipeline_name)
scheduler.set_timesteps(num_inference_steps=40)
import open_clip
clip_model, _, preprocess = open_clip.create_model_and_transforms(
"ViT-B-32", pretrained="openai"
)
clip_model.to(device)
# Transforms to resize and augment an image + normalize to match CLIP's training data
tfms = torchvision.transforms.Compose(
[
torchvision.transforms.RandomResizedCrop(224), # Random CROP each time
torchvision.transforms.RandomAffine(
5
), # One possible random augmentation: skews the image
torchvision.transforms.RandomHorizontalFlip(), # You can add additional augmentations if you like
torchvision.transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
]
)
# And define a loss function that takes an image, embeds it and compares with
# the text features of the prompt
def clip_loss(image, text_features):
image_features = clip_model.encode_image(
tfms(image)
) # Note: applies the above transforms
input_normed = torch.nn.functional.normalize(image_features.unsqueeze(1), dim=2)
embed_normed = torch.nn.functional.normalize(text_features.unsqueeze(0), dim=2)
dists = (
input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
) # Squared Great Circle Distance
return dists.mean()
import gradio as gr
from PIL import Image
def text2img(prompt, guidance_scale, n_cuts):
print(f"prompt: {prompt}")
# We embed a prompt with CLIP as our target
text = open_clip.tokenize([prompt]).to(device)
with torch.no_grad(), torch.cuda.amp.autocast():
text_features = clip_model.encode_text(text)
x = torch.randn(1, 3, 256, 256).to(device) # RAM usage is high, you may want only 1 image at a time
for i, t in tqdm(enumerate(scheduler.timesteps)):
model_input = scheduler.scale_model_input(x, t)
# predict the noise residual
with torch.no_grad():
noise_pred = image_pipe.unet(model_input, t)["sample"]
cond_grad = 0
for cut in range(int(n_cuts)):
# Set requires grad on x
x = x.detach().requires_grad_()
# Get the predicted x0:
x0 = scheduler.step(noise_pred, t, x).pred_original_sample
# Calculate loss
loss = clip_loss(x0, text_features) * guidance_scale
# Get gradient (scale by n_cuts since we want the average)
cond_grad -= torch.autograd.grad(loss, x)[0] / n_cuts
# Modify x based on this gradient
alpha_bar = scheduler.alphas_cumprod[i]
x = x.detach() + cond_grad * alpha_bar.sqrt() # Note the additional scaling factor here!
# Now step with scheduler
x = scheduler.step(noise_pred, t, x).prev_sample
grid = torchvision.utils.make_grid(x, nrow=1)
im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5
im = Image.fromarray(np.array(im * 255).astype(np.uint8))
return im
# See the gradio docs for the types of inputs and outputs available
inputs = ["text", "number", "number"]
outputs = gr.Image(label="text-guided image")
# And the minimal interface
demo = gr.Interface(
fn=text2img,
inputs=inputs,
outputs=outputs,
examples=[
["Red Rose (still life), red flower painting", 10, 4],
["Blue sky, pure and bright, positive mood", 8, 5], # You can provide some example inputs to get people started
],
)
demo.launch()