Shivdutta commited on
Commit
3d736df
·
verified ·
1 Parent(s): 9fe3ab5

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -10
app.py CHANGED
@@ -388,7 +388,7 @@ def generate_with_prompt_style_guidance(prompt, style, seed,num_inference_steps,
388
  guidance_scale = guidance_scale # # Scale for classifier-free guidance
389
  generator = torch.manual_seed(seed) # Seed generator to create the inital latent noise
390
  batch_size = 1
391
- contrast_loss_scale = 200 #
392
 
393
  # Prep text
394
  text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
@@ -463,17 +463,22 @@ def generate_with_prompt_style_guidance(prompt, style, seed,num_inference_steps,
463
  # Calculate loss
464
  # "contrast", "blue_original", "blue_modified","ymca_loss","cymk_loss"
465
  if loss_function == "contrast":
466
- loss = contrast_loss(denoised_images) * contrast_loss_scale
 
467
  elif loss_function == "blue_original":
468
- loss = blue_loss(denoised_images) * contrast_loss_scale
 
469
  elif loss_function == "blue_modified":
470
- loss = blue_loss_variant(denoised_images) * contrast_loss_scale
471
- elif loss_function == "ymca_loss":
472
- loss = ymca_loss(denoised_images) * contrast_loss_scale
473
- elif loss_function == "cymk_loss":
474
- loss = cymk_loss(denoised_images) * contrast_loss_scale
 
 
 
475
  else :
476
- loss = ymca_loss(denoised_images) * contrast_loss_scale
477
 
478
  # # Occasionally print it out
479
  # if i%10==0:
@@ -537,7 +542,7 @@ demo = gr.Interface(inference,
537
  step=8,
538
  label="Select Guidance Scale",
539
  interactive=True,
540
- ),gr.Radio(["contrast", "blue_original", "blue_modified","ymca_loss","cymk_loss"], label="loss-function", info="loss-function" , value="ymca_loss"),
541
  ],
542
  outputs = [
543
  gr.Image(label="Stable Diffusion Output"),
 
388
  guidance_scale = guidance_scale # # Scale for classifier-free guidance
389
  generator = torch.manual_seed(seed) # Seed generator to create the inital latent noise
390
  batch_size = 1
391
+
392
 
393
  # Prep text
394
  text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
 
463
  # Calculate loss
464
  # "contrast", "blue_original", "blue_modified","ymca_loss","cymk_loss"
465
  if loss_function == "contrast":
466
+ loss_scale = 200 #
467
+ loss = contrast_loss(denoised_images) * loss_scale
468
  elif loss_function == "blue_original":
469
+ loss_scale = 200 #
470
+ loss = blue_loss(denoised_images) * loss_scale
471
  elif loss_function == "blue_modified":
472
+ loss_scale = 200 #
473
+ loss = blue_loss_variant(denoised_images) * loss_scale
474
+ elif loss_function == "ymca":
475
+ loss_scale = 200 #
476
+ loss = ymca_loss(denoised_images) * loss_scale
477
+ elif loss_function == "cmyk":
478
+ loss_scale = 10 #
479
+ loss = cymk_loss(denoised_images) * loss_scale
480
  else :
481
+ loss = ymca_loss(denoised_images) * loss_scale
482
 
483
  # # Occasionally print it out
484
  # if i%10==0:
 
542
  step=8,
543
  label="Select Guidance Scale",
544
  interactive=True,
545
+ ),gr.Radio(["contrast", "blue_original", "blue_modified","ymca","cmyk"], label="loss-function", info="loss-function" , value="ymca_loss"),
546
  ],
547
  outputs = [
548
  gr.Image(label="Stable Diffusion Output"),