Spaces:
Running
on
Zero
Running
on
Zero
up
Browse files- check.py +82 -0
- optimization.py +21 -14
- optimization_utils.py +43 -4
- reproduce.py +117 -0
- requirements.txt +0 -11
check.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import LTXConditionPipeline
|
2 |
+
from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import torch
|
6 |
+
from diffusers.utils import load_image, load_video, export_to_video
|
7 |
+
from optimization import optimize_pipeline_
|
8 |
+
|
9 |
+
MODEL_ID = "Lightricks/LTX-Video-0.9.8-13B-distilled"
|
10 |
+
|
11 |
+
LANDSCAPE_WIDTH = 480
|
12 |
+
LANDSCAPE_HEIGHT = 832
|
13 |
+
MAX_SEED = np.iinfo(np.int32).max
|
14 |
+
|
15 |
+
FIXED_FPS = 24
|
16 |
+
MIN_FRAMES_MODEL = 8
|
17 |
+
MAX_FRAMES_MODEL = 96
|
18 |
+
|
19 |
+
MIN_DURATION = round(MIN_FRAMES_MODEL / FIXED_FPS, 1)
|
20 |
+
MAX_DURATION = round(MAX_FRAMES_MODEL / FIXED_FPS, 1)
|
21 |
+
|
22 |
+
def resize_image(image: Image.Image) -> Image.Image:
|
23 |
+
if image.height > image.width:
|
24 |
+
transposed = image.transpose(Image.Transpose.ROTATE_90)
|
25 |
+
resized = resize_image_landscape(transposed)
|
26 |
+
return resized.transpose(Image.Transpose.ROTATE_270)
|
27 |
+
return resize_image_landscape(image)
|
28 |
+
|
29 |
+
|
30 |
+
def resize_image_landscape(image: Image.Image) -> Image.Image:
|
31 |
+
target_aspect = LANDSCAPE_WIDTH / LANDSCAPE_HEIGHT
|
32 |
+
width, height = image.size
|
33 |
+
in_aspect = width / height
|
34 |
+
if in_aspect > target_aspect:
|
35 |
+
new_width = round(height * target_aspect)
|
36 |
+
left = (width - new_width) // 2
|
37 |
+
image = image.crop((left, 0, left + new_width, height))
|
38 |
+
else:
|
39 |
+
new_height = round(width / target_aspect)
|
40 |
+
top = (height - new_height) // 2
|
41 |
+
image = image.crop((0, top, width, top + new_height))
|
42 |
+
return image.resize((LANDSCAPE_WIDTH, LANDSCAPE_HEIGHT), Image.LANCZOS)
|
43 |
+
|
44 |
+
|
45 |
+
pipe = LTXConditionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16).to("cuda")
|
46 |
+
dummy_image = Image.new("RGB", (LANDSCAPE_WIDTH, LANDSCAPE_HEIGHT))
|
47 |
+
video = load_video(export_to_video([dummy_image]))
|
48 |
+
condition1 = LTXVideoCondition(video=video, frame_index=0)
|
49 |
+
optimize_pipeline_(
|
50 |
+
pipe,
|
51 |
+
conditions=[condition1],
|
52 |
+
prompt="prompt",
|
53 |
+
height=LANDSCAPE_HEIGHT,
|
54 |
+
width=LANDSCAPE_WIDTH,
|
55 |
+
num_frames=MAX_FRAMES_MODEL,
|
56 |
+
num_inference_steps=2
|
57 |
+
)
|
58 |
+
|
59 |
+
default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
|
60 |
+
default_negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards, watermark, text, signature"
|
61 |
+
input_image = load_image("peng.png")
|
62 |
+
duration_seconds = MAX_DURATION
|
63 |
+
guidance_scale = 1.0
|
64 |
+
num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
|
65 |
+
current_seed = 42
|
66 |
+
resized_image = resize_image(input_image)
|
67 |
+
steps = 8
|
68 |
+
|
69 |
+
video = load_video(export_to_video([resized_image]))
|
70 |
+
condition1 = LTXVideoCondition(video=video, frame_index=0)
|
71 |
+
|
72 |
+
output_frames_list = pipe(
|
73 |
+
conditions=[condition1],
|
74 |
+
prompt=default_prompt_i2v,
|
75 |
+
negative_prompt=default_negative_prompt,
|
76 |
+
height=resized_image.height,
|
77 |
+
width=resized_image.width,
|
78 |
+
num_frames=num_frames,
|
79 |
+
guidance_scale=float(guidance_scale),
|
80 |
+
num_inference_steps=int(steps),
|
81 |
+
generator=torch.Generator(device="cuda").manual_seed(current_seed),
|
82 |
+
).frames[0]
|
optimization.py
CHANGED
@@ -9,11 +9,9 @@ from typing import ParamSpec
|
|
9 |
import spaces
|
10 |
import torch
|
11 |
from torch.utils._pytree import tree_map_only
|
12 |
-
from torchao.quantization import quantize_
|
13 |
-
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
|
14 |
from diffusers import LTXConditionPipeline
|
15 |
-
from optimization_utils import capture_component_call
|
16 |
-
from optimization_utils import aoti_compile
|
17 |
|
18 |
|
19 |
P = ParamSpec("P")
|
@@ -33,7 +31,8 @@ INDUCTOR_CONFIGS = {
|
|
33 |
"epilogue_fusion": False,
|
34 |
"coordinate_descent_tuning": True,
|
35 |
"coordinate_descent_check_all_directions": True,
|
36 |
-
"max_autotune": True,
|
|
|
37 |
"triton.cudagraphs": True,
|
38 |
}
|
39 |
TRANSFORMER_SPATIAL_PATCH_SIZE = 1
|
@@ -50,15 +49,15 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
|
|
50 |
latent_height = height // VAE_SPATIAL_COMPRESSION_RATIO
|
51 |
latent_width = width // VAE_SPATIAL_COMPRESSION_RATIO
|
52 |
|
|
|
|
|
|
|
53 |
@spaces.GPU(duration=1500)
|
54 |
def compile_transformer():
|
55 |
-
|
56 |
-
pipeline(*args, **kwargs)
|
57 |
-
|
58 |
-
dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
|
59 |
# dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
|
60 |
|
61 |
-
quantize_(pipeline.transformer,
|
62 |
|
63 |
hidden_states: torch.Tensor = call.kwargs["hidden_states"]
|
64 |
unpacked_hidden_states = LTXConditionPipeline._unpack_latents(
|
@@ -88,14 +87,13 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
|
|
88 |
mod=pipeline.transformer,
|
89 |
args=call.args,
|
90 |
kwargs=call.kwargs | {"hidden_states": hidden_states_landscape},
|
91 |
-
dynamic_shapes=dynamic_shapes,
|
92 |
)
|
93 |
-
|
94 |
exported_portrait = torch.export.export(
|
95 |
mod=pipeline.transformer,
|
96 |
args=call.args,
|
97 |
kwargs=call.kwargs | {"hidden_states": hidden_states_portrait},
|
98 |
-
dynamic_shapes=dynamic_shapes,
|
99 |
)
|
100 |
|
101 |
compiled_landscape = aoti_compile(exported_landscape, INDUCTOR_CONFIGS)
|
@@ -108,6 +106,7 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
|
|
108 |
|
109 |
compiled_landscape, compiled_portrait = compile_transformer()
|
110 |
|
|
|
111 |
def combined_transformer(*args, **kwargs):
|
112 |
hidden_states: torch.Tensor = kwargs["hidden_states"]
|
113 |
unpacked_hidden_states = LTXConditionPipeline._unpack_latents(
|
@@ -126,7 +125,15 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
|
|
126 |
transformer_config = pipeline.transformer.config
|
127 |
transformer_dtype = pipeline.transformer.dtype
|
128 |
cache_context = pipeline.transformer.cache_context
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
|
131 |
pipeline.transformer.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue]
|
132 |
pipeline.transformer.cache_context = cache_context
|
|
|
9 |
import spaces
|
10 |
import torch
|
11 |
from torch.utils._pytree import tree_map_only
|
12 |
+
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight
|
|
|
13 |
from diffusers import LTXConditionPipeline
|
14 |
+
from optimization_utils import capture_component_call, aoti_compile, cudagraph
|
|
|
15 |
|
16 |
|
17 |
P = ParamSpec("P")
|
|
|
31 |
"epilogue_fusion": False,
|
32 |
"coordinate_descent_tuning": True,
|
33 |
"coordinate_descent_check_all_directions": True,
|
34 |
+
# "max_autotune": True,
|
35 |
+
"max_autotune": False,
|
36 |
"triton.cudagraphs": True,
|
37 |
}
|
38 |
TRANSFORMER_SPATIAL_PATCH_SIZE = 1
|
|
|
49 |
latent_height = height // VAE_SPATIAL_COMPRESSION_RATIO
|
50 |
latent_width = width // VAE_SPATIAL_COMPRESSION_RATIO
|
51 |
|
52 |
+
with capture_component_call(pipeline, "transformer") as call:
|
53 |
+
pipeline(*args, **kwargs)
|
54 |
+
|
55 |
@spaces.GPU(duration=1500)
|
56 |
def compile_transformer():
|
57 |
+
# dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
|
|
|
|
|
|
|
58 |
# dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
|
59 |
|
60 |
+
quantize_(pipeline.transformer, float8_dynamic_activation_float8_weight())
|
61 |
|
62 |
hidden_states: torch.Tensor = call.kwargs["hidden_states"]
|
63 |
unpacked_hidden_states = LTXConditionPipeline._unpack_latents(
|
|
|
87 |
mod=pipeline.transformer,
|
88 |
args=call.args,
|
89 |
kwargs=call.kwargs | {"hidden_states": hidden_states_landscape},
|
90 |
+
# dynamic_shapes=dynamic_shapes,
|
91 |
)
|
|
|
92 |
exported_portrait = torch.export.export(
|
93 |
mod=pipeline.transformer,
|
94 |
args=call.args,
|
95 |
kwargs=call.kwargs | {"hidden_states": hidden_states_portrait},
|
96 |
+
# dynamic_shapes=dynamic_shapes,
|
97 |
)
|
98 |
|
99 |
compiled_landscape = aoti_compile(exported_landscape, INDUCTOR_CONFIGS)
|
|
|
106 |
|
107 |
compiled_landscape, compiled_portrait = compile_transformer()
|
108 |
|
109 |
+
@torch.no_grad()
|
110 |
def combined_transformer(*args, **kwargs):
|
111 |
hidden_states: torch.Tensor = kwargs["hidden_states"]
|
112 |
unpacked_hidden_states = LTXConditionPipeline._unpack_latents(
|
|
|
125 |
transformer_config = pipeline.transformer.config
|
126 |
transformer_dtype = pipeline.transformer.dtype
|
127 |
cache_context = pipeline.transformer.cache_context
|
128 |
+
|
129 |
+
with torch.no_grad():
|
130 |
+
combined_transformer(**call.kwargs)
|
131 |
+
|
132 |
+
pipeline.transformer = cudagraph(combined_transformer)
|
133 |
+
|
134 |
+
with torch.no_grad():
|
135 |
+
pipeline.transformer(**call.kwargs)
|
136 |
+
|
137 |
pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
|
138 |
pipeline.transformer.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue]
|
139 |
pipeline.transformer.cache_context = cache_context
|
optimization_utils.py
CHANGED
@@ -5,12 +5,13 @@ Taken from https://huggingface.co/spaces/cbensimon/wan2-1-fast/
|
|
5 |
import contextlib
|
6 |
from contextvars import ContextVar
|
7 |
from io import BytesIO
|
8 |
-
from typing import Any
|
9 |
-
from typing import cast
|
10 |
from unittest.mock import patch
|
11 |
|
12 |
import torch
|
|
|
13 |
from torch._inductor.package.package import package_aoti
|
|
|
14 |
from torch.export.pt2_archive._package import AOTICompiledModel
|
15 |
from torch.export.pt2_archive._package_weights import Weights
|
16 |
|
@@ -45,7 +46,9 @@ class ZeroGPUCompiledModel:
|
|
45 |
|
46 |
def __call__(self, *args, **kwargs):
|
47 |
if (compiled_model := self.compiled_model.get()) is None:
|
48 |
-
compiled_model = cast(AOTICompiledModel, torch._inductor.aoti_load_package(self.archive_file))
|
|
|
|
|
49 |
compiled_model.load_constants(self.weights.constants_map, check_full_update=True, user_managed=True)
|
50 |
self.compiled_model.set(compiled_model)
|
51 |
return compiled_model(*args, **kwargs)
|
@@ -67,7 +70,7 @@ def aoti_compile(
|
|
67 |
files: list[str | Weights] = [file for file in artifacts if isinstance(file, str)]
|
68 |
package_aoti(archive_file, files)
|
69 |
(weights,) = (artifact for artifact in artifacts if isinstance(artifact, Weights))
|
70 |
-
zerogpu_weights = ZeroGPUWeights({name: weights.get_weight(name)[0] for name in weights})
|
71 |
return ZeroGPUCompiledModel(archive_file, zerogpu_weights)
|
72 |
|
73 |
|
@@ -100,3 +103,39 @@ def capture_component_call(
|
|
100 |
except CapturedCallException as e:
|
101 |
captured_call.args = e.args
|
102 |
captured_call.kwargs = e.kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import contextlib
|
6 |
from contextvars import ContextVar
|
7 |
from io import BytesIO
|
8 |
+
from typing import Any, cast
|
|
|
9 |
from unittest.mock import patch
|
10 |
|
11 |
import torch
|
12 |
+
from torch.utils._pytree import tree_map_only
|
13 |
from torch._inductor.package.package import package_aoti
|
14 |
+
from torch._inductor.package import load_package
|
15 |
from torch.export.pt2_archive._package import AOTICompiledModel
|
16 |
from torch.export.pt2_archive._package_weights import Weights
|
17 |
|
|
|
46 |
|
47 |
def __call__(self, *args, **kwargs):
|
48 |
if (compiled_model := self.compiled_model.get()) is None:
|
49 |
+
# compiled_model = cast(AOTICompiledModel, torch._inductor.aoti_load_package(self.archive_file))
|
50 |
+
# compiled_model = torch._inductor.aoti_load_package(self.archive_file, run_single_threaded=True)
|
51 |
+
compiled_model = load_package(self.archive_file, run_single_threaded=True)
|
52 |
compiled_model.load_constants(self.weights.constants_map, check_full_update=True, user_managed=True)
|
53 |
self.compiled_model.set(compiled_model)
|
54 |
return compiled_model(*args, **kwargs)
|
|
|
70 |
files: list[str | Weights] = [file for file in artifacts if isinstance(file, str)]
|
71 |
package_aoti(archive_file, files)
|
72 |
(weights,) = (artifact for artifact in artifacts if isinstance(artifact, Weights))
|
73 |
+
zerogpu_weights = ZeroGPUWeights({name: weights.get_weight(name)[0] for name in weights}, to_cuda=True)
|
74 |
return ZeroGPUCompiledModel(archive_file, zerogpu_weights)
|
75 |
|
76 |
|
|
|
103 |
except CapturedCallException as e:
|
104 |
captured_call.args = e.args
|
105 |
captured_call.kwargs = e.kwargs
|
106 |
+
|
107 |
+
|
108 |
+
# Taken from
|
109 |
+
# https://github.com/huggingface/flux-fast/blob/5027798d7f69a8e0e478df92f48663c40727f8ea/utils/pipeline_utils.py#L198C1-L231C14
|
110 |
+
def cudagraph(f):
|
111 |
+
_graphs = {}
|
112 |
+
def f_(*args, **kwargs):
|
113 |
+
key = hash(tuple(tuple(kwargs[a].shape) for a in sorted(kwargs.keys())
|
114 |
+
if isinstance(kwargs[a], torch.Tensor)))
|
115 |
+
if key in _graphs:
|
116 |
+
# use the cached wrapper if one exists. this will perform CUDAGraph replay
|
117 |
+
wrapped, *_ = _graphs[key]
|
118 |
+
return wrapped(*args, **kwargs)
|
119 |
+
|
120 |
+
# record a new CUDAGraph and cache it for future use
|
121 |
+
g = torch.cuda.CUDAGraph()
|
122 |
+
in_args, in_kwargs = tree_map_only(torch.Tensor, lambda t: t.clone(), (args, kwargs))
|
123 |
+
f(*in_args, **in_kwargs) # stream warmup
|
124 |
+
with torch.cuda.graph(g):
|
125 |
+
out_tensors = f(*in_args, **in_kwargs)
|
126 |
+
def wrapped(*args, **kwargs):
|
127 |
+
# note that CUDAGraphs require inputs / outputs to be in fixed memory locations.
|
128 |
+
# inputs must be copied into the fixed input memory locations.
|
129 |
+
[a.copy_(b) for a, b in zip(in_args, args) if isinstance(a, torch.Tensor)]
|
130 |
+
for key in kwargs:
|
131 |
+
if isinstance(kwargs[key], torch.Tensor):
|
132 |
+
in_kwargs[key].copy_(kwargs[key])
|
133 |
+
g.replay()
|
134 |
+
# clone() outputs on the way out to disconnect them from the fixed output memory
|
135 |
+
# locations. this allows for CUDAGraph reuse without accidentally overwriting memory
|
136 |
+
return [o.clone() for o in out_tensors]
|
137 |
+
|
138 |
+
# cache function that does CUDAGraph replay
|
139 |
+
_graphs[key] = (wrapped, g, in_args, in_kwargs, out_tensors)
|
140 |
+
return wrapped(*args, **kwargs)
|
141 |
+
return f_
|
reproduce.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import LTXConditionPipeline
|
2 |
+
from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
|
3 |
+
from diffusers.utils import load_video, export_to_video
|
4 |
+
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight
|
5 |
+
from io import BytesIO
|
6 |
+
import contextlib
|
7 |
+
from typing import Any, cast
|
8 |
+
from unittest.mock import patch
|
9 |
+
import torch
|
10 |
+
from torch._inductor.package.package import package_aoti
|
11 |
+
from torch._inductor.package import load_package
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
MODEL_ID = "Lightricks/LTX-Video-0.9.8-13B-distilled"
|
15 |
+
LANDSCAPE_WIDTH = 480
|
16 |
+
LANDSCAPE_HEIGHT = 832
|
17 |
+
MAX_FRAMES_MODEL = 96
|
18 |
+
INDUCTOR_CONFIGS = {
|
19 |
+
"conv_1x1_as_mm": True,
|
20 |
+
"epilogue_fusion": False,
|
21 |
+
"coordinate_descent_tuning": True,
|
22 |
+
"coordinate_descent_check_all_directions": True,
|
23 |
+
"max_autotune": False,
|
24 |
+
"triton.cudagraphs": True,
|
25 |
+
}
|
26 |
+
INDUCTOR_CONFIGS_OVERRIDES = {
|
27 |
+
"aot_inductor.package_constants_in_so": False,
|
28 |
+
"aot_inductor.package_constants_on_disk": True,
|
29 |
+
"aot_inductor.package": True,
|
30 |
+
}
|
31 |
+
|
32 |
+
|
33 |
+
@contextlib.contextmanager
|
34 |
+
def capture_component_call(
|
35 |
+
pipeline: LTXConditionPipeline,
|
36 |
+
component_name: str,
|
37 |
+
component_method="forward",
|
38 |
+
):
|
39 |
+
class CapturedCallException(Exception):
|
40 |
+
def __init__(self, *args, **kwargs):
|
41 |
+
super().__init__()
|
42 |
+
self.args = args
|
43 |
+
self.kwargs = kwargs
|
44 |
+
|
45 |
+
class CapturedCall:
|
46 |
+
def __init__(self):
|
47 |
+
self.args: tuple[Any, ...] = ()
|
48 |
+
self.kwargs: dict[str, Any] = {}
|
49 |
+
|
50 |
+
component = getattr(pipeline, component_name)
|
51 |
+
captured_call = CapturedCall()
|
52 |
+
|
53 |
+
def capture_call(*args, **kwargs):
|
54 |
+
raise CapturedCallException(*args, **kwargs)
|
55 |
+
|
56 |
+
with patch.object(component, component_method, new=capture_call):
|
57 |
+
try:
|
58 |
+
yield captured_call
|
59 |
+
except CapturedCallException as e:
|
60 |
+
captured_call.args = e.args
|
61 |
+
captured_call.kwargs = e.kwargs
|
62 |
+
|
63 |
+
|
64 |
+
pipe = LTXConditionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16).to("cuda")
|
65 |
+
quantize_(pipe.transformer, float8_dynamic_activation_float8_weight())
|
66 |
+
|
67 |
+
resized_image = Image.new("RGB", (LANDSCAPE_WIDTH, LANDSCAPE_HEIGHT))
|
68 |
+
video = load_video(export_to_video([resized_image]))
|
69 |
+
condition1 = LTXVideoCondition(video=video, frame_index=0)
|
70 |
+
|
71 |
+
with capture_component_call(pipe, "transformer") as call:
|
72 |
+
pipe(
|
73 |
+
conditions=[condition1],
|
74 |
+
prompt="prompt",
|
75 |
+
height=LANDSCAPE_HEIGHT,
|
76 |
+
width=LANDSCAPE_WIDTH,
|
77 |
+
num_frames=MAX_FRAMES_MODEL,
|
78 |
+
num_inference_steps=2
|
79 |
+
)
|
80 |
+
|
81 |
+
hidden_states: torch.Tensor = call.kwargs["hidden_states"]
|
82 |
+
exported = torch.export.export(
|
83 |
+
mod=pipe.transformer,
|
84 |
+
args=call.args,
|
85 |
+
kwargs=call.kwargs | {"hidden_states": hidden_states},
|
86 |
+
)
|
87 |
+
|
88 |
+
assert exported.example_inputs is not None
|
89 |
+
args, kwargs = exported.example_inputs
|
90 |
+
gm = cast(torch.fx.GraphModule, exported.module())
|
91 |
+
artifacts = torch._inductor.aot_compile(
|
92 |
+
gm, args, kwargs, options=INDUCTOR_CONFIGS | INDUCTOR_CONFIGS_OVERRIDES
|
93 |
+
)
|
94 |
+
archive_file = BytesIO()
|
95 |
+
files = [file for file in artifacts if isinstance(file, str)]
|
96 |
+
package_aoti(archive_file, files)
|
97 |
+
compiled_model = load_package(archive_file, run_single_threaded=True)
|
98 |
+
print("Package loaded.")
|
99 |
+
|
100 |
+
transformer_config = pipe.transformer.config
|
101 |
+
transformer_dtype = pipe.transformer.dtype
|
102 |
+
cache_context = pipe.transformer.cache_context
|
103 |
+
pipe.transformer = compiled_model
|
104 |
+
pipe.transformer.config = transformer_config
|
105 |
+
pipe.transformer.dtype = transformer_dtype
|
106 |
+
pipe.transformer.cache_context = cache_context
|
107 |
+
print("Configs done.")
|
108 |
+
|
109 |
+
pipe(
|
110 |
+
conditions=[condition1],
|
111 |
+
prompt="prompt",
|
112 |
+
height=LANDSCAPE_HEIGHT,
|
113 |
+
width=LANDSCAPE_WIDTH,
|
114 |
+
num_frames=MAX_FRAMES_MODEL,
|
115 |
+
num_inference_steps=2
|
116 |
+
)
|
117 |
+
print("Okay")
|
requirements.txt
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
git+https://github.com/huggingface/diffusers.git
|
2 |
-
transformers
|
3 |
-
accelerate
|
4 |
-
safetensors
|
5 |
-
sentencepiece
|
6 |
-
peft
|
7 |
-
ftfy
|
8 |
-
imageio
|
9 |
-
imageio-ffmpeg
|
10 |
-
opencv-python
|
11 |
-
torchao==0.11.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|