sayakpaul HF Staff commited on
Commit
832becd
·
1 Parent(s): b2f5bcf
Files changed (1) hide show
  1. optimization.py +3 -2
optimization.py CHANGED
@@ -23,6 +23,7 @@ P = ParamSpec("P")
23
  # See: https://github.com/huggingface/diffusers/blob/c052791b5fe29ce8a308bf63dda97aa205b729be/src/diffusers/pipelines/ltx/pipeline_ltx.py#L420
24
  TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim("seq_len", min=4680, max=4680)
25
 
 
26
  TRANSFORMER_DYNAMIC_SHAPES = {
27
  "hidden_states": {1: TRANSFORMER_NUM_FRAMES_DIM},
28
  }
@@ -55,7 +56,7 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
55
  pipeline(*args, **kwargs)
56
 
57
  dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
58
- dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
59
 
60
  quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
61
 
@@ -94,7 +95,7 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
94
  mod=pipeline.transformer,
95
  args=call.args,
96
  kwargs=call.kwargs | {"hidden_states": hidden_states_portrait},
97
- dynamic_shapes=dynamic_shapes,
98
  )
99
 
100
  compiled_landscape = aoti_compile(exported_landscape, INDUCTOR_CONFIGS)
 
23
  # See: https://github.com/huggingface/diffusers/blob/c052791b5fe29ce8a308bf63dda97aa205b729be/src/diffusers/pipelines/ltx/pipeline_ltx.py#L420
24
  TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim("seq_len", min=4680, max=4680)
25
 
26
+ # Unused currently as I don't know how to make the best use of it for LTX.
27
  TRANSFORMER_DYNAMIC_SHAPES = {
28
  "hidden_states": {1: TRANSFORMER_NUM_FRAMES_DIM},
29
  }
 
56
  pipeline(*args, **kwargs)
57
 
58
  dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
59
+ # dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
60
 
61
  quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
62
 
 
95
  mod=pipeline.transformer,
96
  args=call.args,
97
  kwargs=call.kwargs | {"hidden_states": hidden_states_portrait},
98
+ ddynamic_shapes=dynamic_shapes,
99
  )
100
 
101
  compiled_landscape = aoti_compile(exported_landscape, INDUCTOR_CONFIGS)