Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,407 Bytes
3fc0a22 10bbb52 e2da1c9 10bbb52 3fc0a22 7f65363 3fc0a22 e2da1c9 c11b1fb 3fc0a22 832becd c11b1fb 3fc0a22 7f65363 10bbb52 7f65363 3fc0a22 e2da1c9 3fc0a22 7f65363 3fc0a22 e2da1c9 7f65363 e2da1c9 7f65363 3fc0a22 10bbb52 3fc0a22 10bbb52 832becd 3fc0a22 10bbb52 7f65363 e2da1c9 7f65363 e2da1c9 3fc0a22 e2da1c9 b2f5bcf 3fc0a22 7f65363 10bbb52 3fc0a22 7f65363 10bbb52 3fc0a22 7f65363 3fc0a22 10bbb52 3fc0a22 7f65363 e2da1c9 3fc0a22 fa8999e 10bbb52 7f65363 fa8999e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
"""
Taken from https://huggingface.co/spaces/cbensimon/wan2-1-fast/
"""
from typing import Any
from typing import Callable
from typing import ParamSpec
import spaces
import torch
from torch.utils._pytree import tree_map_only
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight
from diffusers import LTXConditionPipeline
from optimization_utils import capture_component_call, aoti_compile, cudagraph
P = ParamSpec("P")
# Sequence packing in LTX is a bit of a pain.
# See: https://github.com/huggingface/diffusers/blob/c052791b5fe29ce8a308bf63dda97aa205b729be/src/diffusers/pipelines/ltx/pipeline_ltx.py#L420
# TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim("seq_len", min=4680, max=4680)
# Unused currently as I don't know how to make the best use of it for LTX.
# TRANSFORMER_DYNAMIC_SHAPES = {
# "hidden_states": {1: TRANSFORMER_NUM_FRAMES_DIM},
# }
INDUCTOR_CONFIGS = {
"conv_1x1_as_mm": True,
"epilogue_fusion": False,
"coordinate_descent_tuning": True,
"coordinate_descent_check_all_directions": True,
# "max_autotune": True,
"max_autotune": False,
"triton.cudagraphs": True,
}
TRANSFORMER_SPATIAL_PATCH_SIZE = 1
TRANSFORMER_TEMPORAL_PATCH_SIZE = 1
VAE_SPATIAL_COMPRESSION_RATIO = 32
VAE_TEMPORAL_COMPRESSION_RATIO = 8
def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
num_frames = kwargs.get("num_frames")
height = kwargs.get("height")
width = kwargs.get("width")
latent_num_frames = (num_frames - 1) // VAE_TEMPORAL_COMPRESSION_RATIO + 1
latent_height = height // VAE_SPATIAL_COMPRESSION_RATIO
latent_width = width // VAE_SPATIAL_COMPRESSION_RATIO
with capture_component_call(pipeline, "transformer") as call:
pipeline(*args, **kwargs)
@spaces.GPU(duration=1500)
def compile_transformer():
# dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
# dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
quantize_(pipeline.transformer, float8_dynamic_activation_float8_weight())
hidden_states: torch.Tensor = call.kwargs["hidden_states"]
unpacked_hidden_states = LTXConditionPipeline._unpack_latents(
hidden_states,
latent_num_frames,
latent_height,
latent_width,
TRANSFORMER_SPATIAL_PATCH_SIZE,
TRANSFORMER_TEMPORAL_PATCH_SIZE,
)
unpacked_hidden_states_transposed = unpacked_hidden_states.transpose(-1, -2).contiguous()
if unpacked_hidden_states.shape[-1] > hidden_states.shape[-2]:
hidden_states_landscape = unpacked_hidden_states
hidden_states_portrait = unpacked_hidden_states_transposed
else:
hidden_states_landscape = unpacked_hidden_states_transposed
hidden_states_portrait = unpacked_hidden_states
hidden_states_landscape = LTXConditionPipeline._pack_latents(
hidden_states_landscape, TRANSFORMER_SPATIAL_PATCH_SIZE, TRANSFORMER_TEMPORAL_PATCH_SIZE
)
hidden_states_portrait = LTXConditionPipeline._pack_latents(
hidden_states_portrait, TRANSFORMER_SPATIAL_PATCH_SIZE, TRANSFORMER_TEMPORAL_PATCH_SIZE
)
exported_landscape = torch.export.export(
mod=pipeline.transformer,
args=call.args,
kwargs=call.kwargs | {"hidden_states": hidden_states_landscape},
# dynamic_shapes=dynamic_shapes,
)
exported_portrait = torch.export.export(
mod=pipeline.transformer,
args=call.args,
kwargs=call.kwargs | {"hidden_states": hidden_states_portrait},
# dynamic_shapes=dynamic_shapes,
)
compiled_landscape = aoti_compile(exported_landscape, INDUCTOR_CONFIGS)
compiled_portrait = aoti_compile(exported_portrait, INDUCTOR_CONFIGS)
compiled_portrait.weights = (
compiled_landscape.weights
) # Avoid weights duplication when serializing back to main process
return compiled_landscape, compiled_portrait
compiled_landscape, compiled_portrait = compile_transformer()
@torch.no_grad()
def combined_transformer(*args, **kwargs):
hidden_states: torch.Tensor = kwargs["hidden_states"]
unpacked_hidden_states = LTXConditionPipeline._unpack_latents(
hidden_states,
latent_num_frames,
latent_height,
latent_width,
TRANSFORMER_SPATIAL_PATCH_SIZE,
TRANSFORMER_TEMPORAL_PATCH_SIZE,
)
if unpacked_hidden_states.shape[-1] > unpacked_hidden_states.shape[-2]:
return compiled_landscape(*args, **kwargs)
else:
return compiled_portrait(*args, **kwargs)
transformer_config = pipeline.transformer.config
transformer_dtype = pipeline.transformer.dtype
cache_context = pipeline.transformer.cache_context
with torch.no_grad():
combined_transformer(**call.kwargs)
pipeline.transformer = cudagraph(combined_transformer)
with torch.no_grad():
pipeline.transformer(**call.kwargs)
pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
pipeline.transformer.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue]
pipeline.transformer.cache_context = cache_context |