Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
from base64 import b64encode
|
2 |
-
import torch
|
3 |
import numpy
|
4 |
import torch
|
5 |
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
|
6 |
from huggingface_hub import notebook_login
|
|
|
7 |
|
8 |
# For video display:
|
9 |
from matplotlib import pyplot as plt
|
@@ -16,7 +16,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, logging
|
|
16 |
import os
|
17 |
import numpy as np
|
18 |
|
19 |
-
|
20 |
|
21 |
# Supress some unnecessary warnings when loading the CLIPTextModel
|
22 |
logging.set_verbosity_error()
|
@@ -145,7 +145,7 @@ def generate_with_embs(text_embeddings, text_input, seed):
|
|
145 |
return latents_to_pil(latents)[0]
|
146 |
|
147 |
|
148 |
-
def generate_with_prompt_style(prompt, style, seed
|
149 |
|
150 |
prompt = prompt + ' in style of s'
|
151 |
embed = torch.load(style)
|
@@ -175,72 +175,7 @@ def contrast_loss(images):
|
|
175 |
variance = torch.var(images)
|
176 |
return -variance
|
177 |
|
178 |
-
def
|
179 |
-
"""
|
180 |
-
Computes the blue loss for a batch of images with an optional mean component.
|
181 |
-
|
182 |
-
The blue loss is defined as the negative variance of the blue channel's pixel values.
|
183 |
-
Optionally, it can also include the mean value of the blue channel.
|
184 |
-
|
185 |
-
Parameters:
|
186 |
-
images (torch.Tensor): A batch of images. Expected shape is (N, C, H, W) where
|
187 |
-
N is the batch size, C is the number of channels (3 for RGB),
|
188 |
-
H is the height, and W is the width.
|
189 |
-
use_mean (bool): If True, includes the mean of the blue channel in the loss calculation.
|
190 |
-
alpha (float): Weighting factor for the mean component when use_mean is True.
|
191 |
-
|
192 |
-
Returns:
|
193 |
-
torch.Tensor: The blue loss, which is the negative variance of the blue channel's pixel values,
|
194 |
-
optionally combined with the mean value of the blue channel.
|
195 |
-
"""
|
196 |
-
# Ensure the input tensor has the correct shape
|
197 |
-
if images.shape[1] != 3:
|
198 |
-
raise ValueError("Expected images with 3 channels (RGB), but got shape {}".format(images.shape))
|
199 |
-
|
200 |
-
# Extract the blue channel (assuming the channels are in RGB order)
|
201 |
-
blue_channel = images[:, 2, :, :]
|
202 |
-
|
203 |
-
# Calculate the variance of the blue channel
|
204 |
-
variance = torch.var(blue_channel)
|
205 |
-
|
206 |
-
if use_mean:
|
207 |
-
# Calculate the mean of the blue channel
|
208 |
-
mean = torch.mean(blue_channel)
|
209 |
-
# Combine variance and mean into the loss
|
210 |
-
loss = -variance + alpha * mean
|
211 |
-
else:
|
212 |
-
loss = -variance
|
213 |
-
|
214 |
-
return loss
|
215 |
-
|
216 |
-
def blue_loss(images):
|
217 |
-
"""
|
218 |
-
Computes the blue loss for a batch of images.
|
219 |
-
|
220 |
-
The blue loss is defined as the negative variance of the blue channel's pixel values.
|
221 |
-
|
222 |
-
Parameters:
|
223 |
-
images (torch.Tensor): A batch of images. Expected shape is (N, C, H, W) where
|
224 |
-
N is the batch size, C is the number of channels (3 for RGB),
|
225 |
-
H is the height, and W is the width.
|
226 |
-
|
227 |
-
Returns:
|
228 |
-
torch.Tensor: The blue loss, which is the negative variance of the blue channel's pixel values.
|
229 |
-
"""
|
230 |
-
# Ensure the input tensor has the correct shape
|
231 |
-
if images.shape[1] != 3:
|
232 |
-
raise ValueError("Expected images with 3 channels (RGB), but got shape {}".format(images.shape))
|
233 |
-
|
234 |
-
# Extract the blue channel (assuming the channels are in RGB order)
|
235 |
-
blue_channel = images[:, 2, :, :]
|
236 |
-
|
237 |
-
# Calculate the variance of the blue channel
|
238 |
-
variance = torch.var(blue_channel)
|
239 |
-
|
240 |
-
return -variance
|
241 |
-
|
242 |
-
|
243 |
-
def generate_with_prompt_style_guidance(prompt, style, seed=42):
|
244 |
|
245 |
prompt = prompt + ' in style of s'
|
246 |
|
@@ -325,7 +260,7 @@ def generate_with_prompt_style_guidance(prompt, style, seed=42):
|
|
325 |
denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
|
326 |
|
327 |
# Calculate loss
|
328 |
-
loss =
|
329 |
|
330 |
# # Occasionally print it out
|
331 |
# if i%10==0:
|
@@ -344,7 +279,7 @@ def generate_with_prompt_style_guidance(prompt, style, seed=42):
|
|
344 |
return latents_to_pil(latents)[0]
|
345 |
|
346 |
|
347 |
-
|
348 |
|
349 |
dict_styles = {
|
350 |
'Dr Strange': 'styles/learned_embeds_dr_strange.bin',
|
@@ -354,11 +289,12 @@ dict_styles = {
|
|
354 |
}
|
355 |
# dict_styles.keys()
|
356 |
|
357 |
-
def inference(prompt, style):
|
358 |
|
359 |
-
if prompt is not None and style is not None:
|
360 |
style = dict_styles[style]
|
361 |
-
|
|
|
362 |
return np.array(result)
|
363 |
else:
|
364 |
return None
|
@@ -369,6 +305,7 @@ examples = [['A man sipping wine wearing a spacesuit on the moon', 'Stripes']]
|
|
369 |
|
370 |
demo = gr.Interface(inference,
|
371 |
inputs = [gr.Textbox(label='Prompt'),
|
|
|
372 |
gr.Dropdown(['Dr Strange', 'GTA-5',
|
373 |
'Manga', 'Pokemon'], label='Style')
|
374 |
],
|
@@ -377,7 +314,8 @@ demo = gr.Interface(inference,
|
|
377 |
],
|
378 |
title = title,
|
379 |
description = description,
|
380 |
-
|
381 |
# cache_examples=True
|
382 |
)
|
383 |
demo.launch()
|
|
|
|
1 |
from base64 import b64encode
|
|
|
2 |
import numpy
|
3 |
import torch
|
4 |
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
|
5 |
from huggingface_hub import notebook_login
|
6 |
+
import gradio as gr
|
7 |
|
8 |
# For video display:
|
9 |
from matplotlib import pyplot as plt
|
|
|
16 |
import os
|
17 |
import numpy as np
|
18 |
|
19 |
+
|
20 |
|
21 |
# Supress some unnecessary warnings when loading the CLIPTextModel
|
22 |
logging.set_verbosity_error()
|
|
|
145 |
return latents_to_pil(latents)[0]
|
146 |
|
147 |
|
148 |
+
def generate_with_prompt_style(prompt, style, seed):
|
149 |
|
150 |
prompt = prompt + ' in style of s'
|
151 |
embed = torch.load(style)
|
|
|
175 |
variance = torch.var(images)
|
176 |
return -variance
|
177 |
|
178 |
+
def generate_with_prompt_style_guidance(prompt, style, seed):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
prompt = prompt + ' in style of s'
|
181 |
|
|
|
260 |
denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
|
261 |
|
262 |
# Calculate loss
|
263 |
+
loss = contrast_loss(denoised_images) * contrast_loss_scale
|
264 |
|
265 |
# # Occasionally print it out
|
266 |
# if i%10==0:
|
|
|
279 |
return latents_to_pil(latents)[0]
|
280 |
|
281 |
|
282 |
+
|
283 |
|
284 |
dict_styles = {
|
285 |
'Dr Strange': 'styles/learned_embeds_dr_strange.bin',
|
|
|
289 |
}
|
290 |
# dict_styles.keys()
|
291 |
|
292 |
+
def inference(prompt, seed,style):
|
293 |
|
294 |
+
if prompt is not None and style is not None and seed is not None:
|
295 |
style = dict_styles[style]
|
296 |
+
torch.manual_seed(seed)
|
297 |
+
result = generate_with_prompt_style_guidance(prompt, style,seed)
|
298 |
return np.array(result)
|
299 |
else:
|
300 |
return None
|
|
|
305 |
|
306 |
demo = gr.Interface(inference,
|
307 |
inputs = [gr.Textbox(label='Prompt'),
|
308 |
+
gr.Textbox(label='Seed', value='24041975'),
|
309 |
gr.Dropdown(['Dr Strange', 'GTA-5',
|
310 |
'Manga', 'Pokemon'], label='Style')
|
311 |
],
|
|
|
314 |
],
|
315 |
title = title,
|
316 |
description = description,
|
317 |
+
examples = examples,
|
318 |
# cache_examples=True
|
319 |
)
|
320 |
demo.launch()
|
321 |
+
|