ltx-dev-fast / reproduce.py
sayakpaul's picture
sayakpaul HF Staff
up
10bbb52
from diffusers import LTXConditionPipeline
from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
from diffusers.utils import load_video, export_to_video
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight
from io import BytesIO
import contextlib
from typing import Any, cast
from unittest.mock import patch
import torch
from torch._inductor.package.package import package_aoti
from torch._inductor.package import load_package
from PIL import Image
MODEL_ID = "Lightricks/LTX-Video-0.9.8-13B-distilled"
LANDSCAPE_WIDTH = 480
LANDSCAPE_HEIGHT = 832
MAX_FRAMES_MODEL = 96
INDUCTOR_CONFIGS = {
"conv_1x1_as_mm": True,
"epilogue_fusion": False,
"coordinate_descent_tuning": True,
"coordinate_descent_check_all_directions": True,
"max_autotune": False,
"triton.cudagraphs": True,
}
INDUCTOR_CONFIGS_OVERRIDES = {
"aot_inductor.package_constants_in_so": False,
"aot_inductor.package_constants_on_disk": True,
"aot_inductor.package": True,
}
@contextlib.contextmanager
def capture_component_call(
pipeline: LTXConditionPipeline,
component_name: str,
component_method="forward",
):
class CapturedCallException(Exception):
def __init__(self, *args, **kwargs):
super().__init__()
self.args = args
self.kwargs = kwargs
class CapturedCall:
def __init__(self):
self.args: tuple[Any, ...] = ()
self.kwargs: dict[str, Any] = {}
component = getattr(pipeline, component_name)
captured_call = CapturedCall()
def capture_call(*args, **kwargs):
raise CapturedCallException(*args, **kwargs)
with patch.object(component, component_method, new=capture_call):
try:
yield captured_call
except CapturedCallException as e:
captured_call.args = e.args
captured_call.kwargs = e.kwargs
pipe = LTXConditionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16).to("cuda")
quantize_(pipe.transformer, float8_dynamic_activation_float8_weight())
resized_image = Image.new("RGB", (LANDSCAPE_WIDTH, LANDSCAPE_HEIGHT))
video = load_video(export_to_video([resized_image]))
condition1 = LTXVideoCondition(video=video, frame_index=0)
with capture_component_call(pipe, "transformer") as call:
pipe(
conditions=[condition1],
prompt="prompt",
height=LANDSCAPE_HEIGHT,
width=LANDSCAPE_WIDTH,
num_frames=MAX_FRAMES_MODEL,
num_inference_steps=2
)
hidden_states: torch.Tensor = call.kwargs["hidden_states"]
exported = torch.export.export(
mod=pipe.transformer,
args=call.args,
kwargs=call.kwargs | {"hidden_states": hidden_states},
)
assert exported.example_inputs is not None
args, kwargs = exported.example_inputs
gm = cast(torch.fx.GraphModule, exported.module())
artifacts = torch._inductor.aot_compile(
gm, args, kwargs, options=INDUCTOR_CONFIGS | INDUCTOR_CONFIGS_OVERRIDES
)
archive_file = BytesIO()
files = [file for file in artifacts if isinstance(file, str)]
package_aoti(archive_file, files)
compiled_model = load_package(archive_file, run_single_threaded=True)
print("Package loaded.")
transformer_config = pipe.transformer.config
transformer_dtype = pipe.transformer.dtype
cache_context = pipe.transformer.cache_context
pipe.transformer = compiled_model
pipe.transformer.config = transformer_config
pipe.transformer.dtype = transformer_dtype
pipe.transformer.cache_context = cache_context
print("Configs done.")
pipe(
conditions=[condition1],
prompt="prompt",
height=LANDSCAPE_HEIGHT,
width=LANDSCAPE_WIDTH,
num_frames=MAX_FRAMES_MODEL,
num_inference_steps=2
)
print("Okay")