import numpy as np import torch import torch.nn.functional as F import torchvision from datasets import load_dataset from diffusers import DDIMScheduler, DDPMPipeline from matplotlib import pyplot as plt from PIL import Image from torchvision import transforms from tqdm.auto import tqdm device = "cpu" # Load the pretrained pipeline pipeline_name = "johnowhitaker/sd-class-wikiart-from-bedrooms" image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device) # Sample some images with a DDIM Scheduler over 40 steps scheduler = DDIMScheduler.from_pretrained(pipeline_name) scheduler.set_timesteps(num_inference_steps=40) import open_clip clip_model, _, preprocess = open_clip.create_model_and_transforms( "ViT-B-32", pretrained="openai" ) clip_model.to(device) # Transforms to resize and augment an image + normalize to match CLIP's training data tfms = torchvision.transforms.Compose( [ torchvision.transforms.RandomResizedCrop(224), # Random CROP each time torchvision.transforms.RandomAffine( 5 ), # One possible random augmentation: skews the image torchvision.transforms.RandomHorizontalFlip(), # You can add additional augmentations if you like torchvision.transforms.Normalize( mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711), ), ] ) # And define a loss function that takes an image, embeds it and compares with # the text features of the prompt def clip_loss(image, text_features): image_features = clip_model.encode_image( tfms(image) ) # Note: applies the above transforms input_normed = torch.nn.functional.normalize(image_features.unsqueeze(1), dim=2) embed_normed = torch.nn.functional.normalize(text_features.unsqueeze(0), dim=2) dists = ( input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2) ) # Squared Great Circle Distance return dists.mean() import gradio as gr from PIL import Image def text2img(prompt, guidance_scale, n_cuts): print(f"prompt: {prompt}") # We embed a prompt with CLIP as our target text = open_clip.tokenize([prompt]).to(device) with torch.no_grad(), torch.cuda.amp.autocast(): text_features = clip_model.encode_text(text) x = torch.randn(1, 3, 256, 256).to(device) # RAM usage is high, you may want only 1 image at a time for i, t in tqdm(enumerate(scheduler.timesteps)): model_input = scheduler.scale_model_input(x, t) # predict the noise residual with torch.no_grad(): noise_pred = image_pipe.unet(model_input, t)["sample"] cond_grad = 0 for cut in range(int(n_cuts)): # Set requires grad on x x = x.detach().requires_grad_() # Get the predicted x0: x0 = scheduler.step(noise_pred, t, x).pred_original_sample # Calculate loss loss = clip_loss(x0, text_features) * guidance_scale # Get gradient (scale by n_cuts since we want the average) cond_grad -= torch.autograd.grad(loss, x)[0] / n_cuts # Modify x based on this gradient alpha_bar = scheduler.alphas_cumprod[i] x = x.detach() + cond_grad * alpha_bar.sqrt() # Note the additional scaling factor here! # Now step with scheduler x = scheduler.step(noise_pred, t, x).prev_sample grid = torchvision.utils.make_grid(x, nrow=1) im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5 im = Image.fromarray(np.array(im * 255).astype(np.uint8)) return im # See the gradio docs for the types of inputs and outputs available inputs = ["text", "number", "number"] outputs = gr.Image(label="text-guided image") # And the minimal interface demo = gr.Interface( fn=text2img, inputs=inputs, outputs=outputs, examples=[ ["Red Rose (still life), red flower painting", 10, 4], ["Blue sky, pure and bright, positive mood", 8, 5], # You can provide some example inputs to get people started ], ) demo.launch()