fffiloni commited on
Commit
e681a74
Β·
verified Β·
1 Parent(s): b4d03cb

Update wan/modules/vae.py

Browse files
Files changed (1) hide show
  1. wan/modules/vae.py +3 -3
wan/modules/vae.py CHANGED
@@ -2,7 +2,7 @@
2
  import logging
3
 
4
  import torch
5
- import torch.cuda.amp as amp
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
  from einops import rearrange
@@ -648,14 +648,14 @@ class WanVAE:
648
  """
649
  videos: A list of videos each with shape [C, T, H, W].
650
  """
651
- with amp.autocast(dtype=self.dtype):
652
  return [
653
  self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
654
  for u in videos
655
  ]
656
 
657
  def decode(self, zs):
658
- with amp.autocast(dtype=self.dtype):
659
  return [
660
  self.model.decode(u.unsqueeze(0),
661
  self.scale).float().clamp_(-1, 1).squeeze(0)
 
2
  import logging
3
 
4
  import torch
5
+ import torch.amp as amp
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
  from einops import rearrange
 
648
  """
649
  videos: A list of videos each with shape [C, T, H, W].
650
  """
651
+ with amp.autocast("cuda", dtype=self.dtype):
652
  return [
653
  self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
654
  for u in videos
655
  ]
656
 
657
  def decode(self, zs):
658
+ with amp.autocast("cuda", dtype=self.dtype):
659
  return [
660
  self.model.decode(u.unsqueeze(0),
661
  self.scale).float().clamp_(-1, 1).squeeze(0)