sayakpaul HF Staff commited on
Commit
0a64cad
·
verified ·
1 Parent(s): d418d58

Update optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +5 -5
optimization.py CHANGED
@@ -46,11 +46,11 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
46
  latent_height = height // VAE_SPATIAL_COMPRESSION_RATIO
47
  latent_width = width // VAE_SPATIAL_COMPRESSION_RATIO
48
 
49
- with capture_component_call(pipeline, "transformer") as call:
50
- pipeline(*args, **kwargs)
51
-
52
  @spaces.GPU(duration=1500)
53
  def compile_transformer():
 
 
 
54
  dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
55
  dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
56
 
@@ -129,8 +129,8 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
129
  pipeline.transformer = combined_transformer
130
  # pipeline.transformer = cudagraph(combined_transformer)
131
 
132
- with torch.no_grad():
133
- pipeline.transformer(**call.kwargs)
134
 
135
  pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
136
  pipeline.transformer.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue]
 
46
  latent_height = height // VAE_SPATIAL_COMPRESSION_RATIO
47
  latent_width = width // VAE_SPATIAL_COMPRESSION_RATIO
48
 
 
 
 
49
  @spaces.GPU(duration=1500)
50
  def compile_transformer():
51
+ with capture_component_call(pipeline, "transformer") as call:
52
+ pipeline(*args, **kwargs)
53
+
54
  dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
55
  dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
56
 
 
129
  pipeline.transformer = combined_transformer
130
  # pipeline.transformer = cudagraph(combined_transformer)
131
 
132
+ # with torch.no_grad():
133
+ # pipeline.transformer(**call.kwargs)
134
 
135
  pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
136
  pipeline.transformer.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue]