Spaces:
Runtime error
Runtime error
Update wan/modules/vae.py
Browse files- wan/modules/vae.py +3 -3
wan/modules/vae.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
import logging
|
3 |
|
4 |
import torch
|
5 |
-
import torch.
|
6 |
import torch.nn as nn
|
7 |
import torch.nn.functional as F
|
8 |
from einops import rearrange
|
@@ -648,14 +648,14 @@ class WanVAE:
|
|
648 |
"""
|
649 |
videos: A list of videos each with shape [C, T, H, W].
|
650 |
"""
|
651 |
-
with amp.autocast(dtype=self.dtype):
|
652 |
return [
|
653 |
self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
|
654 |
for u in videos
|
655 |
]
|
656 |
|
657 |
def decode(self, zs):
|
658 |
-
with amp.autocast(dtype=self.dtype):
|
659 |
return [
|
660 |
self.model.decode(u.unsqueeze(0),
|
661 |
self.scale).float().clamp_(-1, 1).squeeze(0)
|
|
|
2 |
import logging
|
3 |
|
4 |
import torch
|
5 |
+
import torch.amp as amp
|
6 |
import torch.nn as nn
|
7 |
import torch.nn.functional as F
|
8 |
from einops import rearrange
|
|
|
648 |
"""
|
649 |
videos: A list of videos each with shape [C, T, H, W].
|
650 |
"""
|
651 |
+
with amp.autocast("cuda", dtype=self.dtype):
|
652 |
return [
|
653 |
self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
|
654 |
for u in videos
|
655 |
]
|
656 |
|
657 |
def decode(self, zs):
|
658 |
+
with amp.autocast("cuda", dtype=self.dtype):
|
659 |
return [
|
660 |
self.model.decode(u.unsqueeze(0),
|
661 |
self.scale).float().clamp_(-1, 1).squeeze(0)
|