AkashDataScience commited on
Commit
0c1e9f5
·
1 Parent(s): a67c790
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -151,29 +151,29 @@ def generate_with_embs(num_inference_steps, guidance_scale, seed, text_input, te
151
 
152
  return latents_to_pil(latents)[0]
153
 
154
- def guide_loss(images, loss_type='grayscale'):
155
  # grayscale loss
156
- if loss_type == 'grayscale':
157
  transformed_imgs = grayscale_transformer(images)
158
  error = torch.abs(transformed_imgs - images).mean()
159
 
160
  # brightness loss
161
- elif loss_type == 'bright':
162
  transformed_imgs = tfms.functional.adjust_brightness(images, brightness_factor=3)
163
  error = torch.abs(transformed_imgs - images).mean()
164
 
165
  # contrast loss
166
- elif loss_type == 'contrast':
167
  transformed_imgs = tfms.functional.adjust_contrast(images, contrast_factor=10)
168
  error = torch.abs(transformed_imgs - images).mean()
169
 
170
  # symmetry loss - Flip the image along the width
171
- elif loss_type == "symmetry":
172
  flipped_image = torch.flip(images, [3])
173
  error = F.mse_loss(images, flipped_image)
174
 
175
  # saturation loss
176
- elif loss_type == 'saturation':
177
  transformed_imgs = tfms.functional.adjust_saturation(images,saturation_factor = 10)
178
  error = torch.abs(transformed_imgs - images).mean()
179
 
@@ -291,7 +291,7 @@ demo = gr.Interface(inference,
291
  gr.Slider(1, 10, 7.5, step = 0.1, label="Guidance scale"),
292
  gr.Slider(0, 10000, 1, step = 1, label="Seed"),
293
  gr.Dropdown(label="Guidance method", choices=['Grayscale', 'Bright', 'Contrast',
294
- 'Symmetry', 'Saturation'], value="Concept"),
295
  gr.Slider(100, 10000, 100, step = 1, label="Loss scale")],
296
  outputs= [gr.Image(width=320, height=320, label="Generated art"),
297
  gr.Image(width=320, height=320, label="Generated art with guidance")],
 
151
 
152
  return latents_to_pil(latents)[0]
153
 
154
+ def guide_loss(images, loss_type='Gayscale'):
155
  # grayscale loss
156
+ if loss_type == 'Grayscale':
157
  transformed_imgs = grayscale_transformer(images)
158
  error = torch.abs(transformed_imgs - images).mean()
159
 
160
  # brightness loss
161
+ elif loss_type == 'Bright':
162
  transformed_imgs = tfms.functional.adjust_brightness(images, brightness_factor=3)
163
  error = torch.abs(transformed_imgs - images).mean()
164
 
165
  # contrast loss
166
+ elif loss_type == 'Contrast':
167
  transformed_imgs = tfms.functional.adjust_contrast(images, contrast_factor=10)
168
  error = torch.abs(transformed_imgs - images).mean()
169
 
170
  # symmetry loss - Flip the image along the width
171
+ elif loss_type == "Symmetry":
172
  flipped_image = torch.flip(images, [3])
173
  error = F.mse_loss(images, flipped_image)
174
 
175
  # saturation loss
176
+ elif loss_type == 'Saturation':
177
  transformed_imgs = tfms.functional.adjust_saturation(images,saturation_factor = 10)
178
  error = torch.abs(transformed_imgs - images).mean()
179
 
 
291
  gr.Slider(1, 10, 7.5, step = 0.1, label="Guidance scale"),
292
  gr.Slider(0, 10000, 1, step = 1, label="Seed"),
293
  gr.Dropdown(label="Guidance method", choices=['Grayscale', 'Bright', 'Contrast',
294
+ 'Symmetry', 'Saturation'], value="Grayscale"),
295
  gr.Slider(100, 10000, 100, step = 1, label="Loss scale")],
296
  outputs= [gr.Image(width=320, height=320, label="Generated art"),
297
  gr.Image(width=320, height=320, label="Generated art with guidance")],