Shivdutta commited on
Commit
9fe3ab5
·
verified ·
1 Parent(s): 9337350

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -5
app.py CHANGED
@@ -202,7 +202,6 @@ def blue_loss(images):
202
 
203
  return -variance
204
 
205
- import torch
206
 
207
  def ymca_loss(images, weights=(1.0, 1.0, 1.0, 1.0)):
208
  """
@@ -263,6 +262,82 @@ def ymca_loss(images, weights=(1.0, 1.0, 1.0, 1.0)):
263
  return loss
264
 
265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  def blue_loss_variant(images, use_mean=False, alpha=1.0):
267
  """
268
  Computes the blue loss for a batch of images with an optional mean component.
@@ -301,7 +376,7 @@ def blue_loss_variant(images, use_mean=False, alpha=1.0):
301
 
302
  return loss
303
 
304
- def generate_with_prompt_style_guidance(prompt, style, seed,num_inference_steps,guidance_scale):
305
 
306
  prompt = prompt + ' in style of s'
307
 
@@ -386,7 +461,19 @@ def generate_with_prompt_style_guidance(prompt, style, seed,num_inference_steps,
386
  denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
387
 
388
  # Calculate loss
389
- loss = ymca_loss(denoised_images) * contrast_loss_scale
 
 
 
 
 
 
 
 
 
 
 
 
390
 
391
  # # Occasionally print it out
392
  # if i%10==0:
@@ -423,7 +510,7 @@ def inference(prompt, seed, style,num_inference_steps,guidance_scale,loss_functi
423
  print(loss_function)
424
  style = dict_styles[style]
425
  torch.manual_seed(seed)
426
- result = generate_with_prompt_style_guidance(prompt, style,seed,num_inference_steps,guidance_scale)
427
  return np.array(result)
428
  else:
429
  return None
@@ -450,7 +537,7 @@ demo = gr.Interface(inference,
450
  step=8,
451
  label="Select Guidance Scale",
452
  interactive=True,
453
- ),gr.Radio(["contrast", "blue-original", "blue-modified","ymca_loss"], label="loss-function", info="loss-function"),
454
  ],
455
  outputs = [
456
  gr.Image(label="Stable Diffusion Output"),
 
202
 
203
  return -variance
204
 
 
205
 
206
  def ymca_loss(images, weights=(1.0, 1.0, 1.0, 1.0)):
207
  """
 
262
  return loss
263
 
264
 
265
+
266
+ def rgb_to_cmyk(images):
267
+ """
268
+ Converts an RGB image tensor to CMYK.
269
+
270
+ Parameters:
271
+ images (torch.Tensor): A batch of images in RGB format. Expected shape is (N, 3, H, W).
272
+
273
+ Returns:
274
+ torch.Tensor: A tensor containing the CMYK channels.
275
+ """
276
+ R = images[:, 0, :, :]
277
+ G = images[:, 1, :, :]
278
+ B = images[:, 2, :, :]
279
+
280
+ # Convert RGB to CMY
281
+ C = 1 - R
282
+ M = 1 - G
283
+ Y = 1 - B
284
+
285
+ # Convert CMY to CMYK
286
+ K = torch.min(torch.min(C, M), Y)
287
+ C = (C - K) / (1 - K + 1e-8)
288
+ M = (M - K) / (1 - K + 1e-8)
289
+ Y = (Y - K) / (1 - K + 1e-8)
290
+
291
+ CMYK = torch.stack([C, M, Y, K], dim=1)
292
+ return CMYK
293
+
294
+ def cymk_loss(images, weights=(1.0, 1.0, 1.0, 1.0)):
295
+ """
296
+ Computes the CYMK loss for a batch of images.
297
+
298
+ The CYMK loss is a custom loss function combining the variance of the Cyan channel,
299
+ the mean value of the Yellow channel, the variance of the Magenta channel, and the
300
+ absolute sum of the Black channel.
301
+
302
+ Parameters:
303
+ images (torch.Tensor): A batch of images. Expected shape is (N, 3, H, W) for RGB input.
304
+ weights (tuple): A tuple of four floats representing the weights for each component of the loss
305
+ (default is (1.0, 1.0, 1.0, 1.0)).
306
+
307
+ Returns:
308
+ torch.Tensor: The CYMK loss, combining the specified components.
309
+ """
310
+ # Ensure the input tensor has the correct shape
311
+ if images.shape[1] != 3:
312
+ raise ValueError("Expected images with 3 channels (RGB), but got shape {}".format(images.shape))
313
+
314
+ # Convert RGB to CMYK
315
+ cmyk_images = rgb_to_cmyk(images)
316
+
317
+ # Extract CMYK channels
318
+ C = cmyk_images[:, 0, :, :]
319
+ M = cmyk_images[:, 1, :, :]
320
+ Y = cmyk_images[:, 2, :, :]
321
+ K = cmyk_images[:, 3, :, :]
322
+
323
+ # Compute the variance of the C channel
324
+ variance_C = torch.var(C)
325
+
326
+ # Compute the mean of the Y channel
327
+ mean_Y = torch.mean(Y)
328
+
329
+ # Compute the variance of the M channel
330
+ variance_M = torch.var(M)
331
+
332
+ # Compute the absolute sum of the K channel
333
+ abs_sum_K = torch.sum(torch.abs(K))
334
+
335
+ # Combine the components with the given weights
336
+ loss = (weights[0] * variance_C) + (weights[1] * mean_Y) + (weights[2] * variance_M) + (weights[3] * abs_sum_K)
337
+
338
+ return loss
339
+
340
+
341
  def blue_loss_variant(images, use_mean=False, alpha=1.0):
342
  """
343
  Computes the blue loss for a batch of images with an optional mean component.
 
376
 
377
  return loss
378
 
379
+ def generate_with_prompt_style_guidance(prompt, style, seed,num_inference_steps,guidance_scale,loss_function):
380
 
381
  prompt = prompt + ' in style of s'
382
 
 
461
  denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
462
 
463
  # Calculate loss
464
+ # "contrast", "blue_original", "blue_modified","ymca_loss","cymk_loss"
465
+ if loss_function == "contrast":
466
+ loss = contrast_loss(denoised_images) * contrast_loss_scale
467
+ elif loss_function == "blue_original":
468
+ loss = blue_loss(denoised_images) * contrast_loss_scale
469
+ elif loss_function == "blue_modified":
470
+ loss = blue_loss_variant(denoised_images) * contrast_loss_scale
471
+ elif loss_function == "ymca_loss":
472
+ loss = ymca_loss(denoised_images) * contrast_loss_scale
473
+ elif loss_function == "cymk_loss":
474
+ loss = cymk_loss(denoised_images) * contrast_loss_scale
475
+ else :
476
+ loss = ymca_loss(denoised_images) * contrast_loss_scale
477
 
478
  # # Occasionally print it out
479
  # if i%10==0:
 
510
  print(loss_function)
511
  style = dict_styles[style]
512
  torch.manual_seed(seed)
513
+ result = generate_with_prompt_style_guidance(prompt, style,seed,num_inference_steps,guidance_scale,loss_function)
514
  return np.array(result)
515
  else:
516
  return None
 
537
  step=8,
538
  label="Select Guidance Scale",
539
  interactive=True,
540
+ ),gr.Radio(["contrast", "blue_original", "blue_modified","ymca_loss","cymk_loss"], label="loss-function", info="loss-function" , value="ymca_loss"),
541
  ],
542
  outputs = [
543
  gr.Image(label="Stable Diffusion Output"),