sayakpaul HF Staff commited on
Commit
f3f8d99
·
1 Parent(s): 7f65363
Files changed (1) hide show
  1. optimization.py +2 -0
optimization.py CHANGED
@@ -60,6 +60,7 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
60
  quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
61
 
62
  hidden_states: torch.Tensor = call.kwargs["hidden_states"]
 
63
  unpacked_hidden_states = LTXConditionPipeline._unpack_latents(
64
  hidden_states,
65
  latent_num_frames,
@@ -68,6 +69,7 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
68
  TRANSFORMER_SPATIAL_PATCH_SIZE,
69
  TRANSFORMER_TEMPORAL_PATCH_SIZE,
70
  )
 
71
  unpacked_hidden_states_transposed = unpacked_hidden_states.transpose(-1, -2).contiguous()
72
  if unpacked_hidden_states.shape[-1] > hidden_states.shape[-2]:
73
  hidden_states_landscape = unpacked_hidden_states
 
60
  quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
61
 
62
  hidden_states: torch.Tensor = call.kwargs["hidden_states"]
63
+ print(f"{hidden_states.shape=}")
64
  unpacked_hidden_states = LTXConditionPipeline._unpack_latents(
65
  hidden_states,
66
  latent_num_frames,
 
69
  TRANSFORMER_SPATIAL_PATCH_SIZE,
70
  TRANSFORMER_TEMPORAL_PATCH_SIZE,
71
  )
72
+ print(f"{unpacked_hidden_states.shape=}")
73
  unpacked_hidden_states_transposed = unpacked_hidden_states.transpose(-1, -2).contiguous()
74
  if unpacked_hidden_states.shape[-1] > hidden_states.shape[-2]:
75
  hidden_states_landscape = unpacked_hidden_states