sayakpaul HF Staff commited on
Commit
e2da1c9
·
1 Parent(s): 4e531c0
Files changed (1) hide show
  1. optimization.py +45 -11
optimization.py CHANGED
@@ -11,7 +11,7 @@ import torch
11
  from torch.utils._pytree import tree_map_only
12
  from torchao.quantization import quantize_
13
  from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
14
-
15
  from optimization_utils import capture_component_call
16
  from optimization_utils import aoti_compile
17
 
@@ -19,10 +19,12 @@ from optimization_utils import aoti_compile
19
  P = ParamSpec('P')
20
 
21
 
22
- TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim('num_frames', min=3, max=21)
 
 
23
 
24
  TRANSFORMER_DYNAMIC_SHAPES = {
25
- 'hidden_states': {2: TRANSFORMER_NUM_FRAMES_DIM},
26
  }
27
 
28
  INDUCTOR_CONFIGS = {
@@ -33,9 +35,18 @@ INDUCTOR_CONFIGS = {
33
  'max_autotune': True,
34
  'triton.cudagraphs': True,
35
  }
36
-
 
 
 
37
 
38
  def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
 
 
 
 
 
 
39
 
40
  @spaces.GPU(duration=1500)
41
  def compile_transformer():
@@ -49,13 +60,28 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
49
  quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
50
 
51
  hidden_states: torch.Tensor = call.kwargs['hidden_states']
52
- hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous()
53
- if hidden_states.shape[-1] > hidden_states.shape[-2]:
54
- hidden_states_landscape = hidden_states
55
- hidden_states_portrait = hidden_states_transposed
 
 
 
 
 
 
 
 
56
  else:
57
- hidden_states_landscape = hidden_states_transposed
58
- hidden_states_portrait = hidden_states
 
 
 
 
 
 
 
59
 
60
  exported_landscape = torch.export.export(
61
  mod=pipeline.transformer,
@@ -81,7 +107,15 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
81
 
82
  def combined_transformer(*args, **kwargs):
83
  hidden_states: torch.Tensor = kwargs['hidden_states']
84
- if hidden_states.shape[-1] > hidden_states.shape[-2]:
 
 
 
 
 
 
 
 
85
  return compiled_landscape(*args, **kwargs)
86
  else:
87
  return compiled_portrait(*args, **kwargs)
 
11
  from torch.utils._pytree import tree_map_only
12
  from torchao.quantization import quantize_
13
  from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
14
+ from diffusers import LTXConditionPipeline
15
  from optimization_utils import capture_component_call
16
  from optimization_utils import aoti_compile
17
 
 
19
  P = ParamSpec('P')
20
 
21
 
22
+ # Sequence packing in LTX is a bit of a pain.
23
+ # See: https://github.com/huggingface/diffusers/blob/c052791b5fe29ce8a308bf63dda97aa205b729be/src/diffusers/pipelines/ltx/pipeline_ltx.py#L420
24
+ TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim('seq_len', min=4680, max=6000)
25
 
26
  TRANSFORMER_DYNAMIC_SHAPES = {
27
+ 'hidden_states': {1: TRANSFORMER_NUM_FRAMES_DIM},
28
  }
29
 
30
  INDUCTOR_CONFIGS = {
 
35
  'max_autotune': True,
36
  'triton.cudagraphs': True,
37
  }
38
+ TRANSFORMER_SPATIAL_PATCH_SIZE = 1
39
+ TRANSFORMER_TEMPORAL_PATCH_SIZE = 1
40
+ VAE_SPATIAL_COMPRESSION_RATIO = 32
41
+ VAE_TEMPORAL_COMPRESSION_RATIO = 8
42
 
43
  def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
44
+ num_frames = kwargs.get("num_frames")
45
+ height = kwargs.get("height")
46
+ width = kwargs.get("width")
47
+ latent_num_frames = (num_frames - 1) // VAE_TEMPORAL_COMPRESSION_RATIO + 1
48
+ latent_height = height // VAE_SPATIAL_COMPRESSION_RATIO
49
+ latent_width = width //VAE_SPATIAL_COMPRESSION_RATIO
50
 
51
  @spaces.GPU(duration=1500)
52
  def compile_transformer():
 
60
  quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
61
 
62
  hidden_states: torch.Tensor = call.kwargs['hidden_states']
63
+ unpacked_hidden_states = LTXConditionPipeline._unpack_latents(
64
+ hidden_states,
65
+ latent_num_frames,
66
+ latent_height,
67
+ latent_width,
68
+ TRANSFORMER_SPATIAL_PATCH_SIZE,
69
+ TRANSFORMER_TEMPORAL_PATCH_SIZE,
70
+ )
71
+ unpacked_hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous()
72
+ if unpacked_hidden_states.shape[-1] > hidden_states.shape[-2]:
73
+ hidden_states_landscape = unpacked_hidden_states
74
+ hidden_states_portrait = unpacked_hidden_states_transposed
75
  else:
76
+ hidden_states_landscape = unpacked_hidden_states_transposed
77
+ hidden_states_portrait = unpacked_hidden_states
78
+
79
+ hidden_states_landscape = LTXConditionPipeline._pack_latents(
80
+ hidden_states_landscape, TRANSFORMER_SPATIAL_PATCH_SIZE, TRANSFORMER_TEMPORAL_PATCH_SIZE
81
+ )
82
+ hidden_states_portrait = LTXConditionPipeline._pack_latents(
83
+ hidden_states_portrait, TRANSFORMER_SPATIAL_PATCH_SIZE, TRANSFORMER_TEMPORAL_PATCH_SIZE
84
+ )
85
 
86
  exported_landscape = torch.export.export(
87
  mod=pipeline.transformer,
 
107
 
108
  def combined_transformer(*args, **kwargs):
109
  hidden_states: torch.Tensor = kwargs['hidden_states']
110
+ unpacked_hidden_states = LTXConditionPipeline._unpack_latents(
111
+ hidden_states,
112
+ latent_num_frames,
113
+ latent_height,
114
+ latent_width,
115
+ TRANSFORMER_SPATIAL_PATCH_SIZE,
116
+ TRANSFORMER_TEMPORAL_PATCH_SIZE,
117
+ )
118
+ if unpacked_hidden_states.shape[-1] > unpacked_hidden_states.shape[-2]:
119
  return compiled_landscape(*args, **kwargs)
120
  else:
121
  return compiled_portrait(*args, **kwargs)