Shivdutta commited on
Commit
3527874
·
verified ·
1 Parent(s): 723c4a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -3
app.py CHANGED
@@ -175,10 +175,40 @@ def generate_with_prompt_style(prompt, style, seed = 42):
175
 
176
  import torch
177
 
178
- def contrast_loss(images):
179
- variance = torch.var(images)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  return -variance
181
 
 
182
  def generate_with_prompt_style_guidance(prompt, style, seed=42):
183
 
184
  prompt = prompt + ' in style of s'
@@ -264,7 +294,7 @@ def generate_with_prompt_style_guidance(prompt, style, seed=42):
264
  denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
265
 
266
  # Calculate loss
267
- loss = contrast_loss(denoised_images) * contrast_loss_scale
268
 
269
  # # Occasionally print it out
270
  # if i%10==0:
 
175
 
176
  import torch
177
 
178
+ # def contrast_loss(images):
179
+ # variance = torch.var(images)
180
+ # return -variance
181
+
182
+
183
+ import torch
184
+
185
+ def blue_loss(images):
186
+ """
187
+ Computes the blue loss for a batch of images.
188
+
189
+ The blue loss is defined as the negative variance of the blue channel's pixel values.
190
+
191
+ Parameters:
192
+ images (torch.Tensor): A batch of images. Expected shape is (N, C, H, W) where
193
+ N is the batch size, C is the number of channels (3 for RGB),
194
+ H is the height, and W is the width.
195
+
196
+ Returns:
197
+ torch.Tensor: The blue loss, which is the negative variance of the blue channel's pixel values.
198
+ """
199
+ # Ensure the input tensor has the correct shape
200
+ if images.shape[1] != 3:
201
+ raise ValueError("Expected images with 3 channels (RGB), but got shape {}".format(images.shape))
202
+
203
+ # Extract the blue channel (assuming the channels are in RGB order)
204
+ blue_channel = images[:, 2, :, :]
205
+
206
+ # Calculate the variance of the blue channel
207
+ variance = torch.var(blue_channel)
208
+
209
  return -variance
210
 
211
+
212
  def generate_with_prompt_style_guidance(prompt, style, seed=42):
213
 
214
  prompt = prompt + ' in style of s'
 
294
  denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
295
 
296
  # Calculate loss
297
+ loss = blue_loss(denoised_images) * contrast_loss_scale
298
 
299
  # # Occasionally print it out
300
  # if i%10==0: