Shivdutta commited on
Commit
9337350
·
verified ·
1 Parent(s): a00c054

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -11
app.py CHANGED
@@ -202,17 +202,19 @@ def blue_loss(images):
202
 
203
  return -variance
204
 
 
 
205
  def ymca_loss(images, weights=(1.0, 1.0, 1.0, 1.0)):
206
  """
207
  Computes the YMCA loss for a batch of images.
208
 
209
  The YMCA loss is a custom loss function combining the mean value of the Y (luminance) channel,
210
  the mean value of the M (magenta) channel, the variance of the C (cyan) channel, and the
211
- absolute sum of the A (alpha) channel.
212
 
213
  Parameters:
214
  images (torch.Tensor): A batch of images. Expected shape is (N, C, H, W) where
215
- N is the batch size, C is the number of channels (assumed 4 for RGBA),
216
  H is the height, and W is the width.
217
  weights (tuple): A tuple of four floats representing the weights for each component of the loss
218
  (default is (1.0, 1.0, 1.0, 1.0)).
@@ -220,15 +222,15 @@ def ymca_loss(images, weights=(1.0, 1.0, 1.0, 1.0)):
220
  Returns:
221
  torch.Tensor: The YMCA loss, combining the specified components.
222
  """
223
- # Ensure the input tensor has the correct shape
224
- if images.shape[1] != 4:
225
- raise ValueError("Expected images with 4 channels (RGBA), but got shape {}".format(images.shape))
 
226
 
227
- # Extract the RGBA channels
228
  R = images[:, 0, :, :]
229
  G = images[:, 1, :, :]
230
  B = images[:, 2, :, :]
231
- A = images[:, 3, :, :]
232
 
233
  # Convert RGB to Y (luminance) channel
234
  Y = 0.299 * R + 0.587 * G + 0.114 * B
@@ -248,11 +250,15 @@ def ymca_loss(images, weights=(1.0, 1.0, 1.0, 1.0)):
248
  # Compute the variance of the C channel
249
  variance_C = torch.var(C)
250
 
251
- # Compute the absolute sum of the A channel
252
- abs_sum_A = torch.sum(torch.abs(A))
253
 
254
- # Combine the components with the given weights
255
- loss = (weights[0] * mean_Y) + (weights[1] * mean_M) - (weights[2] * variance_C) + (weights[3] * abs_sum_A)
 
 
 
 
 
256
 
257
  return loss
258
 
 
202
 
203
  return -variance
204
 
205
+ import torch
206
+
207
  def ymca_loss(images, weights=(1.0, 1.0, 1.0, 1.0)):
208
  """
209
  Computes the YMCA loss for a batch of images.
210
 
211
  The YMCA loss is a custom loss function combining the mean value of the Y (luminance) channel,
212
  the mean value of the M (magenta) channel, the variance of the C (cyan) channel, and the
213
+ absolute sum of the A (alpha) channel if present.
214
 
215
  Parameters:
216
  images (torch.Tensor): A batch of images. Expected shape is (N, C, H, W) where
217
+ N is the batch size, C is the number of channels (3 for RGB or 4 for RGBA),
218
  H is the height, and W is the width.
219
  weights (tuple): A tuple of four floats representing the weights for each component of the loss
220
  (default is (1.0, 1.0, 1.0, 1.0)).
 
222
  Returns:
223
  torch.Tensor: The YMCA loss, combining the specified components.
224
  """
225
+ num_channels = images.shape[1]
226
+
227
+ if num_channels not in [3, 4]:
228
+ raise ValueError("Expected images with 3 (RGB) or 4 (RGBA) channels, but got shape {}".format(images.shape))
229
 
230
+ # Extract the RGB channels
231
  R = images[:, 0, :, :]
232
  G = images[:, 1, :, :]
233
  B = images[:, 2, :, :]
 
234
 
235
  # Convert RGB to Y (luminance) channel
236
  Y = 0.299 * R + 0.587 * G + 0.114 * B
 
250
  # Compute the variance of the C channel
251
  variance_C = torch.var(C)
252
 
253
+ loss = weights[0] * mean_Y + weights[1] * mean_M - weights[2] * variance_C
 
254
 
255
+ if num_channels == 4:
256
+ # Extract the alpha channel
257
+ A = images[:, 3, :, :]
258
+ # Compute the absolute sum of the A channel
259
+ abs_sum_A = torch.sum(torch.abs(A))
260
+ # Include the alpha component in the loss
261
+ loss += weights[3] * abs_sum_A
262
 
263
  return loss
264