Shivdutta commited on
Commit
a00c054
·
verified ·
1 Parent(s): d19ba9e

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -3
app.py CHANGED
@@ -175,6 +175,126 @@ def contrast_loss(images):
175
  variance = torch.var(images)
176
  return -variance
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  def generate_with_prompt_style_guidance(prompt, style, seed,num_inference_steps,guidance_scale):
179
 
180
  prompt = prompt + ' in style of s'
@@ -260,7 +380,7 @@ def generate_with_prompt_style_guidance(prompt, style, seed,num_inference_steps,
260
  denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
261
 
262
  # Calculate loss
263
- loss = contrast_loss(denoised_images) * contrast_loss_scale
264
 
265
  # # Occasionally print it out
266
  # if i%10==0:
@@ -291,9 +411,10 @@ dict_styles = {
291
  'Oil Painting':'styles/learned_embeds_oil.bin',
292
  }
293
 
294
- def inference(prompt, seed, style,num_inference_steps,guidance_scale):
295
 
296
  if prompt is not None and style is not None and seed is not None:
 
297
  style = dict_styles[style]
298
  torch.manual_seed(seed)
299
  result = generate_with_prompt_style_guidance(prompt, style,seed,num_inference_steps,guidance_scale)
@@ -323,7 +444,7 @@ demo = gr.Interface(inference,
323
  step=8,
324
  label="Select Guidance Scale",
325
  interactive=True,
326
- )
327
  ],
328
  outputs = [
329
  gr.Image(label="Stable Diffusion Output"),
 
175
  variance = torch.var(images)
176
  return -variance
177
 
178
+
179
+ def blue_loss(images):
180
+ """
181
+ Computes the blue loss for a batch of images.
182
+
183
+ The blue loss is defined as the negative variance of the blue channel's pixel values.
184
+
185
+ Parameters:
186
+ images (torch.Tensor): A batch of images. Expected shape is (N, C, H, W) where
187
+ N is the batch size, C is the number of channels (3 for RGB),
188
+ H is the height, and W is the width.
189
+
190
+ Returns:
191
+ torch.Tensor: The blue loss, which is the negative variance of the blue channel's pixel values.
192
+ """
193
+ # Ensure the input tensor has the correct shape
194
+ if images.shape[1] != 3:
195
+ raise ValueError("Expected images with 3 channels (RGB), but got shape {}".format(images.shape))
196
+
197
+ # Extract the blue channel (assuming the channels are in RGB order)
198
+ blue_channel = images[:, 2, :, :]
199
+
200
+ # Calculate the variance of the blue channel
201
+ variance = torch.var(blue_channel)
202
+
203
+ return -variance
204
+
205
+ def ymca_loss(images, weights=(1.0, 1.0, 1.0, 1.0)):
206
+ """
207
+ Computes the YMCA loss for a batch of images.
208
+
209
+ The YMCA loss is a custom loss function combining the mean value of the Y (luminance) channel,
210
+ the mean value of the M (magenta) channel, the variance of the C (cyan) channel, and the
211
+ absolute sum of the A (alpha) channel.
212
+
213
+ Parameters:
214
+ images (torch.Tensor): A batch of images. Expected shape is (N, C, H, W) where
215
+ N is the batch size, C is the number of channels (assumed 4 for RGBA),
216
+ H is the height, and W is the width.
217
+ weights (tuple): A tuple of four floats representing the weights for each component of the loss
218
+ (default is (1.0, 1.0, 1.0, 1.0)).
219
+
220
+ Returns:
221
+ torch.Tensor: The YMCA loss, combining the specified components.
222
+ """
223
+ # Ensure the input tensor has the correct shape
224
+ if images.shape[1] != 4:
225
+ raise ValueError("Expected images with 4 channels (RGBA), but got shape {}".format(images.shape))
226
+
227
+ # Extract the RGBA channels
228
+ R = images[:, 0, :, :]
229
+ G = images[:, 1, :, :]
230
+ B = images[:, 2, :, :]
231
+ A = images[:, 3, :, :]
232
+
233
+ # Convert RGB to Y (luminance) channel
234
+ Y = 0.299 * R + 0.587 * G + 0.114 * B
235
+
236
+ # Convert RGB to M (magenta) channel
237
+ M = 1 - G
238
+
239
+ # Convert RGB to C (cyan) channel
240
+ C = 1 - R
241
+
242
+ # Compute the mean of the Y channel
243
+ mean_Y = torch.mean(Y)
244
+
245
+ # Compute the mean of the M channel
246
+ mean_M = torch.mean(M)
247
+
248
+ # Compute the variance of the C channel
249
+ variance_C = torch.var(C)
250
+
251
+ # Compute the absolute sum of the A channel
252
+ abs_sum_A = torch.sum(torch.abs(A))
253
+
254
+ # Combine the components with the given weights
255
+ loss = (weights[0] * mean_Y) + (weights[1] * mean_M) - (weights[2] * variance_C) + (weights[3] * abs_sum_A)
256
+
257
+ return loss
258
+
259
+
260
+ def blue_loss_variant(images, use_mean=False, alpha=1.0):
261
+ """
262
+ Computes the blue loss for a batch of images with an optional mean component.
263
+
264
+ The blue loss is defined as the negative variance of the blue channel's pixel values.
265
+ Optionally, it can also include the mean value of the blue channel.
266
+
267
+ Parameters:
268
+ images (torch.Tensor): A batch of images. Expected shape is (N, C, H, W) where
269
+ N is the batch size, C is the number of channels (3 for RGB),
270
+ H is the height, and W is the width.
271
+ use_mean (bool): If True, includes the mean of the blue channel in the loss calculation.
272
+ alpha (float): Weighting factor for the mean component when use_mean is True.
273
+
274
+ Returns:
275
+ torch.Tensor: The blue loss, which is the negative variance of the blue channel's pixel values,
276
+ optionally combined with the mean value of the blue channel.
277
+ """
278
+ # Ensure the input tensor has the correct shape
279
+ if images.shape[1] != 3:
280
+ raise ValueError("Expected images with 3 channels (RGB), but got shape {}".format(images.shape))
281
+
282
+ # Extract the blue channel (assuming the channels are in RGB order)
283
+ blue_channel = images[:, 2, :, :]
284
+
285
+ # Calculate the variance of the blue channel
286
+ variance = torch.var(blue_channel)
287
+
288
+ if use_mean:
289
+ # Calculate the mean of the blue channel
290
+ mean = torch.mean(blue_channel)
291
+ # Combine variance and mean into the loss
292
+ loss = -variance + alpha * mean
293
+ else:
294
+ loss = -variance
295
+
296
+ return loss
297
+
298
  def generate_with_prompt_style_guidance(prompt, style, seed,num_inference_steps,guidance_scale):
299
 
300
  prompt = prompt + ' in style of s'
 
380
  denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
381
 
382
  # Calculate loss
383
+ loss = ymca_loss(denoised_images) * contrast_loss_scale
384
 
385
  # # Occasionally print it out
386
  # if i%10==0:
 
411
  'Oil Painting':'styles/learned_embeds_oil.bin',
412
  }
413
 
414
+ def inference(prompt, seed, style,num_inference_steps,guidance_scale,loss_function):
415
 
416
  if prompt is not None and style is not None and seed is not None:
417
+ print(loss_function)
418
  style = dict_styles[style]
419
  torch.manual_seed(seed)
420
  result = generate_with_prompt_style_guidance(prompt, style,seed,num_inference_steps,guidance_scale)
 
444
  step=8,
445
  label="Select Guidance Scale",
446
  interactive=True,
447
+ ),gr.Radio(["contrast", "blue-original", "blue-modified","ymca_loss"], label="loss-function", info="loss-function"),
448
  ],
449
  outputs = [
450
  gr.Image(label="Stable Diffusion Output"),