File size: 4,103 Bytes
cef4750
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import numpy as np

import torch
import torch.nn.functional as F
import torchvision

from datasets import load_dataset
from diffusers import DDIMScheduler, DDPMPipeline

from matplotlib import pyplot as plt

from PIL import Image
from torchvision import transforms

from tqdm.auto import tqdm


device = "cpu"


# Load the pretrained pipeline
pipeline_name = "johnowhitaker/sd-class-wikiart-from-bedrooms"
image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device)

# Sample some images with a DDIM Scheduler over 40 steps
scheduler = DDIMScheduler.from_pretrained(pipeline_name)
scheduler.set_timesteps(num_inference_steps=40)


import open_clip


clip_model, _, preprocess = open_clip.create_model_and_transforms(
    "ViT-B-32", pretrained="openai"
)
clip_model.to(device)

# Transforms to resize and augment an image + normalize to match CLIP's training data
tfms = torchvision.transforms.Compose(
    [
        torchvision.transforms.RandomResizedCrop(224),  # Random CROP each time
        torchvision.transforms.RandomAffine(
            5
        ),  # One possible random augmentation: skews the image
        torchvision.transforms.RandomHorizontalFlip(),  # You can add additional augmentations if you like
        torchvision.transforms.Normalize(
            mean=(0.48145466, 0.4578275, 0.40821073),
            std=(0.26862954, 0.26130258, 0.27577711),
        ),
    ]
)

# And define a loss function that takes an image, embeds it and compares with
# the text features of the prompt
def clip_loss(image, text_features):
    image_features = clip_model.encode_image(
        tfms(image)
    )  # Note: applies the above transforms
    input_normed = torch.nn.functional.normalize(image_features.unsqueeze(1), dim=2)
    embed_normed = torch.nn.functional.normalize(text_features.unsqueeze(0), dim=2)
    dists = (
        input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
    )  # Squared Great Circle Distance
    return dists.mean()


import gradio as gr
from PIL import Image


def text2img(prompt, guidance_scale, n_cuts):
    print(f"prompt: {prompt}")

    # We embed a prompt with CLIP as our target
    text = open_clip.tokenize([prompt]).to(device)
    with torch.no_grad(), torch.cuda.amp.autocast():
        text_features = clip_model.encode_text(text)
    
    x = torch.randn(1, 3, 256, 256).to(device)  # RAM usage is high, you may want only 1 image at a time
    for i, t in tqdm(enumerate(scheduler.timesteps)):
        model_input = scheduler.scale_model_input(x, t)
        # predict the noise residual
        with torch.no_grad():
            noise_pred = image_pipe.unet(model_input, t)["sample"]

        cond_grad = 0
        for cut in range(int(n_cuts)):
            # Set requires grad on x
            x = x.detach().requires_grad_()
            # Get the predicted x0:
            x0 = scheduler.step(noise_pred, t, x).pred_original_sample
            # Calculate loss
            loss = clip_loss(x0, text_features) * guidance_scale
            # Get gradient (scale by n_cuts since we want the average)
            cond_grad -= torch.autograd.grad(loss, x)[0] / n_cuts

        # Modify x based on this gradient
        alpha_bar = scheduler.alphas_cumprod[i]
        x = x.detach() + cond_grad * alpha_bar.sqrt()  # Note the additional scaling factor here!

        # Now step with scheduler
        x = scheduler.step(noise_pred, t, x).prev_sample
    
    grid = torchvision.utils.make_grid(x, nrow=1)
    im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5
    im = Image.fromarray(np.array(im * 255).astype(np.uint8))

    return im


# See the gradio docs for the types of inputs and outputs available
inputs = ["text", "number", "number"]
outputs = gr.Image(label="text-guided image")

# And the minimal interface
demo = gr.Interface(
    fn=text2img,
    inputs=inputs,
    outputs=outputs,
    examples=[
        ["Red Rose (still life), red flower painting", 10, 4],
        ["Blue sky, pure and bright, positive mood", 8, 5],  # You can provide some example inputs to get people started
    ],
)
demo.launch()