ZeroGPU
Browse files
demo.py
CHANGED
@@ -243,6 +243,7 @@ def preprocess_mask(mask):
|
|
243 |
|
244 |
|
245 |
@spaces.GPU
|
|
|
246 |
def generate_latent_image(mask, class_selection, sampling_steps=50):
|
247 |
"""Generate a latent image based on mask, class selection, and sampling steps"""
|
248 |
|
@@ -306,6 +307,7 @@ def generate_latent_image(mask, class_selection, sampling_steps=50):
|
|
306 |
|
307 |
|
308 |
@spaces.GPU
|
|
|
309 |
def decode_images(latents):
|
310 |
"""Decode latent representations to pixel space using a VAE.
|
311 |
|
@@ -384,6 +386,7 @@ def decode_latent_to_pixel(latent_image):
|
|
384 |
|
385 |
|
386 |
@spaces.GPU
|
|
|
387 |
def check_privacy(latent_image_numpy, class_selection):
|
388 |
"""Check if the latent image is too similar to database images"""
|
389 |
latent_image = torch.from_numpy(latent_image_numpy).to(device, dtype=dtype)
|
@@ -410,6 +413,7 @@ def check_privacy(latent_image_numpy, class_selection):
|
|
410 |
|
411 |
|
412 |
@spaces.GPU
|
|
|
413 |
def generate_animation(
|
414 |
latent_image, ejection_fraction, sampling_steps=50, cfg_scale=1.0
|
415 |
):
|
@@ -491,7 +495,7 @@ def generate_animation(
|
|
491 |
|
492 |
# print("Synthetic video:", synthetic_video.shape)
|
493 |
|
494 |
-
return synthetic_video # B x C x T x H x W
|
495 |
|
496 |
|
497 |
def decode_animation(latent_animation):
|
|
|
243 |
|
244 |
|
245 |
@spaces.GPU
|
246 |
+
@torch.no_grad()
|
247 |
def generate_latent_image(mask, class_selection, sampling_steps=50):
|
248 |
"""Generate a latent image based on mask, class selection, and sampling steps"""
|
249 |
|
|
|
307 |
|
308 |
|
309 |
@spaces.GPU
|
310 |
+
@torch.no_grad()
|
311 |
def decode_images(latents):
|
312 |
"""Decode latent representations to pixel space using a VAE.
|
313 |
|
|
|
386 |
|
387 |
|
388 |
@spaces.GPU
|
389 |
+
@torch.no_grad()
|
390 |
def check_privacy(latent_image_numpy, class_selection):
|
391 |
"""Check if the latent image is too similar to database images"""
|
392 |
latent_image = torch.from_numpy(latent_image_numpy).to(device, dtype=dtype)
|
|
|
413 |
|
414 |
|
415 |
@spaces.GPU
|
416 |
+
@torch.no_grad()
|
417 |
def generate_animation(
|
418 |
latent_image, ejection_fraction, sampling_steps=50, cfg_scale=1.0
|
419 |
):
|
|
|
495 |
|
496 |
# print("Synthetic video:", synthetic_video.shape)
|
497 |
|
498 |
+
return synthetic_video.detach() # B x C x T x H x W
|
499 |
|
500 |
|
501 |
def decode_animation(latent_animation):
|