Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
@@ -202,7 +202,6 @@ def blue_loss(images):
|
|
202 |
|
203 |
return -variance
|
204 |
|
205 |
-
import torch
|
206 |
|
207 |
def ymca_loss(images, weights=(1.0, 1.0, 1.0, 1.0)):
|
208 |
"""
|
@@ -263,6 +262,82 @@ def ymca_loss(images, weights=(1.0, 1.0, 1.0, 1.0)):
|
|
263 |
return loss
|
264 |
|
265 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
def blue_loss_variant(images, use_mean=False, alpha=1.0):
|
267 |
"""
|
268 |
Computes the blue loss for a batch of images with an optional mean component.
|
@@ -301,7 +376,7 @@ def blue_loss_variant(images, use_mean=False, alpha=1.0):
|
|
301 |
|
302 |
return loss
|
303 |
|
304 |
-
def generate_with_prompt_style_guidance(prompt, style, seed,num_inference_steps,guidance_scale):
|
305 |
|
306 |
prompt = prompt + ' in style of s'
|
307 |
|
@@ -386,7 +461,19 @@ def generate_with_prompt_style_guidance(prompt, style, seed,num_inference_steps,
|
|
386 |
denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
|
387 |
|
388 |
# Calculate loss
|
389 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
390 |
|
391 |
# # Occasionally print it out
|
392 |
# if i%10==0:
|
@@ -423,7 +510,7 @@ def inference(prompt, seed, style,num_inference_steps,guidance_scale,loss_functi
|
|
423 |
print(loss_function)
|
424 |
style = dict_styles[style]
|
425 |
torch.manual_seed(seed)
|
426 |
-
result = generate_with_prompt_style_guidance(prompt, style,seed,num_inference_steps,guidance_scale)
|
427 |
return np.array(result)
|
428 |
else:
|
429 |
return None
|
@@ -450,7 +537,7 @@ demo = gr.Interface(inference,
|
|
450 |
step=8,
|
451 |
label="Select Guidance Scale",
|
452 |
interactive=True,
|
453 |
-
),gr.Radio(["contrast", "
|
454 |
],
|
455 |
outputs = [
|
456 |
gr.Image(label="Stable Diffusion Output"),
|
|
|
202 |
|
203 |
return -variance
|
204 |
|
|
|
205 |
|
206 |
def ymca_loss(images, weights=(1.0, 1.0, 1.0, 1.0)):
|
207 |
"""
|
|
|
262 |
return loss
|
263 |
|
264 |
|
265 |
+
|
266 |
+
def rgb_to_cmyk(images):
|
267 |
+
"""
|
268 |
+
Converts an RGB image tensor to CMYK.
|
269 |
+
|
270 |
+
Parameters:
|
271 |
+
images (torch.Tensor): A batch of images in RGB format. Expected shape is (N, 3, H, W).
|
272 |
+
|
273 |
+
Returns:
|
274 |
+
torch.Tensor: A tensor containing the CMYK channels.
|
275 |
+
"""
|
276 |
+
R = images[:, 0, :, :]
|
277 |
+
G = images[:, 1, :, :]
|
278 |
+
B = images[:, 2, :, :]
|
279 |
+
|
280 |
+
# Convert RGB to CMY
|
281 |
+
C = 1 - R
|
282 |
+
M = 1 - G
|
283 |
+
Y = 1 - B
|
284 |
+
|
285 |
+
# Convert CMY to CMYK
|
286 |
+
K = torch.min(torch.min(C, M), Y)
|
287 |
+
C = (C - K) / (1 - K + 1e-8)
|
288 |
+
M = (M - K) / (1 - K + 1e-8)
|
289 |
+
Y = (Y - K) / (1 - K + 1e-8)
|
290 |
+
|
291 |
+
CMYK = torch.stack([C, M, Y, K], dim=1)
|
292 |
+
return CMYK
|
293 |
+
|
294 |
+
def cymk_loss(images, weights=(1.0, 1.0, 1.0, 1.0)):
|
295 |
+
"""
|
296 |
+
Computes the CYMK loss for a batch of images.
|
297 |
+
|
298 |
+
The CYMK loss is a custom loss function combining the variance of the Cyan channel,
|
299 |
+
the mean value of the Yellow channel, the variance of the Magenta channel, and the
|
300 |
+
absolute sum of the Black channel.
|
301 |
+
|
302 |
+
Parameters:
|
303 |
+
images (torch.Tensor): A batch of images. Expected shape is (N, 3, H, W) for RGB input.
|
304 |
+
weights (tuple): A tuple of four floats representing the weights for each component of the loss
|
305 |
+
(default is (1.0, 1.0, 1.0, 1.0)).
|
306 |
+
|
307 |
+
Returns:
|
308 |
+
torch.Tensor: The CYMK loss, combining the specified components.
|
309 |
+
"""
|
310 |
+
# Ensure the input tensor has the correct shape
|
311 |
+
if images.shape[1] != 3:
|
312 |
+
raise ValueError("Expected images with 3 channels (RGB), but got shape {}".format(images.shape))
|
313 |
+
|
314 |
+
# Convert RGB to CMYK
|
315 |
+
cmyk_images = rgb_to_cmyk(images)
|
316 |
+
|
317 |
+
# Extract CMYK channels
|
318 |
+
C = cmyk_images[:, 0, :, :]
|
319 |
+
M = cmyk_images[:, 1, :, :]
|
320 |
+
Y = cmyk_images[:, 2, :, :]
|
321 |
+
K = cmyk_images[:, 3, :, :]
|
322 |
+
|
323 |
+
# Compute the variance of the C channel
|
324 |
+
variance_C = torch.var(C)
|
325 |
+
|
326 |
+
# Compute the mean of the Y channel
|
327 |
+
mean_Y = torch.mean(Y)
|
328 |
+
|
329 |
+
# Compute the variance of the M channel
|
330 |
+
variance_M = torch.var(M)
|
331 |
+
|
332 |
+
# Compute the absolute sum of the K channel
|
333 |
+
abs_sum_K = torch.sum(torch.abs(K))
|
334 |
+
|
335 |
+
# Combine the components with the given weights
|
336 |
+
loss = (weights[0] * variance_C) + (weights[1] * mean_Y) + (weights[2] * variance_M) + (weights[3] * abs_sum_K)
|
337 |
+
|
338 |
+
return loss
|
339 |
+
|
340 |
+
|
341 |
def blue_loss_variant(images, use_mean=False, alpha=1.0):
|
342 |
"""
|
343 |
Computes the blue loss for a batch of images with an optional mean component.
|
|
|
376 |
|
377 |
return loss
|
378 |
|
379 |
+
def generate_with_prompt_style_guidance(prompt, style, seed,num_inference_steps,guidance_scale,loss_function):
|
380 |
|
381 |
prompt = prompt + ' in style of s'
|
382 |
|
|
|
461 |
denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
|
462 |
|
463 |
# Calculate loss
|
464 |
+
# "contrast", "blue_original", "blue_modified","ymca_loss","cymk_loss"
|
465 |
+
if loss_function == "contrast":
|
466 |
+
loss = contrast_loss(denoised_images) * contrast_loss_scale
|
467 |
+
elif loss_function == "blue_original":
|
468 |
+
loss = blue_loss(denoised_images) * contrast_loss_scale
|
469 |
+
elif loss_function == "blue_modified":
|
470 |
+
loss = blue_loss_variant(denoised_images) * contrast_loss_scale
|
471 |
+
elif loss_function == "ymca_loss":
|
472 |
+
loss = ymca_loss(denoised_images) * contrast_loss_scale
|
473 |
+
elif loss_function == "cymk_loss":
|
474 |
+
loss = cymk_loss(denoised_images) * contrast_loss_scale
|
475 |
+
else :
|
476 |
+
loss = ymca_loss(denoised_images) * contrast_loss_scale
|
477 |
|
478 |
# # Occasionally print it out
|
479 |
# if i%10==0:
|
|
|
510 |
print(loss_function)
|
511 |
style = dict_styles[style]
|
512 |
torch.manual_seed(seed)
|
513 |
+
result = generate_with_prompt_style_guidance(prompt, style,seed,num_inference_steps,guidance_scale,loss_function)
|
514 |
return np.array(result)
|
515 |
else:
|
516 |
return None
|
|
|
537 |
step=8,
|
538 |
label="Select Guidance Scale",
|
539 |
interactive=True,
|
540 |
+
),gr.Radio(["contrast", "blue_original", "blue_modified","ymca_loss","cymk_loss"], label="loss-function", info="loss-function" , value="ymca_loss"),
|
541 |
],
|
542 |
outputs = [
|
543 |
gr.Image(label="Stable Diffusion Output"),
|