1inkusFace commited on
Commit
c93b626
·
verified ·
1 Parent(s): a636200

Update skyreelsinfer/skyreels_video_infer.py

Browse files
skyreelsinfer/skyreels_video_infer.py CHANGED
@@ -9,9 +9,6 @@ import torch
9
  from diffusers import HunyuanVideoTransformer3DModel
10
  from diffusers import DiffusionPipeline
11
  from PIL import Image
12
- # DELAY torchao imports:
13
- # from torchao.quantization import float8_weight_only
14
- # from torchao.quantization import quantize_
15
  from transformers import LlamaModel
16
 
17
  from . import TaskType
@@ -19,7 +16,6 @@ from .offload import Offload
19
  from .offload import OffloadConfig
20
  from .pipelines import SkyreelsVideoPipeline
21
 
22
-
23
  logger = logging.getLogger("SkyReelsVideoInfer")
24
  logger.setLevel(logging.DEBUG)
25
  console_handler = logging.StreamHandler()
@@ -30,8 +26,6 @@ formatter = logging.Formatter(
30
  console_handler.setFormatter(formatter)
31
  logger.addHandler(console_handler)
32
 
33
-
34
-
35
  class SkyReelsVideoInfer:
36
  def __init__(
37
  self,
@@ -55,11 +49,10 @@ class SkyReelsVideoInfer:
55
  model_id: str,
56
  base_model_id: str = "hunyuanvideo-community/HunyuanVideo",
57
  quant_model: bool = True,
58
- device: str = "cpu", # Use string "cpu"
59
  ) -> SkyreelsVideoPipeline:
60
  logger.info(f"load model model_id:{model_id} quan_model:{quant_model} device:{device}")
61
 
62
- # DELAYED IMPORTS:
63
  from torchao.quantization import float8_weight_only
64
  from torchao.quantization import quantize_
65
 
@@ -87,7 +80,6 @@ class SkyReelsVideoInfer:
87
  pipe.vae.enable_tiling()
88
  return pipe
89
 
90
-
91
  def _initialize_pipeline(self):
92
  self.pipe: SkyreelsVideoPipeline = self._load_model(
93
  model_id=self.model_id, quant_model=self.quant_model, device="cpu"
@@ -99,11 +91,10 @@ class SkyReelsVideoInfer:
99
  config=self.offload_config,
100
  )
101
 
102
-
103
  def inference(self, kwargs):
104
  if self.task_type == TaskType.I2V:
105
  image = kwargs.pop("image")
106
- output = self.pipe(image=image, **kwargs).frames
107
  else:
108
- output = self.pipe(**kwargs).frames
109
- return output
 
9
  from diffusers import HunyuanVideoTransformer3DModel
10
  from diffusers import DiffusionPipeline
11
  from PIL import Image
 
 
 
12
  from transformers import LlamaModel
13
 
14
  from . import TaskType
 
16
  from .offload import OffloadConfig
17
  from .pipelines import SkyreelsVideoPipeline
18
 
 
19
  logger = logging.getLogger("SkyReelsVideoInfer")
20
  logger.setLevel(logging.DEBUG)
21
  console_handler = logging.StreamHandler()
 
26
  console_handler.setFormatter(formatter)
27
  logger.addHandler(console_handler)
28
 
 
 
29
  class SkyReelsVideoInfer:
30
  def __init__(
31
  self,
 
49
  model_id: str,
50
  base_model_id: str = "hunyuanvideo-community/HunyuanVideo",
51
  quant_model: bool = True,
52
+ device: str = "cpu",
53
  ) -> SkyreelsVideoPipeline:
54
  logger.info(f"load model model_id:{model_id} quan_model:{quant_model} device:{device}")
55
 
 
56
  from torchao.quantization import float8_weight_only
57
  from torchao.quantization import quantize_
58
 
 
80
  pipe.vae.enable_tiling()
81
  return pipe
82
 
 
83
  def _initialize_pipeline(self):
84
  self.pipe: SkyreelsVideoPipeline = self._load_model(
85
  model_id=self.model_id, quant_model=self.quant_model, device="cpu"
 
91
  config=self.offload_config,
92
  )
93
 
 
94
  def inference(self, kwargs):
95
  if self.task_type == TaskType.I2V:
96
  image = kwargs.pop("image")
97
+ output = self.pipe(image=image, **kwargs) # Get full output
98
  else:
99
+ output = self.pipe(**kwargs) # Get full output
100
+ return output.frames # Return frames directly