darshanjani commited on
Commit
bbb78cc
·
1 Parent(s): 9cd8c78

gradio app

Browse files
Files changed (1) hide show
  1. app.py +322 -0
app.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import sys
4
+ from base64 import b64encode
5
+
6
+ import numpy as np
7
+ import torch
8
+ from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
9
+
10
+ from matplotlib import pyplot as plt
11
+ from pathlib import Path
12
+ from PIL import Image
13
+ from torch import autocast
14
+ from torchvision import transforms as tfms
15
+ from tqdm.auto import tqdm
16
+ from transformers import CLIPTextModel, CLIPTokenizer, logging
17
+ import os
18
+ import cv2
19
+ import torchvision.transforms as T
20
+
21
+ torch.manual_seed(1)
22
+ logging.set_verbosity_error()
23
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+
26
+ # Load the autoencoder
27
+ vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder='vae')
28
+
29
+ # Load tokenizer and text encoder to tokenize and encode the text
30
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
31
+ text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
32
+
33
+ # Unet model for generating latents
34
+ unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder='unet')
35
+
36
+ # Noise scheduler
37
+ scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
38
+
39
+ # Move everything to GPU
40
+ vae = vae.to(torch_device)
41
+ text_encoder = text_encoder.to(torch_device)
42
+ unet = unet.to(torch_device)
43
+
44
+ def get_output_embeds(input_embeddings):
45
+ # CLIP's text model uses causal mask, so we prepare it here:
46
+ bsz, seq_len = input_embeddings.shape[:2]
47
+ causal_attention_mask = text_encoder.text_model._build_causal_attention_mask(bsz, seq_len, dtype=input_embeddings.dtype)
48
+
49
+ # Getting the output embeddings involves calling the model with passing output_hidden_states=True
50
+ # so that it doesn't just return the pooled final predictions:
51
+ encoder_outputs = text_encoder.text_model.encoder(
52
+ inputs_embeds=input_embeddings,
53
+ attention_mask=None, # We aren't using an attention mask so that can be None
54
+ causal_attention_mask=causal_attention_mask.to(torch_device),
55
+ output_attentions=None,
56
+ output_hidden_states=True, # We want the output embs not the final output
57
+ return_dict=None,
58
+ )
59
+
60
+ # We're interested in the output hidden state only
61
+ output = encoder_outputs[0]
62
+
63
+ # There is a final layer norm we need to pass these through
64
+ output = text_encoder.text_model.final_layer_norm(output)
65
+
66
+ # And now they're ready!
67
+ return output
68
+
69
+ # Prep Scheduler
70
+ def set_timesteps(scheduler, num_inference_steps):
71
+ scheduler.set_timesteps(num_inference_steps)
72
+ scheduler.timesteps = scheduler.timesteps.to(torch.float32) # minor fix to ensure MPS compatibility, fixed in diffusers PR 3925
73
+
74
+
75
+
76
+ style_files = ['learned_embeds_birb_style.bin','learned_embeds_cute_game_style.bin',
77
+ 'learned_embeds_manga_style.bin','learned_embeds_midjourney_style.bin','learned_embeds_space_style.bin']
78
+
79
+ seed_values = [8,16,50,80,128]
80
+ height = 512 # default height of Stable Diffusion
81
+ width = 512 # default width of Stable Diffusion
82
+ num_inference_steps = 5 # Number of denoising steps
83
+ guidance_scale = 7.5 # Scale for classifier-free guidance
84
+ num_styles = len(style_files)
85
+
86
+ def get_style_embeddings(style_file):
87
+ style_embed = torch.load(style_file)
88
+ style_name = list(style_embed.keys())[0]
89
+ return style_embed[style_name]
90
+
91
+ def get_EOS_pos_in_prompt(prompt):
92
+ return len(prompt.split())+1
93
+
94
+
95
+ import torch.nn.functional as F
96
+ """
97
+ def gradient_loss(images):
98
+ # Compute gradient magnitude using Sobel filters.
99
+ gradient_x = F.conv2d(images, torch.Tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).view(1, 1, 3, 3).to(images.device))
100
+ gradient_y = F.conv2d(images, torch.Tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]).view(1, 1, 3, 3).to(images.device))
101
+ gradient_magnitude = torch.sqrt(gradient_x**2 + gradient_y**2)
102
+ return gradient_magnitude.mean()
103
+ """
104
+
105
+ from torchvision.transforms import ToTensor
106
+ def pil_to_latent(input_im):
107
+ # Single image -> single latent in a batch (so size 1, 4, 64, 64)
108
+ with torch.no_grad():
109
+ latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
110
+ return 0.18215 * latent.latent_dist.sample()
111
+
112
+ def latents_to_pil(latents):
113
+ # bath of latents -> list of images
114
+ latents = (1 / 0.18215) * latents
115
+ with torch.no_grad():
116
+ image = vae.decode(latents).sample
117
+ image = (image / 2 + 0.5).clamp(0, 1)
118
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
119
+ images = (image * 255).round().astype("uint8")
120
+ pil_images = [Image.fromarray(image) for image in images]
121
+ return pil_images
122
+
123
+
124
+ def additional_guidance(latents, scheduler, noise_pred, t, sigma, custom_loss_fn, custom_loss_scale):
125
+ #### ADDITIONAL GUIDANCE ###
126
+ # Requires grad on the latents
127
+ latents = latents.detach().requires_grad_()
128
+
129
+ # Get the predicted x0:
130
+ latents_x0 = latents - sigma * noise_pred
131
+
132
+ # Decode to image space
133
+ denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
134
+
135
+ # Calculate loss
136
+ loss = custom_loss_fn(denoised_images) * custom_loss_scale
137
+
138
+ # Get gradient
139
+ cond_grad = torch.autograd.grad(loss, latents, allow_unused=False)[0]
140
+
141
+ # Modify the latents based on this gradient
142
+ latents = latents.detach() - cond_grad * sigma**2
143
+ return latents, loss
144
+
145
+
146
+ def generate_with_embs(text_embeddings, max_length, random_seed, loss_fn = None, custom_loss_scale=1.0):
147
+
148
+ height = 512 # default height of Stable Diffusion
149
+ width = 512 # default width of Stable Diffusion
150
+ num_inference_steps = 5 # Number of denoising steps
151
+ guidance_scale = 7.5 # Scale for classifier-free guidance
152
+
153
+ generator = torch.manual_seed(random_seed) # Seed generator to create the inital latent noise
154
+ batch_size = 1
155
+
156
+ uncond_input = tokenizer(
157
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
158
+ )
159
+ with torch.no_grad():
160
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
161
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
162
+
163
+ # Prep Scheduler
164
+ set_timesteps(scheduler, num_inference_steps)
165
+
166
+ # Prep latents
167
+ latents = torch.randn(
168
+ (batch_size, unet.in_channels, height // 8, width // 8),
169
+ generator=generator,
170
+ )
171
+ latents = latents.to(torch_device)
172
+ latents = latents * scheduler.init_noise_sigma
173
+
174
+ # Loop
175
+ for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
176
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
177
+ latent_model_input = torch.cat([latents] * 2)
178
+ sigma = scheduler.sigmas[i]
179
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
180
+
181
+ # predict the noise residual
182
+ with torch.no_grad():
183
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
184
+
185
+ # perform guidance
186
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
187
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
188
+ if loss_fn is not None:
189
+ if i%2 == 0:
190
+ latents, custom_loss = additional_guidance(latents, scheduler, noise_pred, t, sigma, loss_fn, custom_loss_scale)
191
+ print(i, 'loss:', custom_loss.item())
192
+
193
+ # compute the previous noisy sample x_t -> x_t-1
194
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
195
+
196
+ return latents_to_pil(latents)[0]
197
+
198
+ def generate_image_custom_style(prompt, style_num=None, random_seed=41, custom_loss_fn = None, custom_loss_scale=1.0):
199
+ eos_pos = get_EOS_pos_in_prompt(prompt)
200
+
201
+ style_token_embedding = None
202
+ if style_num:
203
+ style_token_embedding = get_style_embeddings(style_files[style_num])
204
+
205
+ # tokenize
206
+ text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
207
+ max_length = text_input.input_ids.shape[-1]
208
+ input_ids = text_input.input_ids.to(torch_device)
209
+
210
+ # get token embeddings
211
+ token_emb_layer = text_encoder.text_model.embeddings.token_embedding
212
+ token_embeddings = token_emb_layer(input_ids)
213
+
214
+ # Append style token towards the end of the sentence embeddings
215
+ if style_token_embedding is not None:
216
+ token_embeddings[-1, eos_pos, :] = style_token_embedding
217
+
218
+ # combine with pos embs
219
+ pos_emb_layer = text_encoder.text_model.embeddings.position_embedding
220
+ position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
221
+ position_embeddings = pos_emb_layer(position_ids)
222
+ input_embeddings = token_embeddings + position_embeddings
223
+
224
+ # Feed through to get final output embs
225
+ modified_output_embeddings = get_output_embeds(input_embeddings)
226
+
227
+ # And generate an image with this:
228
+ generated_image = generate_with_embs(modified_output_embeddings, max_length, random_seed, custom_loss_fn, custom_loss_scale)
229
+ return generated_image
230
+
231
+
232
+ def show_images(images_list):
233
+ # Let's visualize the four channels of this latent representation:
234
+ fig, axs = plt.subplots(1, len(images_list), figsize=(16, 4))
235
+ for c in range(len(images_list)):
236
+ axs[c].imshow(images_list[c])
237
+ plt.show()
238
+
239
+
240
+ def invert_loss(gen_image):
241
+ inverter = T.RandomInvert(p=1.0)
242
+ inverted_img = inverter(gen_image)
243
+ #loss = torch.abs(gen_image - inverted_img).sum()
244
+ loss = torch.nn.functional.mse_loss(gen_image[:,0], gen_image[:,2]) + torch.nn.functional.mse_loss(gen_image[:,2], gen_image[:,1]) + torch.nn.functional.mse_loss(gen_image[:,0], gen_image[:,1])
245
+ return loss
246
+
247
+ def contrast_loss(images):
248
+ # Calculate the variance of pixel values as a measure of contrast.
249
+ variance = torch.var(images)
250
+ return -variance
251
+
252
+ def blue_loss(images):
253
+ # How far are the blue channel values to 0.9:
254
+ error = torch.abs(images[:,2] - 0.9).mean() # [:,2] -> all images in batch, only the blue channel
255
+ return error
256
+
257
+
258
+ def display_images_in_rows(images_with_titles, titles):
259
+ num_images = len(images_with_titles)
260
+ rows = 5 # Display 5 rows always
261
+ columns = 1 if num_images == 5 else 2 # Use 1 column if there are 5 images, otherwise 2 columns
262
+ fig, axes = plt.subplots(rows, columns + 1, figsize=(15, 5 * rows)) # Add an extra column for titles
263
+
264
+ for r in range(rows):
265
+ # Add the title on the extreme left in the middle of each picture
266
+ axes[r, 0].text(0.5, 0.5, titles[r], ha='center', va='center')
267
+ axes[r, 0].axis('off')
268
+
269
+ # Add "Without Loss" label above the first column and "With Loss" label above the second column (if applicable)
270
+ if columns == 2:
271
+ axes[r, 1].set_title("Without Loss", pad=10)
272
+ axes[r, 2].set_title("With Loss", pad=10)
273
+
274
+ for c in range(1, columns + 1):
275
+ index = r * columns + c - 1
276
+ if index < num_images:
277
+ image, _ = images_with_titles[index]
278
+ axes[r, c].imshow(image)
279
+ axes[r, c].axis('off')
280
+
281
+ return fig
282
+ # plt.show()
283
+
284
+
285
+ def image_generator(prompt = "dog", loss_function=None):
286
+ images_without_loss = []
287
+ images_with_loss = []
288
+
289
+ for i in range(num_styles):
290
+ generated_img = generate_image_custom_style(prompt,style_num = i,random_seed = seed_values[i],custom_loss_fn = None)
291
+ images_without_loss.append(generated_img)
292
+ if loss_function:
293
+ generated_img = generate_image_custom_style(prompt,style_num = i,random_seed = seed_values[i],custom_loss_fn = loss_function)
294
+ images_with_loss.append(generated_img)
295
+
296
+ generated_sd_images = []
297
+ titles = ["Birb Style","Cute Game Style","Manga Style","Mid Journey Style","Space Style"]
298
+
299
+ for i in range(len(titles)):
300
+ generated_sd_images.append((images_without_loss[i], titles[i]))
301
+ if images_with_loss != []:
302
+ generated_sd_images.append((images_with_loss[i], titles[i]))
303
+
304
+ return display_images_in_rows(generated_sd_images, titles)
305
+
306
+ # Create a wrapper function for show_misclassified_images()
307
+ def image_generator_wrapper(prompt = "dog", loss_function=None):
308
+ if loss_function == "Yes":
309
+ loss_function = contrast_loss
310
+ else:
311
+ loss_function = None
312
+
313
+ return image_generator(prompt, loss_function)
314
+
315
+ description = 'Stable Diffusion is a generative artificial intelligence (generative AI) model that produces unique photorealistic images from text and image prompts.'
316
+ title = 'Image Generation using Stable Diffusion'
317
+
318
+ demo = gr.Interface(image_generator_wrapper,
319
+ inputs=[gr.Textbox(label="Enter prompt for generation", type="text", value="astronaut riding a cycle"),
320
+ gr.Radio(["Yes", "No"], value="No" , label="Apply Contrast Loss")],
321
+ outputs=gr.Plot(label="Generated Images"), title = "Stable Diffusion", description=description)
322
+ demo.launch()