HReynaud commited on
Commit
b004a19
·
1 Parent(s): ef26eff
Files changed (1) hide show
  1. demo.py +5 -1
demo.py CHANGED
@@ -243,6 +243,7 @@ def preprocess_mask(mask):
243
 
244
 
245
  @spaces.GPU
 
246
  def generate_latent_image(mask, class_selection, sampling_steps=50):
247
  """Generate a latent image based on mask, class selection, and sampling steps"""
248
 
@@ -306,6 +307,7 @@ def generate_latent_image(mask, class_selection, sampling_steps=50):
306
 
307
 
308
  @spaces.GPU
 
309
  def decode_images(latents):
310
  """Decode latent representations to pixel space using a VAE.
311
 
@@ -384,6 +386,7 @@ def decode_latent_to_pixel(latent_image):
384
 
385
 
386
  @spaces.GPU
 
387
  def check_privacy(latent_image_numpy, class_selection):
388
  """Check if the latent image is too similar to database images"""
389
  latent_image = torch.from_numpy(latent_image_numpy).to(device, dtype=dtype)
@@ -410,6 +413,7 @@ def check_privacy(latent_image_numpy, class_selection):
410
 
411
 
412
  @spaces.GPU
 
413
  def generate_animation(
414
  latent_image, ejection_fraction, sampling_steps=50, cfg_scale=1.0
415
  ):
@@ -491,7 +495,7 @@ def generate_animation(
491
 
492
  # print("Synthetic video:", synthetic_video.shape)
493
 
494
- return synthetic_video # B x C x T x H x W
495
 
496
 
497
  def decode_animation(latent_animation):
 
243
 
244
 
245
  @spaces.GPU
246
+ @torch.no_grad()
247
  def generate_latent_image(mask, class_selection, sampling_steps=50):
248
  """Generate a latent image based on mask, class selection, and sampling steps"""
249
 
 
307
 
308
 
309
  @spaces.GPU
310
+ @torch.no_grad()
311
  def decode_images(latents):
312
  """Decode latent representations to pixel space using a VAE.
313
 
 
386
 
387
 
388
  @spaces.GPU
389
+ @torch.no_grad()
390
  def check_privacy(latent_image_numpy, class_selection):
391
  """Check if the latent image is too similar to database images"""
392
  latent_image = torch.from_numpy(latent_image_numpy).to(device, dtype=dtype)
 
413
 
414
 
415
  @spaces.GPU
416
+ @torch.no_grad()
417
  def generate_animation(
418
  latent_image, ejection_fraction, sampling_steps=50, cfg_scale=1.0
419
  ):
 
495
 
496
  # print("Synthetic video:", synthetic_video.shape)
497
 
498
+ return synthetic_video.detach() # B x C x T x H x W
499
 
500
 
501
  def decode_animation(latent_animation):