Spaces:
Build error
Build error
Update skyreelsinfer/skyreels_video_infer.py
Browse files
skyreelsinfer/skyreels_video_infer.py
CHANGED
@@ -1,34 +1,22 @@
|
|
1 |
import logging
|
2 |
-
import os
|
3 |
import time
|
4 |
from datetime import timedelta
|
5 |
from typing import Any
|
6 |
from typing import Dict
|
|
|
7 |
import torch
|
8 |
from diffusers import HunyuanVideoTransformer3DModel
|
9 |
-
from diffusers import DiffusionPipeline
|
10 |
from PIL import Image
|
11 |
-
from transformers import LlamaModel
|
12 |
from torchao.quantization import float8_weight_only
|
13 |
from torchao.quantization import quantize_
|
14 |
-
from
|
15 |
-
from .offload import Offload
|
16 |
-
from .offload import OffloadConfig
|
17 |
-
from . import TaskType
|
18 |
-
|
19 |
-
# DELAY ALL THESE IMPORTS:
|
20 |
-
# import torch
|
21 |
-
# from diffusers import HunyuanVideoTransformer3DModel
|
22 |
-
# from diffusers import DiffusionPipeline
|
23 |
-
# from PIL import Image
|
24 |
-
# from transformers import LlamaModel
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
# from .pipelines import SkyreelsVideoPipeline
|
30 |
|
31 |
-
logger = logging.getLogger("
|
32 |
logger.setLevel(logging.DEBUG)
|
33 |
console_handler = logging.StreamHandler()
|
34 |
console_handler.setLevel(logging.DEBUG)
|
@@ -38,66 +26,103 @@ formatter = logging.Formatter(
|
|
38 |
console_handler.setFormatter(formatter)
|
39 |
logger.addHandler(console_handler)
|
40 |
|
41 |
-
class
|
42 |
-
def __init__(
|
43 |
-
self,
|
44 |
-
task_type, # No TaskType.
|
45 |
-
model_id: str,
|
46 |
-
quant_model: bool = True,
|
47 |
-
is_offload: bool = True,
|
48 |
-
offload_config: OffloadConfig = OffloadConfig(),
|
49 |
-
use_multiprocessing: bool = False,
|
50 |
-
):
|
51 |
-
self.task_type = task_type
|
52 |
-
self.model_id = model_id
|
53 |
-
self.quant_model = quant_model
|
54 |
-
self.is_offload = is_offload
|
55 |
-
self.offload_config = offload_config
|
56 |
-
self._initialize_pipeline()
|
57 |
-
|
58 |
def _load_model(
|
59 |
self,
|
60 |
model_id: str,
|
61 |
base_model_id: str = "hunyuanvideo-community/HunyuanVideo",
|
62 |
quant_model: bool = True,
|
63 |
-
|
64 |
-
):
|
65 |
-
logger.info(f"load model model_id:{model_id} quan_model:{quant_model}
|
66 |
text_encoder = LlamaModel.from_pretrained(
|
67 |
base_model_id,
|
68 |
subfolder="text_encoder",
|
69 |
torch_dtype=torch.float16,
|
70 |
-
).to(
|
71 |
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
|
72 |
model_id,
|
|
|
73 |
torch_dtype=torch.float16,
|
74 |
-
|
|
|
75 |
if quant_model:
|
76 |
-
quantize_(text_encoder, float8_weight_only(), device=
|
77 |
-
|
|
|
|
|
|
|
|
|
78 |
pipe = SkyreelsVideoPipeline.from_pretrained(
|
79 |
base_model_id,
|
80 |
transformer=transformer,
|
81 |
text_encoder=text_encoder,
|
82 |
torch_dtype=torch.float16,
|
83 |
-
).to(
|
84 |
pipe.vae.enable_tiling()
|
|
|
85 |
return pipe
|
86 |
|
87 |
-
def
|
88 |
-
self
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
)
|
91 |
-
|
|
|
92 |
Offload.offload(
|
93 |
pipeline=self.pipe,
|
94 |
-
config=
|
95 |
)
|
|
|
|
|
96 |
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
if self.task_type == TaskType.I2V:
|
99 |
-
image =
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import logging
|
2 |
+
import os
|
3 |
import time
|
4 |
from datetime import timedelta
|
5 |
from typing import Any
|
6 |
from typing import Dict
|
7 |
+
|
8 |
import torch
|
9 |
from diffusers import HunyuanVideoTransformer3DModel
|
|
|
10 |
from PIL import Image
|
|
|
11 |
from torchao.quantization import float8_weight_only
|
12 |
from torchao.quantization import quantize_
|
13 |
+
from transformers import LlamaModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
+
from . import TaskType # Assuming these are still needed
|
16 |
+
from .offload import Offload, OffloadConfig
|
17 |
+
from .pipelines import SkyreelsVideoPipeline
|
|
|
18 |
|
19 |
+
logger = logging.getLogger("SkyreelsVideoInfer")
|
20 |
logger.setLevel(logging.DEBUG)
|
21 |
console_handler = logging.StreamHandler()
|
22 |
console_handler.setLevel(logging.DEBUG)
|
|
|
26 |
console_handler.setFormatter(formatter)
|
27 |
logger.addHandler(console_handler)
|
28 |
|
29 |
+
class SkyReelsVideoSingleGpuInfer:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
def _load_model(
|
31 |
self,
|
32 |
model_id: str,
|
33 |
base_model_id: str = "hunyuanvideo-community/HunyuanVideo",
|
34 |
quant_model: bool = True,
|
35 |
+
gpu_device: str = "cuda:0",
|
36 |
+
) -> SkyreelsVideoPipeline:
|
37 |
+
logger.info(f"load model model_id:{model_id} quan_model:{quant_model} gpu_device:{gpu_device}")
|
38 |
text_encoder = LlamaModel.from_pretrained(
|
39 |
base_model_id,
|
40 |
subfolder="text_encoder",
|
41 |
torch_dtype=torch.float16,
|
42 |
+
).to("cpu")
|
43 |
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
|
44 |
model_id,
|
45 |
+
# subfolder="transformer",
|
46 |
torch_dtype=torch.float16,
|
47 |
+
device="cpu",
|
48 |
+
).to("cpu")
|
49 |
if quant_model:
|
50 |
+
quantize_(text_encoder, float8_weight_only(), device="cpu")
|
51 |
+
text_encoder.to("cpu")
|
52 |
+
torch.cuda.empty_cache()
|
53 |
+
quantize_(transformer, float8_weight_only(), device="cpu")
|
54 |
+
transformer.to("cpu")
|
55 |
+
torch.cuda.empty_cache()
|
56 |
pipe = SkyreelsVideoPipeline.from_pretrained(
|
57 |
base_model_id,
|
58 |
transformer=transformer,
|
59 |
text_encoder=text_encoder,
|
60 |
torch_dtype=torch.float16,
|
61 |
+
).to("cpu")
|
62 |
pipe.vae.enable_tiling()
|
63 |
+
torch.cuda.empty_cache()
|
64 |
return pipe
|
65 |
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
task_type: TaskType,
|
69 |
+
model_id: str,
|
70 |
+
quant_model: bool = True,
|
71 |
+
is_offload: bool = True,
|
72 |
+
offload_config: OffloadConfig = OffloadConfig(),
|
73 |
+
):
|
74 |
+
self.task_type = task_type
|
75 |
+
# os.environ["LOCAL_RANK"] = "0" # No longer needed in single-GPU
|
76 |
+
#torch.cuda.set_device(0) # Still a good idea to be explicit.
|
77 |
+
torch.backends.cuda.enable_cudnn_sdp(False) #Still a good idea to keep it.
|
78 |
+
gpu_device = "cuda:0"
|
79 |
+
|
80 |
+
self.pipe: SkyreelsVideoPipeline = self._load_model(
|
81 |
+
model_id=model_id, quant_model=quant_model, gpu_device=gpu_device
|
82 |
)
|
83 |
+
|
84 |
+
if is_offload:
|
85 |
Offload.offload(
|
86 |
pipeline=self.pipe,
|
87 |
+
config=offload_config,
|
88 |
)
|
89 |
+
else:
|
90 |
+
self.pipe.to(gpu_device)
|
91 |
|
92 |
+
if offload_config.compiler_transformer:
|
93 |
+
torch._dynamo.config.suppress_errors = True
|
94 |
+
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
|
95 |
+
os.environ["TORCHINDUCTOR_CACHE_DIR"] = f"{offload_config.compiler_cache}_1" #_1 represents 1 gpu.
|
96 |
+
self.pipe.transformer = torch.compile(
|
97 |
+
self.pipe.transformer,
|
98 |
+
mode="max-autotune-no-cudagraphs",
|
99 |
+
dynamic=True,
|
100 |
+
)
|
101 |
+
self.warm_up()
|
102 |
+
|
103 |
+
def warm_up(self):
|
104 |
+
init_kwargs = {
|
105 |
+
"prompt": "A woman is dancing in a room",
|
106 |
+
"height": 512,
|
107 |
+
"width": 512,
|
108 |
+
"guidance_scale": 6,
|
109 |
+
"num_inference_steps": 1,
|
110 |
+
"negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
|
111 |
+
"num_frames": 97,
|
112 |
+
"generator": torch.Generator("cuda").manual_seed(42),
|
113 |
+
"embedded_guidance_scale": 1.0,
|
114 |
+
}
|
115 |
if self.task_type == TaskType.I2V:
|
116 |
+
init_kwargs["image"] = Image.new("RGB", (512, 512), color="black")
|
117 |
+
self.pipe(**init_kwargs)
|
118 |
+
|
119 |
+
def inference(self, kwargs: Dict[str, Any]):
|
120 |
+
logger.info(f"kwargs: {kwargs}")
|
121 |
+
if "seed" in kwargs:
|
122 |
+
kwargs["generator"] = torch.Generator("cuda").manual_seed(kwargs["seed"])
|
123 |
+
del kwargs["seed"]
|
124 |
+
start_time = time.time()
|
125 |
+
assert (self.task_type == TaskType.I2V and "image" in kwargs) or self.task_type == TaskType.T2V
|
126 |
+
out = self.pipe(**kwargs).frames[0]
|
127 |
+
logger.info(f"inference time: {time.time() - start_time}")
|
128 |
+
return out
|