Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
@@ -202,17 +202,19 @@ def blue_loss(images):
|
|
202 |
|
203 |
return -variance
|
204 |
|
|
|
|
|
205 |
def ymca_loss(images, weights=(1.0, 1.0, 1.0, 1.0)):
|
206 |
"""
|
207 |
Computes the YMCA loss for a batch of images.
|
208 |
|
209 |
The YMCA loss is a custom loss function combining the mean value of the Y (luminance) channel,
|
210 |
the mean value of the M (magenta) channel, the variance of the C (cyan) channel, and the
|
211 |
-
absolute sum of the A (alpha) channel.
|
212 |
|
213 |
Parameters:
|
214 |
images (torch.Tensor): A batch of images. Expected shape is (N, C, H, W) where
|
215 |
-
N is the batch size, C is the number of channels (
|
216 |
H is the height, and W is the width.
|
217 |
weights (tuple): A tuple of four floats representing the weights for each component of the loss
|
218 |
(default is (1.0, 1.0, 1.0, 1.0)).
|
@@ -220,15 +222,15 @@ def ymca_loss(images, weights=(1.0, 1.0, 1.0, 1.0)):
|
|
220 |
Returns:
|
221 |
torch.Tensor: The YMCA loss, combining the specified components.
|
222 |
"""
|
223 |
-
|
224 |
-
|
225 |
-
|
|
|
226 |
|
227 |
-
# Extract the
|
228 |
R = images[:, 0, :, :]
|
229 |
G = images[:, 1, :, :]
|
230 |
B = images[:, 2, :, :]
|
231 |
-
A = images[:, 3, :, :]
|
232 |
|
233 |
# Convert RGB to Y (luminance) channel
|
234 |
Y = 0.299 * R + 0.587 * G + 0.114 * B
|
@@ -248,11 +250,15 @@ def ymca_loss(images, weights=(1.0, 1.0, 1.0, 1.0)):
|
|
248 |
# Compute the variance of the C channel
|
249 |
variance_C = torch.var(C)
|
250 |
|
251 |
-
|
252 |
-
abs_sum_A = torch.sum(torch.abs(A))
|
253 |
|
254 |
-
|
255 |
-
|
|
|
|
|
|
|
|
|
|
|
256 |
|
257 |
return loss
|
258 |
|
|
|
202 |
|
203 |
return -variance
|
204 |
|
205 |
+
import torch
|
206 |
+
|
207 |
def ymca_loss(images, weights=(1.0, 1.0, 1.0, 1.0)):
|
208 |
"""
|
209 |
Computes the YMCA loss for a batch of images.
|
210 |
|
211 |
The YMCA loss is a custom loss function combining the mean value of the Y (luminance) channel,
|
212 |
the mean value of the M (magenta) channel, the variance of the C (cyan) channel, and the
|
213 |
+
absolute sum of the A (alpha) channel if present.
|
214 |
|
215 |
Parameters:
|
216 |
images (torch.Tensor): A batch of images. Expected shape is (N, C, H, W) where
|
217 |
+
N is the batch size, C is the number of channels (3 for RGB or 4 for RGBA),
|
218 |
H is the height, and W is the width.
|
219 |
weights (tuple): A tuple of four floats representing the weights for each component of the loss
|
220 |
(default is (1.0, 1.0, 1.0, 1.0)).
|
|
|
222 |
Returns:
|
223 |
torch.Tensor: The YMCA loss, combining the specified components.
|
224 |
"""
|
225 |
+
num_channels = images.shape[1]
|
226 |
+
|
227 |
+
if num_channels not in [3, 4]:
|
228 |
+
raise ValueError("Expected images with 3 (RGB) or 4 (RGBA) channels, but got shape {}".format(images.shape))
|
229 |
|
230 |
+
# Extract the RGB channels
|
231 |
R = images[:, 0, :, :]
|
232 |
G = images[:, 1, :, :]
|
233 |
B = images[:, 2, :, :]
|
|
|
234 |
|
235 |
# Convert RGB to Y (luminance) channel
|
236 |
Y = 0.299 * R + 0.587 * G + 0.114 * B
|
|
|
250 |
# Compute the variance of the C channel
|
251 |
variance_C = torch.var(C)
|
252 |
|
253 |
+
loss = weights[0] * mean_Y + weights[1] * mean_M - weights[2] * variance_C
|
|
|
254 |
|
255 |
+
if num_channels == 4:
|
256 |
+
# Extract the alpha channel
|
257 |
+
A = images[:, 3, :, :]
|
258 |
+
# Compute the absolute sum of the A channel
|
259 |
+
abs_sum_A = torch.sum(torch.abs(A))
|
260 |
+
# Include the alpha component in the loss
|
261 |
+
loss += weights[3] * abs_sum_A
|
262 |
|
263 |
return loss
|
264 |
|