Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,368 Bytes
3fc0a22 d418d58 3fc0a22 10bbb52 e2da1c9 10bbb52 3fc0a22 7f65363 3fc0a22 e2da1c9 34dc3bc 3fc0a22 7f65363 34dc3bc 7f65363 3fc0a22 e2da1c9 3fc0a22 7f65363 3fc0a22 e2da1c9 7f65363 e2da1c9 7f65363 3fc0a22 0a64cad 34dc3bc 3fc0a22 10bbb52 7f65363 e2da1c9 7f65363 e2da1c9 3fc0a22 e2da1c9 b2f5bcf 3fc0a22 7f65363 34dc3bc 3fc0a22 7f65363 34dc3bc 3fc0a22 7f65363 3fc0a22 10bbb52 3fc0a22 7f65363 e2da1c9 3fc0a22 fa8999e 10bbb52 1adac55 10bbb52 34dc3bc 10bbb52 0a64cad 10bbb52 7f65363 fa8999e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
"""
Taken from https://huggingface.co/spaces/cbensimon/wan2-1-fast/
"""
from typing import Any
from typing import Callable
from typing import ParamSpec
import spaces
import spaces
import torch
from torch.utils._pytree import tree_map_only
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight
from diffusers import LTXConditionPipeline
from optimization_utils import capture_component_call, aoti_compile, cudagraph
P = ParamSpec("P")
# Sequence packing in LTX is a bit of a pain.
# See: https://github.com/huggingface/diffusers/blob/c052791b5fe29ce8a308bf63dda97aa205b729be/src/diffusers/pipelines/ltx/pipeline_ltx.py#L420
TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim.AUTO
TRANSFORMER_DYNAMIC_SHAPES = {
"hidden_states": {1: TRANSFORMER_NUM_FRAMES_DIM},
}
INDUCTOR_CONFIGS = {
"conv_1x1_as_mm": True,
"epilogue_fusion": False,
"coordinate_descent_tuning": True,
"coordinate_descent_check_all_directions": True,
"max_autotune": False, # doesn't help much
"triton.cudagraphs": True,
}
TRANSFORMER_SPATIAL_PATCH_SIZE = 1
TRANSFORMER_TEMPORAL_PATCH_SIZE = 1
VAE_SPATIAL_COMPRESSION_RATIO = 32
VAE_TEMPORAL_COMPRESSION_RATIO = 8
def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
num_frames = kwargs.get("num_frames")
height = kwargs.get("height")
width = kwargs.get("width")
latent_num_frames = (num_frames - 1) // VAE_TEMPORAL_COMPRESSION_RATIO + 1
latent_height = height // VAE_SPATIAL_COMPRESSION_RATIO
latent_width = width // VAE_SPATIAL_COMPRESSION_RATIO
@spaces.GPU(duration=1500)
def compile_transformer():
with capture_component_call(pipeline, "transformer") as call:
pipeline(*args, **kwargs)
dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
quantize_(pipeline.transformer, float8_dynamic_activation_float8_weight())
hidden_states: torch.Tensor = call.kwargs["hidden_states"]
unpacked_hidden_states = LTXConditionPipeline._unpack_latents(
hidden_states,
latent_num_frames,
latent_height,
latent_width,
TRANSFORMER_SPATIAL_PATCH_SIZE,
TRANSFORMER_TEMPORAL_PATCH_SIZE,
)
unpacked_hidden_states_transposed = unpacked_hidden_states.transpose(-1, -2).contiguous()
if unpacked_hidden_states.shape[-1] > hidden_states.shape[-2]:
hidden_states_landscape = unpacked_hidden_states
hidden_states_portrait = unpacked_hidden_states_transposed
else:
hidden_states_landscape = unpacked_hidden_states_transposed
hidden_states_portrait = unpacked_hidden_states
hidden_states_landscape = LTXConditionPipeline._pack_latents(
hidden_states_landscape, TRANSFORMER_SPATIAL_PATCH_SIZE, TRANSFORMER_TEMPORAL_PATCH_SIZE
)
hidden_states_portrait = LTXConditionPipeline._pack_latents(
hidden_states_portrait, TRANSFORMER_SPATIAL_PATCH_SIZE, TRANSFORMER_TEMPORAL_PATCH_SIZE
)
exported_landscape = torch.export.export(
mod=pipeline.transformer,
args=call.args,
kwargs=call.kwargs | {"hidden_states": hidden_states_landscape},
dynamic_shapes=dynamic_shapes,
)
exported_portrait = torch.export.export(
mod=pipeline.transformer,
args=call.args,
kwargs=call.kwargs | {"hidden_states": hidden_states_portrait},
dynamic_shapes=dynamic_shapes,
)
compiled_landscape = aoti_compile(exported_landscape, INDUCTOR_CONFIGS)
compiled_portrait = aoti_compile(exported_portrait, INDUCTOR_CONFIGS)
compiled_portrait.weights = (
compiled_landscape.weights
) # Avoid weights duplication when serializing back to main process
return compiled_landscape, compiled_portrait
compiled_landscape, compiled_portrait = compile_transformer()
@torch.no_grad()
def combined_transformer(*args, **kwargs):
hidden_states: torch.Tensor = kwargs["hidden_states"]
unpacked_hidden_states = LTXConditionPipeline._unpack_latents(
hidden_states,
latent_num_frames,
latent_height,
latent_width,
TRANSFORMER_SPATIAL_PATCH_SIZE,
TRANSFORMER_TEMPORAL_PATCH_SIZE,
)
if unpacked_hidden_states.shape[-1] > unpacked_hidden_states.shape[-2]:
return compiled_landscape(*args, **kwargs)
else:
return compiled_portrait(*args, **kwargs)
transformer_config = pipeline.transformer.config
transformer_dtype = pipeline.transformer.dtype
cache_context = pipeline.transformer.cache_context
# with torch.no_grad():
# combined_transformer(**call.kwargs)
pipeline.transformer = combined_transformer
# pipeline.transformer = cudagraph(combined_transformer)
# with torch.no_grad():
# pipeline.transformer(**call.kwargs)
pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
pipeline.transformer.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue]
pipeline.transformer.cache_context = cache_context |