Spaces:
Build error
Build error
Update skyreelsinfer/skyreels_video_infer.py
Browse files
skyreelsinfer/skyreels_video_infer.py
CHANGED
@@ -9,9 +9,6 @@ import torch
|
|
9 |
from diffusers import HunyuanVideoTransformer3DModel
|
10 |
from diffusers import DiffusionPipeline
|
11 |
from PIL import Image
|
12 |
-
# DELAY torchao imports:
|
13 |
-
# from torchao.quantization import float8_weight_only
|
14 |
-
# from torchao.quantization import quantize_
|
15 |
from transformers import LlamaModel
|
16 |
|
17 |
from . import TaskType
|
@@ -19,7 +16,6 @@ from .offload import Offload
|
|
19 |
from .offload import OffloadConfig
|
20 |
from .pipelines import SkyreelsVideoPipeline
|
21 |
|
22 |
-
|
23 |
logger = logging.getLogger("SkyReelsVideoInfer")
|
24 |
logger.setLevel(logging.DEBUG)
|
25 |
console_handler = logging.StreamHandler()
|
@@ -30,8 +26,6 @@ formatter = logging.Formatter(
|
|
30 |
console_handler.setFormatter(formatter)
|
31 |
logger.addHandler(console_handler)
|
32 |
|
33 |
-
|
34 |
-
|
35 |
class SkyReelsVideoInfer:
|
36 |
def __init__(
|
37 |
self,
|
@@ -55,11 +49,10 @@ class SkyReelsVideoInfer:
|
|
55 |
model_id: str,
|
56 |
base_model_id: str = "hunyuanvideo-community/HunyuanVideo",
|
57 |
quant_model: bool = True,
|
58 |
-
device: str = "cpu",
|
59 |
) -> SkyreelsVideoPipeline:
|
60 |
logger.info(f"load model model_id:{model_id} quan_model:{quant_model} device:{device}")
|
61 |
|
62 |
-
# DELAYED IMPORTS:
|
63 |
from torchao.quantization import float8_weight_only
|
64 |
from torchao.quantization import quantize_
|
65 |
|
@@ -87,7 +80,6 @@ class SkyReelsVideoInfer:
|
|
87 |
pipe.vae.enable_tiling()
|
88 |
return pipe
|
89 |
|
90 |
-
|
91 |
def _initialize_pipeline(self):
|
92 |
self.pipe: SkyreelsVideoPipeline = self._load_model(
|
93 |
model_id=self.model_id, quant_model=self.quant_model, device="cpu"
|
@@ -99,11 +91,10 @@ class SkyReelsVideoInfer:
|
|
99 |
config=self.offload_config,
|
100 |
)
|
101 |
|
102 |
-
|
103 |
def inference(self, kwargs):
|
104 |
if self.task_type == TaskType.I2V:
|
105 |
image = kwargs.pop("image")
|
106 |
-
output = self.pipe(image=image, **kwargs)
|
107 |
else:
|
108 |
-
output = self.pipe(**kwargs)
|
109 |
-
return output
|
|
|
9 |
from diffusers import HunyuanVideoTransformer3DModel
|
10 |
from diffusers import DiffusionPipeline
|
11 |
from PIL import Image
|
|
|
|
|
|
|
12 |
from transformers import LlamaModel
|
13 |
|
14 |
from . import TaskType
|
|
|
16 |
from .offload import OffloadConfig
|
17 |
from .pipelines import SkyreelsVideoPipeline
|
18 |
|
|
|
19 |
logger = logging.getLogger("SkyReelsVideoInfer")
|
20 |
logger.setLevel(logging.DEBUG)
|
21 |
console_handler = logging.StreamHandler()
|
|
|
26 |
console_handler.setFormatter(formatter)
|
27 |
logger.addHandler(console_handler)
|
28 |
|
|
|
|
|
29 |
class SkyReelsVideoInfer:
|
30 |
def __init__(
|
31 |
self,
|
|
|
49 |
model_id: str,
|
50 |
base_model_id: str = "hunyuanvideo-community/HunyuanVideo",
|
51 |
quant_model: bool = True,
|
52 |
+
device: str = "cpu",
|
53 |
) -> SkyreelsVideoPipeline:
|
54 |
logger.info(f"load model model_id:{model_id} quan_model:{quant_model} device:{device}")
|
55 |
|
|
|
56 |
from torchao.quantization import float8_weight_only
|
57 |
from torchao.quantization import quantize_
|
58 |
|
|
|
80 |
pipe.vae.enable_tiling()
|
81 |
return pipe
|
82 |
|
|
|
83 |
def _initialize_pipeline(self):
|
84 |
self.pipe: SkyreelsVideoPipeline = self._load_model(
|
85 |
model_id=self.model_id, quant_model=self.quant_model, device="cpu"
|
|
|
91 |
config=self.offload_config,
|
92 |
)
|
93 |
|
|
|
94 |
def inference(self, kwargs):
|
95 |
if self.task_type == TaskType.I2V:
|
96 |
image = kwargs.pop("image")
|
97 |
+
output = self.pipe(image=image, **kwargs) # Get full output
|
98 |
else:
|
99 |
+
output = self.pipe(**kwargs) # Get full output
|
100 |
+
return output.frames # Return frames directly
|