Shivdutta commited on
Commit
4e5be15
·
verified ·
1 Parent(s): 3527874

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -10
app.py CHANGED
@@ -1,5 +1,5 @@
1
  from base64 import b64encode
2
-
3
  import numpy
4
  import torch
5
  from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
@@ -16,8 +16,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, logging
16
  import os
17
  import numpy as np
18
 
19
- torch.manual_seed(1)
20
- # if not (Path.home()/'.cache/huggingface'/'token').exists(): notebook_login()
21
 
22
  # Supress some unnecessary warnings when loading the CLIPTextModel
23
  logging.set_verbosity_error()
@@ -172,15 +171,47 @@ def generate_with_prompt_style(prompt, style, seed = 42):
172
  # And generate an image with this:
173
  return generate_with_embs(modified_output_embeddings, text_input, seed)
174
 
 
 
 
175
 
176
- import torch
177
-
178
- # def contrast_loss(images):
179
- # variance = torch.var(images)
180
- # return -variance
 
181
 
 
 
 
 
 
 
182
 
183
- import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
  def blue_loss(images):
186
  """
@@ -294,7 +325,7 @@ def generate_with_prompt_style_guidance(prompt, style, seed=42):
294
  denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
295
 
296
  # Calculate loss
297
- loss = blue_loss(denoised_images) * contrast_loss_scale
298
 
299
  # # Occasionally print it out
300
  # if i%10==0:
 
1
  from base64 import b64encode
2
+ import torch
3
  import numpy
4
  import torch
5
  from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
 
16
  import os
17
  import numpy as np
18
 
19
+ torch.manual_seed(24041975)
 
20
 
21
  # Supress some unnecessary warnings when loading the CLIPTextModel
22
  logging.set_verbosity_error()
 
171
  # And generate an image with this:
172
  return generate_with_embs(modified_output_embeddings, text_input, seed)
173
 
174
+ def contrast_loss(images):
175
+ variance = torch.var(images)
176
+ return -variance
177
 
178
+ def blue_loss_variant(images, use_mean=False, alpha=1.0):
179
+ """
180
+ Computes the blue loss for a batch of images with an optional mean component.
181
+
182
+ The blue loss is defined as the negative variance of the blue channel's pixel values.
183
+ Optionally, it can also include the mean value of the blue channel.
184
 
185
+ Parameters:
186
+ images (torch.Tensor): A batch of images. Expected shape is (N, C, H, W) where
187
+ N is the batch size, C is the number of channels (3 for RGB),
188
+ H is the height, and W is the width.
189
+ use_mean (bool): If True, includes the mean of the blue channel in the loss calculation.
190
+ alpha (float): Weighting factor for the mean component when use_mean is True.
191
 
192
+ Returns:
193
+ torch.Tensor: The blue loss, which is the negative variance of the blue channel's pixel values,
194
+ optionally combined with the mean value of the blue channel.
195
+ """
196
+ # Ensure the input tensor has the correct shape
197
+ if images.shape[1] != 3:
198
+ raise ValueError("Expected images with 3 channels (RGB), but got shape {}".format(images.shape))
199
+
200
+ # Extract the blue channel (assuming the channels are in RGB order)
201
+ blue_channel = images[:, 2, :, :]
202
+
203
+ # Calculate the variance of the blue channel
204
+ variance = torch.var(blue_channel)
205
+
206
+ if use_mean:
207
+ # Calculate the mean of the blue channel
208
+ mean = torch.mean(blue_channel)
209
+ # Combine variance and mean into the loss
210
+ loss = -variance + alpha * mean
211
+ else:
212
+ loss = -variance
213
+
214
+ return loss
215
 
216
  def blue_loss(images):
217
  """
 
325
  denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
326
 
327
  # Calculate loss
328
+ loss = blue_loss_variant(denoised_images) * contrast_loss_scale
329
 
330
  # # Occasionally print it out
331
  # if i%10==0: