Shivdutta commited on
Commit
6798689
·
verified ·
1 Parent(s): 4e5be15

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -75
app.py CHANGED
@@ -1,9 +1,9 @@
1
  from base64 import b64encode
2
- import torch
3
  import numpy
4
  import torch
5
  from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
6
  from huggingface_hub import notebook_login
 
7
 
8
  # For video display:
9
  from matplotlib import pyplot as plt
@@ -16,7 +16,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, logging
16
  import os
17
  import numpy as np
18
 
19
- torch.manual_seed(24041975)
20
 
21
  # Supress some unnecessary warnings when loading the CLIPTextModel
22
  logging.set_verbosity_error()
@@ -145,7 +145,7 @@ def generate_with_embs(text_embeddings, text_input, seed):
145
  return latents_to_pil(latents)[0]
146
 
147
 
148
- def generate_with_prompt_style(prompt, style, seed = 42):
149
 
150
  prompt = prompt + ' in style of s'
151
  embed = torch.load(style)
@@ -175,72 +175,7 @@ def contrast_loss(images):
175
  variance = torch.var(images)
176
  return -variance
177
 
178
- def blue_loss_variant(images, use_mean=False, alpha=1.0):
179
- """
180
- Computes the blue loss for a batch of images with an optional mean component.
181
-
182
- The blue loss is defined as the negative variance of the blue channel's pixel values.
183
- Optionally, it can also include the mean value of the blue channel.
184
-
185
- Parameters:
186
- images (torch.Tensor): A batch of images. Expected shape is (N, C, H, W) where
187
- N is the batch size, C is the number of channels (3 for RGB),
188
- H is the height, and W is the width.
189
- use_mean (bool): If True, includes the mean of the blue channel in the loss calculation.
190
- alpha (float): Weighting factor for the mean component when use_mean is True.
191
-
192
- Returns:
193
- torch.Tensor: The blue loss, which is the negative variance of the blue channel's pixel values,
194
- optionally combined with the mean value of the blue channel.
195
- """
196
- # Ensure the input tensor has the correct shape
197
- if images.shape[1] != 3:
198
- raise ValueError("Expected images with 3 channels (RGB), but got shape {}".format(images.shape))
199
-
200
- # Extract the blue channel (assuming the channels are in RGB order)
201
- blue_channel = images[:, 2, :, :]
202
-
203
- # Calculate the variance of the blue channel
204
- variance = torch.var(blue_channel)
205
-
206
- if use_mean:
207
- # Calculate the mean of the blue channel
208
- mean = torch.mean(blue_channel)
209
- # Combine variance and mean into the loss
210
- loss = -variance + alpha * mean
211
- else:
212
- loss = -variance
213
-
214
- return loss
215
-
216
- def blue_loss(images):
217
- """
218
- Computes the blue loss for a batch of images.
219
-
220
- The blue loss is defined as the negative variance of the blue channel's pixel values.
221
-
222
- Parameters:
223
- images (torch.Tensor): A batch of images. Expected shape is (N, C, H, W) where
224
- N is the batch size, C is the number of channels (3 for RGB),
225
- H is the height, and W is the width.
226
-
227
- Returns:
228
- torch.Tensor: The blue loss, which is the negative variance of the blue channel's pixel values.
229
- """
230
- # Ensure the input tensor has the correct shape
231
- if images.shape[1] != 3:
232
- raise ValueError("Expected images with 3 channels (RGB), but got shape {}".format(images.shape))
233
-
234
- # Extract the blue channel (assuming the channels are in RGB order)
235
- blue_channel = images[:, 2, :, :]
236
-
237
- # Calculate the variance of the blue channel
238
- variance = torch.var(blue_channel)
239
-
240
- return -variance
241
-
242
-
243
- def generate_with_prompt_style_guidance(prompt, style, seed=42):
244
 
245
  prompt = prompt + ' in style of s'
246
 
@@ -325,7 +260,7 @@ def generate_with_prompt_style_guidance(prompt, style, seed=42):
325
  denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
326
 
327
  # Calculate loss
328
- loss = blue_loss_variant(denoised_images) * contrast_loss_scale
329
 
330
  # # Occasionally print it out
331
  # if i%10==0:
@@ -344,7 +279,7 @@ def generate_with_prompt_style_guidance(prompt, style, seed=42):
344
  return latents_to_pil(latents)[0]
345
 
346
 
347
- import gradio as gr
348
 
349
  dict_styles = {
350
  'Dr Strange': 'styles/learned_embeds_dr_strange.bin',
@@ -354,11 +289,12 @@ dict_styles = {
354
  }
355
  # dict_styles.keys()
356
 
357
- def inference(prompt, style):
358
 
359
- if prompt is not None and style is not None:
360
  style = dict_styles[style]
361
- result = generate_with_prompt_style_guidance(prompt, style)
 
362
  return np.array(result)
363
  else:
364
  return None
@@ -369,6 +305,7 @@ examples = [['A man sipping wine wearing a spacesuit on the moon', 'Stripes']]
369
 
370
  demo = gr.Interface(inference,
371
  inputs = [gr.Textbox(label='Prompt'),
 
372
  gr.Dropdown(['Dr Strange', 'GTA-5',
373
  'Manga', 'Pokemon'], label='Style')
374
  ],
@@ -377,7 +314,8 @@ demo = gr.Interface(inference,
377
  ],
378
  title = title,
379
  description = description,
380
- # examples = examples,
381
  # cache_examples=True
382
  )
383
  demo.launch()
 
 
1
  from base64 import b64encode
 
2
  import numpy
3
  import torch
4
  from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
5
  from huggingface_hub import notebook_login
6
+ import gradio as gr
7
 
8
  # For video display:
9
  from matplotlib import pyplot as plt
 
16
  import os
17
  import numpy as np
18
 
19
+
20
 
21
  # Supress some unnecessary warnings when loading the CLIPTextModel
22
  logging.set_verbosity_error()
 
145
  return latents_to_pil(latents)[0]
146
 
147
 
148
+ def generate_with_prompt_style(prompt, style, seed):
149
 
150
  prompt = prompt + ' in style of s'
151
  embed = torch.load(style)
 
175
  variance = torch.var(images)
176
  return -variance
177
 
178
+ def generate_with_prompt_style_guidance(prompt, style, seed):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
  prompt = prompt + ' in style of s'
181
 
 
260
  denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
261
 
262
  # Calculate loss
263
+ loss = contrast_loss(denoised_images) * contrast_loss_scale
264
 
265
  # # Occasionally print it out
266
  # if i%10==0:
 
279
  return latents_to_pil(latents)[0]
280
 
281
 
282
+
283
 
284
  dict_styles = {
285
  'Dr Strange': 'styles/learned_embeds_dr_strange.bin',
 
289
  }
290
  # dict_styles.keys()
291
 
292
+ def inference(prompt, seed,style):
293
 
294
+ if prompt is not None and style is not None and seed is not None:
295
  style = dict_styles[style]
296
+ torch.manual_seed(seed)
297
+ result = generate_with_prompt_style_guidance(prompt, style,seed)
298
  return np.array(result)
299
  else:
300
  return None
 
305
 
306
  demo = gr.Interface(inference,
307
  inputs = [gr.Textbox(label='Prompt'),
308
+ gr.Textbox(label='Seed', value='24041975'),
309
  gr.Dropdown(['Dr Strange', 'GTA-5',
310
  'Manga', 'Pokemon'], label='Style')
311
  ],
 
314
  ],
315
  title = title,
316
  description = description,
317
+ examples = examples,
318
  # cache_examples=True
319
  )
320
  demo.launch()
321
+