from diffusers import LTXConditionPipeline from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition from diffusers.utils import load_video, export_to_video from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight from io import BytesIO import contextlib from typing import Any, cast from unittest.mock import patch import torch from torch._inductor.package.package import package_aoti from torch._inductor.package import load_package from PIL import Image MODEL_ID = "Lightricks/LTX-Video-0.9.8-13B-distilled" LANDSCAPE_WIDTH = 480 LANDSCAPE_HEIGHT = 832 MAX_FRAMES_MODEL = 96 INDUCTOR_CONFIGS = { "conv_1x1_as_mm": True, "epilogue_fusion": False, "coordinate_descent_tuning": True, "coordinate_descent_check_all_directions": True, "max_autotune": False, "triton.cudagraphs": True, } INDUCTOR_CONFIGS_OVERRIDES = { "aot_inductor.package_constants_in_so": False, "aot_inductor.package_constants_on_disk": True, "aot_inductor.package": True, } @contextlib.contextmanager def capture_component_call( pipeline: LTXConditionPipeline, component_name: str, component_method="forward", ): class CapturedCallException(Exception): def __init__(self, *args, **kwargs): super().__init__() self.args = args self.kwargs = kwargs class CapturedCall: def __init__(self): self.args: tuple[Any, ...] = () self.kwargs: dict[str, Any] = {} component = getattr(pipeline, component_name) captured_call = CapturedCall() def capture_call(*args, **kwargs): raise CapturedCallException(*args, **kwargs) with patch.object(component, component_method, new=capture_call): try: yield captured_call except CapturedCallException as e: captured_call.args = e.args captured_call.kwargs = e.kwargs pipe = LTXConditionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16).to("cuda") quantize_(pipe.transformer, float8_dynamic_activation_float8_weight()) resized_image = Image.new("RGB", (LANDSCAPE_WIDTH, LANDSCAPE_HEIGHT)) video = load_video(export_to_video([resized_image])) condition1 = LTXVideoCondition(video=video, frame_index=0) with capture_component_call(pipe, "transformer") as call: pipe( conditions=[condition1], prompt="prompt", height=LANDSCAPE_HEIGHT, width=LANDSCAPE_WIDTH, num_frames=MAX_FRAMES_MODEL, num_inference_steps=2 ) hidden_states: torch.Tensor = call.kwargs["hidden_states"] exported = torch.export.export( mod=pipe.transformer, args=call.args, kwargs=call.kwargs | {"hidden_states": hidden_states}, ) assert exported.example_inputs is not None args, kwargs = exported.example_inputs gm = cast(torch.fx.GraphModule, exported.module()) artifacts = torch._inductor.aot_compile( gm, args, kwargs, options=INDUCTOR_CONFIGS | INDUCTOR_CONFIGS_OVERRIDES ) archive_file = BytesIO() files = [file for file in artifacts if isinstance(file, str)] package_aoti(archive_file, files) compiled_model = load_package(archive_file, run_single_threaded=True) print("Package loaded.") transformer_config = pipe.transformer.config transformer_dtype = pipe.transformer.dtype cache_context = pipe.transformer.cache_context pipe.transformer = compiled_model pipe.transformer.config = transformer_config pipe.transformer.dtype = transformer_dtype pipe.transformer.cache_context = cache_context print("Configs done.") pipe( conditions=[condition1], prompt="prompt", height=LANDSCAPE_HEIGHT, width=LANDSCAPE_WIDTH, num_frames=MAX_FRAMES_MODEL, num_inference_steps=2 ) print("Okay")