1inkusFace commited on
Commit
cce4ca7
·
verified ·
1 Parent(s): e7a7164

Update skyreelsinfer/skyreels_video_infer.py

Browse files
Files changed (1) hide show
  1. skyreelsinfer/skyreels_video_infer.py +80 -55
skyreelsinfer/skyreels_video_infer.py CHANGED
@@ -1,34 +1,22 @@
1
  import logging
2
- import os # Keep os here
3
  import time
4
  from datetime import timedelta
5
  from typing import Any
6
  from typing import Dict
 
7
  import torch
8
  from diffusers import HunyuanVideoTransformer3DModel
9
- from diffusers import DiffusionPipeline
10
  from PIL import Image
11
- from transformers import LlamaModel
12
  from torchao.quantization import float8_weight_only
13
  from torchao.quantization import quantize_
14
- from .pipelines import SkyreelsVideoPipeline # Local import
15
- from .offload import Offload
16
- from .offload import OffloadConfig
17
- from . import TaskType
18
-
19
- # DELAY ALL THESE IMPORTS:
20
- # import torch
21
- # from diffusers import HunyuanVideoTransformer3DModel
22
- # from diffusers import DiffusionPipeline
23
- # from PIL import Image
24
- # from transformers import LlamaModel
25
 
26
- # from . import TaskType
27
- # from .offload import Offload
28
- # from .offload import OffloadConfig
29
- # from .pipelines import SkyreelsVideoPipeline
30
 
31
- logger = logging.getLogger("SkyReelsVideoInfer")
32
  logger.setLevel(logging.DEBUG)
33
  console_handler = logging.StreamHandler()
34
  console_handler.setLevel(logging.DEBUG)
@@ -38,66 +26,103 @@ formatter = logging.Formatter(
38
  console_handler.setFormatter(formatter)
39
  logger.addHandler(console_handler)
40
 
41
- class SkyReelsVideoInfer:
42
- def __init__(
43
- self,
44
- task_type, # No TaskType.
45
- model_id: str,
46
- quant_model: bool = True,
47
- is_offload: bool = True,
48
- offload_config: OffloadConfig = OffloadConfig(),
49
- use_multiprocessing: bool = False,
50
- ):
51
- self.task_type = task_type
52
- self.model_id = model_id
53
- self.quant_model = quant_model
54
- self.is_offload = is_offload
55
- self.offload_config = offload_config
56
- self._initialize_pipeline()
57
-
58
  def _load_model(
59
  self,
60
  model_id: str,
61
  base_model_id: str = "hunyuanvideo-community/HunyuanVideo",
62
  quant_model: bool = True,
63
- device: str = "cuda",
64
- ):
65
- logger.info(f"load model model_id:{model_id} quan_model:{quant_model} device:{device}")
66
  text_encoder = LlamaModel.from_pretrained(
67
  base_model_id,
68
  subfolder="text_encoder",
69
  torch_dtype=torch.float16,
70
- ).to(device)
71
  transformer = HunyuanVideoTransformer3DModel.from_pretrained(
72
  model_id,
 
73
  torch_dtype=torch.float16,
74
- ).to(device)
 
75
  if quant_model:
76
- quantize_(text_encoder, float8_weight_only(), device=device)
77
- quantize_(transformer, float8_weight_only(), device=device)
 
 
 
 
78
  pipe = SkyreelsVideoPipeline.from_pretrained(
79
  base_model_id,
80
  transformer=transformer,
81
  text_encoder=text_encoder,
82
  torch_dtype=torch.float16,
83
- ).to(device)
84
  pipe.vae.enable_tiling()
 
85
  return pipe
86
 
87
- def _initialize_pipeline(self):
88
- self.pipe = self._load_model( #No : SkyreelsVideoPipeline
89
- model_id=self.model_id, quant_model=self.quant_model, device="cuda"
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  )
91
- if self.is_offload and self.offload_config:
 
92
  Offload.offload(
93
  pipeline=self.pipe,
94
- config=self.offload_config,
95
  )
 
 
96
 
97
- def inference(self, kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  if self.task_type == TaskType.I2V:
99
- image = kwargs.pop("image")
100
- output = self.pipe(image=image, **kwargs)
101
- else:
102
- output = self.pipe(**kwargs)
103
- return output.frames
 
 
 
 
 
 
 
 
 
1
  import logging
2
+ import os
3
  import time
4
  from datetime import timedelta
5
  from typing import Any
6
  from typing import Dict
7
+
8
  import torch
9
  from diffusers import HunyuanVideoTransformer3DModel
 
10
  from PIL import Image
 
11
  from torchao.quantization import float8_weight_only
12
  from torchao.quantization import quantize_
13
+ from transformers import LlamaModel
 
 
 
 
 
 
 
 
 
 
14
 
15
+ from . import TaskType # Assuming these are still needed
16
+ from .offload import Offload, OffloadConfig
17
+ from .pipelines import SkyreelsVideoPipeline
 
18
 
19
+ logger = logging.getLogger("SkyreelsVideoInfer")
20
  logger.setLevel(logging.DEBUG)
21
  console_handler = logging.StreamHandler()
22
  console_handler.setLevel(logging.DEBUG)
 
26
  console_handler.setFormatter(formatter)
27
  logger.addHandler(console_handler)
28
 
29
+ class SkyReelsVideoSingleGpuInfer:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def _load_model(
31
  self,
32
  model_id: str,
33
  base_model_id: str = "hunyuanvideo-community/HunyuanVideo",
34
  quant_model: bool = True,
35
+ gpu_device: str = "cuda:0",
36
+ ) -> SkyreelsVideoPipeline:
37
+ logger.info(f"load model model_id:{model_id} quan_model:{quant_model} gpu_device:{gpu_device}")
38
  text_encoder = LlamaModel.from_pretrained(
39
  base_model_id,
40
  subfolder="text_encoder",
41
  torch_dtype=torch.float16,
42
+ ).to("cpu")
43
  transformer = HunyuanVideoTransformer3DModel.from_pretrained(
44
  model_id,
45
+ # subfolder="transformer",
46
  torch_dtype=torch.float16,
47
+ device="cpu",
48
+ ).to("cpu")
49
  if quant_model:
50
+ quantize_(text_encoder, float8_weight_only(), device="cpu")
51
+ text_encoder.to("cpu")
52
+ torch.cuda.empty_cache()
53
+ quantize_(transformer, float8_weight_only(), device="cpu")
54
+ transformer.to("cpu")
55
+ torch.cuda.empty_cache()
56
  pipe = SkyreelsVideoPipeline.from_pretrained(
57
  base_model_id,
58
  transformer=transformer,
59
  text_encoder=text_encoder,
60
  torch_dtype=torch.float16,
61
+ ).to("cpu")
62
  pipe.vae.enable_tiling()
63
+ torch.cuda.empty_cache()
64
  return pipe
65
 
66
+ def __init__(
67
+ self,
68
+ task_type: TaskType,
69
+ model_id: str,
70
+ quant_model: bool = True,
71
+ is_offload: bool = True,
72
+ offload_config: OffloadConfig = OffloadConfig(),
73
+ ):
74
+ self.task_type = task_type
75
+ # os.environ["LOCAL_RANK"] = "0" # No longer needed in single-GPU
76
+ #torch.cuda.set_device(0) # Still a good idea to be explicit.
77
+ torch.backends.cuda.enable_cudnn_sdp(False) #Still a good idea to keep it.
78
+ gpu_device = "cuda:0"
79
+
80
+ self.pipe: SkyreelsVideoPipeline = self._load_model(
81
+ model_id=model_id, quant_model=quant_model, gpu_device=gpu_device
82
  )
83
+
84
+ if is_offload:
85
  Offload.offload(
86
  pipeline=self.pipe,
87
+ config=offload_config,
88
  )
89
+ else:
90
+ self.pipe.to(gpu_device)
91
 
92
+ if offload_config.compiler_transformer:
93
+ torch._dynamo.config.suppress_errors = True
94
+ os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
95
+ os.environ["TORCHINDUCTOR_CACHE_DIR"] = f"{offload_config.compiler_cache}_1" #_1 represents 1 gpu.
96
+ self.pipe.transformer = torch.compile(
97
+ self.pipe.transformer,
98
+ mode="max-autotune-no-cudagraphs",
99
+ dynamic=True,
100
+ )
101
+ self.warm_up()
102
+
103
+ def warm_up(self):
104
+ init_kwargs = {
105
+ "prompt": "A woman is dancing in a room",
106
+ "height": 512,
107
+ "width": 512,
108
+ "guidance_scale": 6,
109
+ "num_inference_steps": 1,
110
+ "negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
111
+ "num_frames": 97,
112
+ "generator": torch.Generator("cuda").manual_seed(42),
113
+ "embedded_guidance_scale": 1.0,
114
+ }
115
  if self.task_type == TaskType.I2V:
116
+ init_kwargs["image"] = Image.new("RGB", (512, 512), color="black")
117
+ self.pipe(**init_kwargs)
118
+
119
+ def inference(self, kwargs: Dict[str, Any]):
120
+ logger.info(f"kwargs: {kwargs}")
121
+ if "seed" in kwargs:
122
+ kwargs["generator"] = torch.Generator("cuda").manual_seed(kwargs["seed"])
123
+ del kwargs["seed"]
124
+ start_time = time.time()
125
+ assert (self.task_type == TaskType.I2V and "image" in kwargs) or self.task_type == TaskType.T2V
126
+ out = self.pipe(**kwargs).frames[0]
127
+ logger.info(f"inference time: {time.time() - start_time}")
128
+ return out