sayakpaul HF Staff commited on
Commit
b2f5bcf
·
1 Parent(s): c910edc
Files changed (1) hide show
  1. optimization.py +2 -5
optimization.py CHANGED
@@ -21,7 +21,7 @@ P = ParamSpec("P")
21
 
22
  # Sequence packing in LTX is a bit of a pain.
23
  # See: https://github.com/huggingface/diffusers/blob/c052791b5fe29ce8a308bf63dda97aa205b729be/src/diffusers/pipelines/ltx/pipeline_ltx.py#L420
24
- TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim("seq_len", min=4680, max=6000)
25
 
26
  TRANSFORMER_DYNAMIC_SHAPES = {
27
  "hidden_states": {1: TRANSFORMER_NUM_FRAMES_DIM},
@@ -60,7 +60,6 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
60
  quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
61
 
62
  hidden_states: torch.Tensor = call.kwargs["hidden_states"]
63
- print(f"{hidden_states.shape=}")
64
  unpacked_hidden_states = LTXConditionPipeline._unpack_latents(
65
  hidden_states,
66
  latent_num_frames,
@@ -69,7 +68,6 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
69
  TRANSFORMER_SPATIAL_PATCH_SIZE,
70
  TRANSFORMER_TEMPORAL_PATCH_SIZE,
71
  )
72
- print(f"{unpacked_hidden_states.shape=}")
73
  unpacked_hidden_states_transposed = unpacked_hidden_states.transpose(-1, -2).contiguous()
74
  if unpacked_hidden_states.shape[-1] > hidden_states.shape[-2]:
75
  hidden_states_landscape = unpacked_hidden_states
@@ -84,8 +82,7 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
84
  hidden_states_portrait = LTXConditionPipeline._pack_latents(
85
  hidden_states_portrait, TRANSFORMER_SPATIAL_PATCH_SIZE, TRANSFORMER_TEMPORAL_PATCH_SIZE
86
  )
87
- print(f"{hidden_states_landscape.shape=}, {hidden_states_portrait.shape=}")
88
-
89
  exported_landscape = torch.export.export(
90
  mod=pipeline.transformer,
91
  args=call.args,
 
21
 
22
  # Sequence packing in LTX is a bit of a pain.
23
  # See: https://github.com/huggingface/diffusers/blob/c052791b5fe29ce8a308bf63dda97aa205b729be/src/diffusers/pipelines/ltx/pipeline_ltx.py#L420
24
+ TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim("seq_len", min=4680, max=4680)
25
 
26
  TRANSFORMER_DYNAMIC_SHAPES = {
27
  "hidden_states": {1: TRANSFORMER_NUM_FRAMES_DIM},
 
60
  quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
61
 
62
  hidden_states: torch.Tensor = call.kwargs["hidden_states"]
 
63
  unpacked_hidden_states = LTXConditionPipeline._unpack_latents(
64
  hidden_states,
65
  latent_num_frames,
 
68
  TRANSFORMER_SPATIAL_PATCH_SIZE,
69
  TRANSFORMER_TEMPORAL_PATCH_SIZE,
70
  )
 
71
  unpacked_hidden_states_transposed = unpacked_hidden_states.transpose(-1, -2).contiguous()
72
  if unpacked_hidden_states.shape[-1] > hidden_states.shape[-2]:
73
  hidden_states_landscape = unpacked_hidden_states
 
82
  hidden_states_portrait = LTXConditionPipeline._pack_latents(
83
  hidden_states_portrait, TRANSFORMER_SPATIAL_PATCH_SIZE, TRANSFORMER_TEMPORAL_PATCH_SIZE
84
  )
85
+
 
86
  exported_landscape = torch.export.export(
87
  mod=pipeline.transformer,
88
  args=call.args,