1inkusFace commited on
Commit
11907e7
·
verified ·
1 Parent(s): df19679

Update skyreelsinfer/skyreels_video_infer.py

Browse files
Files changed (1) hide show
  1. skyreelsinfer/skyreels_video_infer.py +12 -23
skyreelsinfer/skyreels_video_infer.py CHANGED
@@ -4,6 +4,17 @@ import time
4
  from datetime import timedelta
5
  from typing import Any
6
  from typing import Dict
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  # DELAY ALL THESE IMPORTS:
9
  # import torch
@@ -34,7 +45,7 @@ class SkyReelsVideoInfer:
34
  model_id: str,
35
  quant_model: bool = True,
36
  is_offload: bool = True,
37
- offload_config = None, # No OffloadConfig
38
  use_multiprocessing: bool = False,
39
  ):
40
  self.task_type = task_type
@@ -42,7 +53,6 @@ class SkyReelsVideoInfer:
42
  self.quant_model = quant_model
43
  self.is_offload = is_offload
44
  self.offload_config = offload_config
45
-
46
  self._initialize_pipeline()
47
 
48
  def _load_model(
@@ -52,31 +62,16 @@ class SkyReelsVideoInfer:
52
  quant_model: bool = True,
53
  device: str = "cuda",
54
  ):
55
- # DELAYED IMPORTS:
56
- import torch
57
- from diffusers import HunyuanVideoTransformer3DModel
58
- from diffusers import DiffusionPipeline
59
- from PIL import Image
60
- from transformers import LlamaModel
61
- from torchao.quantization import float8_weight_only
62
- from torchao.quantization import quantize_
63
- from .pipelines import SkyreelsVideoPipeline # Local import
64
-
65
-
66
  logger.info(f"load model model_id:{model_id} quan_model:{quant_model} device:{device}")
67
-
68
  text_encoder = LlamaModel.from_pretrained(
69
  base_model_id,
70
  subfolder="text_encoder",
71
  torch_dtype=torch.bfloat16,
72
  ).to(device)
73
-
74
  transformer = HunyuanVideoTransformer3DModel.from_pretrained(
75
  model_id,
76
  torch_dtype=torch.bfloat16,
77
  ).to(device)
78
-
79
-
80
  if quant_model:
81
  quantize_(text_encoder, float8_weight_only(), device=device)
82
  quantize_(transformer, float8_weight_only(), device=device)
@@ -90,13 +85,9 @@ class SkyReelsVideoInfer:
90
  return pipe
91
 
92
  def _initialize_pipeline(self):
93
- #More Delayed Imports
94
- from .offload import Offload
95
-
96
  self.pipe = self._load_model( #No : SkyreelsVideoPipeline
97
  model_id=self.model_id, quant_model=self.quant_model, device="cuda"
98
  )
99
-
100
  if self.is_offload and self.offload_config:
101
  Offload.offload(
102
  pipeline=self.pipe,
@@ -104,8 +95,6 @@ class SkyReelsVideoInfer:
104
  )
105
 
106
  def inference(self, kwargs):
107
- #DELAYED IMPORTS
108
- from . import TaskType
109
  if self.task_type == TaskType.I2V:
110
  image = kwargs.pop("image")
111
  output = self.pipe(image=image, **kwargs)
 
4
  from datetime import timedelta
5
  from typing import Any
6
  from typing import Dict
7
+ import torch
8
+ from diffusers import HunyuanVideoTransformer3DModel
9
+ from diffusers import DiffusionPipeline
10
+ from PIL import Image
11
+ from transformers import LlamaModel
12
+ from torchao.quantization import float8_weight_only
13
+ from torchao.quantization import quantize_
14
+ from .pipelines import SkyreelsVideoPipeline # Local import
15
+ from .offload import Offload
16
+ from .offload import OffloadConfig
17
+ from . import TaskType
18
 
19
  # DELAY ALL THESE IMPORTS:
20
  # import torch
 
45
  model_id: str,
46
  quant_model: bool = True,
47
  is_offload: bool = True,
48
+ offload_config: OffloadConfig = OffloadConfig(),
49
  use_multiprocessing: bool = False,
50
  ):
51
  self.task_type = task_type
 
53
  self.quant_model = quant_model
54
  self.is_offload = is_offload
55
  self.offload_config = offload_config
 
56
  self._initialize_pipeline()
57
 
58
  def _load_model(
 
62
  quant_model: bool = True,
63
  device: str = "cuda",
64
  ):
 
 
 
 
 
 
 
 
 
 
 
65
  logger.info(f"load model model_id:{model_id} quan_model:{quant_model} device:{device}")
 
66
  text_encoder = LlamaModel.from_pretrained(
67
  base_model_id,
68
  subfolder="text_encoder",
69
  torch_dtype=torch.bfloat16,
70
  ).to(device)
 
71
  transformer = HunyuanVideoTransformer3DModel.from_pretrained(
72
  model_id,
73
  torch_dtype=torch.bfloat16,
74
  ).to(device)
 
 
75
  if quant_model:
76
  quantize_(text_encoder, float8_weight_only(), device=device)
77
  quantize_(transformer, float8_weight_only(), device=device)
 
85
  return pipe
86
 
87
  def _initialize_pipeline(self):
 
 
 
88
  self.pipe = self._load_model( #No : SkyreelsVideoPipeline
89
  model_id=self.model_id, quant_model=self.quant_model, device="cuda"
90
  )
 
91
  if self.is_offload and self.offload_config:
92
  Offload.offload(
93
  pipeline=self.pipe,
 
95
  )
96
 
97
  def inference(self, kwargs):
 
 
98
  if self.task_type == TaskType.I2V:
99
  image = kwargs.pop("image")
100
  output = self.pipe(image=image, **kwargs)