File size: 5,368 Bytes
3fc0a22
 
 
 
 
 
 
d418d58
3fc0a22
 
 
10bbb52
e2da1c9
10bbb52
3fc0a22
 
7f65363
3fc0a22
 
e2da1c9
 
34dc3bc
 
 
 
3fc0a22
 
7f65363
 
 
 
34dc3bc
7f65363
3fc0a22
e2da1c9
 
 
 
3fc0a22
7f65363
3fc0a22
e2da1c9
 
7f65363
e2da1c9
 
7f65363
3fc0a22
 
 
0a64cad
 
 
34dc3bc
 
3fc0a22
10bbb52
7f65363
 
e2da1c9
 
 
 
 
 
 
 
7f65363
e2da1c9
 
 
3fc0a22
e2da1c9
 
 
 
 
 
 
 
 
b2f5bcf
3fc0a22
 
 
7f65363
34dc3bc
3fc0a22
 
 
 
7f65363
34dc3bc
3fc0a22
 
 
 
7f65363
 
 
3fc0a22
 
 
 
 
10bbb52
3fc0a22
7f65363
e2da1c9
 
 
 
 
 
 
 
 
3fc0a22
 
 
 
 
 
fa8999e
10bbb52
1adac55
 
10bbb52
34dc3bc
 
10bbb52
0a64cad
 
10bbb52
7f65363
 
fa8999e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
Taken from https://huggingface.co/spaces/cbensimon/wan2-1-fast/
"""

from typing import Any
from typing import Callable
from typing import ParamSpec
import spaces
import spaces
import torch
from torch.utils._pytree import tree_map_only
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight
from diffusers import LTXConditionPipeline
from optimization_utils import capture_component_call, aoti_compile, cudagraph


P = ParamSpec("P")


# Sequence packing in LTX is a bit of a pain.
# See: https://github.com/huggingface/diffusers/blob/c052791b5fe29ce8a308bf63dda97aa205b729be/src/diffusers/pipelines/ltx/pipeline_ltx.py#L420
TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim.AUTO
TRANSFORMER_DYNAMIC_SHAPES = {
    "hidden_states": {1: TRANSFORMER_NUM_FRAMES_DIM},
}

INDUCTOR_CONFIGS = {
    "conv_1x1_as_mm": True,
    "epilogue_fusion": False,
    "coordinate_descent_tuning": True,
    "coordinate_descent_check_all_directions": True,
    "max_autotune": False, # doesn't help much
    "triton.cudagraphs": True,
}
TRANSFORMER_SPATIAL_PATCH_SIZE = 1
TRANSFORMER_TEMPORAL_PATCH_SIZE = 1
VAE_SPATIAL_COMPRESSION_RATIO = 32
VAE_TEMPORAL_COMPRESSION_RATIO = 8


def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
    num_frames = kwargs.get("num_frames")
    height = kwargs.get("height")
    width = kwargs.get("width")
    latent_num_frames = (num_frames - 1) // VAE_TEMPORAL_COMPRESSION_RATIO + 1
    latent_height = height // VAE_SPATIAL_COMPRESSION_RATIO
    latent_width = width // VAE_SPATIAL_COMPRESSION_RATIO

    @spaces.GPU(duration=1500)
    def compile_transformer():
        with capture_component_call(pipeline, "transformer") as call:
            pipeline(*args, **kwargs)
        
        dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
        dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES

        quantize_(pipeline.transformer, float8_dynamic_activation_float8_weight())

        hidden_states: torch.Tensor = call.kwargs["hidden_states"]
        unpacked_hidden_states = LTXConditionPipeline._unpack_latents(
            hidden_states,
            latent_num_frames,
            latent_height,
            latent_width,
            TRANSFORMER_SPATIAL_PATCH_SIZE,
            TRANSFORMER_TEMPORAL_PATCH_SIZE,
        )
        unpacked_hidden_states_transposed = unpacked_hidden_states.transpose(-1, -2).contiguous()
        if unpacked_hidden_states.shape[-1] > hidden_states.shape[-2]:
            hidden_states_landscape = unpacked_hidden_states
            hidden_states_portrait = unpacked_hidden_states_transposed
        else:
            hidden_states_landscape = unpacked_hidden_states_transposed
            hidden_states_portrait = unpacked_hidden_states

        hidden_states_landscape = LTXConditionPipeline._pack_latents(
            hidden_states_landscape, TRANSFORMER_SPATIAL_PATCH_SIZE, TRANSFORMER_TEMPORAL_PATCH_SIZE
        )
        hidden_states_portrait = LTXConditionPipeline._pack_latents(
            hidden_states_portrait, TRANSFORMER_SPATIAL_PATCH_SIZE, TRANSFORMER_TEMPORAL_PATCH_SIZE
        )
        
        exported_landscape = torch.export.export(
            mod=pipeline.transformer,
            args=call.args,
            kwargs=call.kwargs | {"hidden_states": hidden_states_landscape},
            dynamic_shapes=dynamic_shapes,
        )
        exported_portrait = torch.export.export(
            mod=pipeline.transformer,
            args=call.args,
            kwargs=call.kwargs | {"hidden_states": hidden_states_portrait},
            dynamic_shapes=dynamic_shapes,
        )

        compiled_landscape = aoti_compile(exported_landscape, INDUCTOR_CONFIGS)
        compiled_portrait = aoti_compile(exported_portrait, INDUCTOR_CONFIGS)
        compiled_portrait.weights = (
            compiled_landscape.weights
        )  # Avoid weights duplication when serializing back to main process

        return compiled_landscape, compiled_portrait

    compiled_landscape, compiled_portrait = compile_transformer()

    @torch.no_grad()
    def combined_transformer(*args, **kwargs):
        hidden_states: torch.Tensor = kwargs["hidden_states"]
        unpacked_hidden_states = LTXConditionPipeline._unpack_latents(
            hidden_states,
            latent_num_frames,
            latent_height,
            latent_width,
            TRANSFORMER_SPATIAL_PATCH_SIZE,
            TRANSFORMER_TEMPORAL_PATCH_SIZE,
        )
        if unpacked_hidden_states.shape[-1] > unpacked_hidden_states.shape[-2]:
            return compiled_landscape(*args, **kwargs)
        else:
            return compiled_portrait(*args, **kwargs)

    transformer_config = pipeline.transformer.config
    transformer_dtype = pipeline.transformer.dtype
    cache_context = pipeline.transformer.cache_context

    # with torch.no_grad():
    #     combined_transformer(**call.kwargs)

    pipeline.transformer = combined_transformer
    # pipeline.transformer = cudagraph(combined_transformer)

    # with torch.no_grad():
    #     pipeline.transformer(**call.kwargs)

    pipeline.transformer.config = transformer_config  # pyright: ignore[reportAttributeAccessIssue]
    pipeline.transformer.dtype = transformer_dtype  # pyright: ignore[reportAttributeAccessIssue]
    pipeline.transformer.cache_context = cache_context