Commit
·
0c1e9f5
1
Parent(s):
a67c790
Minor fix
Browse files
app.py
CHANGED
@@ -151,29 +151,29 @@ def generate_with_embs(num_inference_steps, guidance_scale, seed, text_input, te
|
|
151 |
|
152 |
return latents_to_pil(latents)[0]
|
153 |
|
154 |
-
def guide_loss(images, loss_type='
|
155 |
# grayscale loss
|
156 |
-
if loss_type == '
|
157 |
transformed_imgs = grayscale_transformer(images)
|
158 |
error = torch.abs(transformed_imgs - images).mean()
|
159 |
|
160 |
# brightness loss
|
161 |
-
elif loss_type == '
|
162 |
transformed_imgs = tfms.functional.adjust_brightness(images, brightness_factor=3)
|
163 |
error = torch.abs(transformed_imgs - images).mean()
|
164 |
|
165 |
# contrast loss
|
166 |
-
elif loss_type == '
|
167 |
transformed_imgs = tfms.functional.adjust_contrast(images, contrast_factor=10)
|
168 |
error = torch.abs(transformed_imgs - images).mean()
|
169 |
|
170 |
# symmetry loss - Flip the image along the width
|
171 |
-
elif loss_type == "
|
172 |
flipped_image = torch.flip(images, [3])
|
173 |
error = F.mse_loss(images, flipped_image)
|
174 |
|
175 |
# saturation loss
|
176 |
-
elif loss_type == '
|
177 |
transformed_imgs = tfms.functional.adjust_saturation(images,saturation_factor = 10)
|
178 |
error = torch.abs(transformed_imgs - images).mean()
|
179 |
|
@@ -291,7 +291,7 @@ demo = gr.Interface(inference,
|
|
291 |
gr.Slider(1, 10, 7.5, step = 0.1, label="Guidance scale"),
|
292 |
gr.Slider(0, 10000, 1, step = 1, label="Seed"),
|
293 |
gr.Dropdown(label="Guidance method", choices=['Grayscale', 'Bright', 'Contrast',
|
294 |
-
'Symmetry', 'Saturation'], value="
|
295 |
gr.Slider(100, 10000, 100, step = 1, label="Loss scale")],
|
296 |
outputs= [gr.Image(width=320, height=320, label="Generated art"),
|
297 |
gr.Image(width=320, height=320, label="Generated art with guidance")],
|
|
|
151 |
|
152 |
return latents_to_pil(latents)[0]
|
153 |
|
154 |
+
def guide_loss(images, loss_type='Gayscale'):
|
155 |
# grayscale loss
|
156 |
+
if loss_type == 'Grayscale':
|
157 |
transformed_imgs = grayscale_transformer(images)
|
158 |
error = torch.abs(transformed_imgs - images).mean()
|
159 |
|
160 |
# brightness loss
|
161 |
+
elif loss_type == 'Bright':
|
162 |
transformed_imgs = tfms.functional.adjust_brightness(images, brightness_factor=3)
|
163 |
error = torch.abs(transformed_imgs - images).mean()
|
164 |
|
165 |
# contrast loss
|
166 |
+
elif loss_type == 'Contrast':
|
167 |
transformed_imgs = tfms.functional.adjust_contrast(images, contrast_factor=10)
|
168 |
error = torch.abs(transformed_imgs - images).mean()
|
169 |
|
170 |
# symmetry loss - Flip the image along the width
|
171 |
+
elif loss_type == "Symmetry":
|
172 |
flipped_image = torch.flip(images, [3])
|
173 |
error = F.mse_loss(images, flipped_image)
|
174 |
|
175 |
# saturation loss
|
176 |
+
elif loss_type == 'Saturation':
|
177 |
transformed_imgs = tfms.functional.adjust_saturation(images,saturation_factor = 10)
|
178 |
error = torch.abs(transformed_imgs - images).mean()
|
179 |
|
|
|
291 |
gr.Slider(1, 10, 7.5, step = 0.1, label="Guidance scale"),
|
292 |
gr.Slider(0, 10000, 1, step = 1, label="Seed"),
|
293 |
gr.Dropdown(label="Guidance method", choices=['Grayscale', 'Bright', 'Contrast',
|
294 |
+
'Symmetry', 'Saturation'], value="Grayscale"),
|
295 |
gr.Slider(100, 10000, 100, step = 1, label="Loss scale")],
|
296 |
outputs= [gr.Image(width=320, height=320, label="Generated art"),
|
297 |
gr.Image(width=320, height=320, label="Generated art with guidance")],
|