File size: 3,725 Bytes
10bbb52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from diffusers import LTXConditionPipeline
from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
from diffusers.utils import load_video, export_to_video
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight
from io import BytesIO
import contextlib
from typing import Any, cast
from unittest.mock import patch
import torch 
from torch._inductor.package.package import package_aoti
from torch._inductor.package import load_package
from PIL import Image

MODEL_ID = "Lightricks/LTX-Video-0.9.8-13B-distilled"
LANDSCAPE_WIDTH = 480
LANDSCAPE_HEIGHT = 832
MAX_FRAMES_MODEL = 96
INDUCTOR_CONFIGS = {
    "conv_1x1_as_mm": True,
    "epilogue_fusion": False,
    "coordinate_descent_tuning": True,
    "coordinate_descent_check_all_directions": True,
    "max_autotune": False,
    "triton.cudagraphs": True,
}
INDUCTOR_CONFIGS_OVERRIDES = {
    "aot_inductor.package_constants_in_so": False,
    "aot_inductor.package_constants_on_disk": True,
    "aot_inductor.package": True,
}


@contextlib.contextmanager
def capture_component_call(
    pipeline: LTXConditionPipeline,
    component_name: str,
    component_method="forward",
):
    class CapturedCallException(Exception):
        def __init__(self, *args, **kwargs):
            super().__init__()
            self.args = args
            self.kwargs = kwargs

    class CapturedCall:
        def __init__(self):
            self.args: tuple[Any, ...] = ()
            self.kwargs: dict[str, Any] = {}

    component = getattr(pipeline, component_name)
    captured_call = CapturedCall()

    def capture_call(*args, **kwargs):
        raise CapturedCallException(*args, **kwargs)

    with patch.object(component, component_method, new=capture_call):
        try:
            yield captured_call
        except CapturedCallException as e:
            captured_call.args = e.args
            captured_call.kwargs = e.kwargs


pipe = LTXConditionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16).to("cuda")
quantize_(pipe.transformer, float8_dynamic_activation_float8_weight())

resized_image = Image.new("RGB", (LANDSCAPE_WIDTH, LANDSCAPE_HEIGHT))
video = load_video(export_to_video([resized_image]))
condition1 = LTXVideoCondition(video=video, frame_index=0)

with capture_component_call(pipe, "transformer") as call:
    pipe(
        conditions=[condition1],
        prompt="prompt",
        height=LANDSCAPE_HEIGHT,
        width=LANDSCAPE_WIDTH,
        num_frames=MAX_FRAMES_MODEL,
        num_inference_steps=2
    )

hidden_states: torch.Tensor = call.kwargs["hidden_states"]
exported = torch.export.export(
    mod=pipe.transformer,
    args=call.args,
    kwargs=call.kwargs | {"hidden_states": hidden_states},
)

assert exported.example_inputs is not None
args, kwargs = exported.example_inputs
gm = cast(torch.fx.GraphModule, exported.module())
artifacts = torch._inductor.aot_compile(
    gm, args, kwargs, options=INDUCTOR_CONFIGS | INDUCTOR_CONFIGS_OVERRIDES
)
archive_file = BytesIO()
files = [file for file in artifacts if isinstance(file, str)]
package_aoti(archive_file, files)
compiled_model = load_package(archive_file, run_single_threaded=True)
print("Package loaded.")

transformer_config = pipe.transformer.config
transformer_dtype = pipe.transformer.dtype
cache_context = pipe.transformer.cache_context
pipe.transformer = compiled_model
pipe.transformer.config = transformer_config  
pipe.transformer.dtype = transformer_dtype 
pipe.transformer.cache_context = cache_context
print("Configs done.")

pipe(
    conditions=[condition1],
    prompt="prompt",
    height=LANDSCAPE_HEIGHT,
    width=LANDSCAPE_WIDTH,
    num_frames=MAX_FRAMES_MODEL,
    num_inference_steps=2
)
print("Okay")