Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import torchvision | |
from datasets import load_dataset | |
from diffusers import DDIMScheduler, DDPMPipeline | |
from matplotlib import pyplot as plt | |
from PIL import Image | |
from torchvision import transforms | |
from tqdm.auto import tqdm | |
device = "cpu" | |
# Load the pretrained pipeline | |
pipeline_name = "johnowhitaker/sd-class-wikiart-from-bedrooms" | |
image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device) | |
# Sample some images with a DDIM Scheduler over 40 steps | |
scheduler = DDIMScheduler.from_pretrained(pipeline_name) | |
scheduler.set_timesteps(num_inference_steps=40) | |
import open_clip | |
clip_model, _, preprocess = open_clip.create_model_and_transforms( | |
"ViT-B-32", pretrained="openai" | |
) | |
clip_model.to(device) | |
# Transforms to resize and augment an image + normalize to match CLIP's training data | |
tfms = torchvision.transforms.Compose( | |
[ | |
torchvision.transforms.RandomResizedCrop(224), # Random CROP each time | |
torchvision.transforms.RandomAffine( | |
5 | |
), # One possible random augmentation: skews the image | |
torchvision.transforms.RandomHorizontalFlip(), # You can add additional augmentations if you like | |
torchvision.transforms.Normalize( | |
mean=(0.48145466, 0.4578275, 0.40821073), | |
std=(0.26862954, 0.26130258, 0.27577711), | |
), | |
] | |
) | |
# And define a loss function that takes an image, embeds it and compares with | |
# the text features of the prompt | |
def clip_loss(image, text_features): | |
image_features = clip_model.encode_image( | |
tfms(image) | |
) # Note: applies the above transforms | |
input_normed = torch.nn.functional.normalize(image_features.unsqueeze(1), dim=2) | |
embed_normed = torch.nn.functional.normalize(text_features.unsqueeze(0), dim=2) | |
dists = ( | |
input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2) | |
) # Squared Great Circle Distance | |
return dists.mean() | |
import gradio as gr | |
from PIL import Image | |
def text2img(prompt, guidance_scale, n_cuts): | |
print(f"prompt: {prompt}") | |
# We embed a prompt with CLIP as our target | |
text = open_clip.tokenize([prompt]).to(device) | |
with torch.no_grad(), torch.cuda.amp.autocast(): | |
text_features = clip_model.encode_text(text) | |
x = torch.randn(1, 3, 256, 256).to(device) # RAM usage is high, you may want only 1 image at a time | |
for i, t in tqdm(enumerate(scheduler.timesteps)): | |
model_input = scheduler.scale_model_input(x, t) | |
# predict the noise residual | |
with torch.no_grad(): | |
noise_pred = image_pipe.unet(model_input, t)["sample"] | |
cond_grad = 0 | |
for cut in range(int(n_cuts)): | |
# Set requires grad on x | |
x = x.detach().requires_grad_() | |
# Get the predicted x0: | |
x0 = scheduler.step(noise_pred, t, x).pred_original_sample | |
# Calculate loss | |
loss = clip_loss(x0, text_features) * guidance_scale | |
# Get gradient (scale by n_cuts since we want the average) | |
cond_grad -= torch.autograd.grad(loss, x)[0] / n_cuts | |
# Modify x based on this gradient | |
alpha_bar = scheduler.alphas_cumprod[i] | |
x = x.detach() + cond_grad * alpha_bar.sqrt() # Note the additional scaling factor here! | |
# Now step with scheduler | |
x = scheduler.step(noise_pred, t, x).prev_sample | |
grid = torchvision.utils.make_grid(x, nrow=1) | |
im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5 | |
im = Image.fromarray(np.array(im * 255).astype(np.uint8)) | |
return im | |
# See the gradio docs for the types of inputs and outputs available | |
inputs = ["text", "number", "number"] | |
outputs = gr.Image(label="text-guided image") | |
# And the minimal interface | |
demo = gr.Interface( | |
fn=text2img, | |
inputs=inputs, | |
outputs=outputs, | |
examples=[ | |
["Red Rose (still life), red flower painting", 10, 4], | |
["Blue sky, pure and bright, positive mood", 8, 5], # You can provide some example inputs to get people started | |
], | |
) | |
demo.launch() |