Spaces:
Running
on
Zero
Running
on
Zero
Update optimization.py
Browse files- optimization.py +5 -5
optimization.py
CHANGED
@@ -46,11 +46,11 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
|
|
46 |
latent_height = height // VAE_SPATIAL_COMPRESSION_RATIO
|
47 |
latent_width = width // VAE_SPATIAL_COMPRESSION_RATIO
|
48 |
|
49 |
-
with capture_component_call(pipeline, "transformer") as call:
|
50 |
-
pipeline(*args, **kwargs)
|
51 |
-
|
52 |
@spaces.GPU(duration=1500)
|
53 |
def compile_transformer():
|
|
|
|
|
|
|
54 |
dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
|
55 |
dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
|
56 |
|
@@ -129,8 +129,8 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
|
|
129 |
pipeline.transformer = combined_transformer
|
130 |
# pipeline.transformer = cudagraph(combined_transformer)
|
131 |
|
132 |
-
with torch.no_grad():
|
133 |
-
|
134 |
|
135 |
pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
|
136 |
pipeline.transformer.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue]
|
|
|
46 |
latent_height = height // VAE_SPATIAL_COMPRESSION_RATIO
|
47 |
latent_width = width // VAE_SPATIAL_COMPRESSION_RATIO
|
48 |
|
|
|
|
|
|
|
49 |
@spaces.GPU(duration=1500)
|
50 |
def compile_transformer():
|
51 |
+
with capture_component_call(pipeline, "transformer") as call:
|
52 |
+
pipeline(*args, **kwargs)
|
53 |
+
|
54 |
dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
|
55 |
dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
|
56 |
|
|
|
129 |
pipeline.transformer = combined_transformer
|
130 |
# pipeline.transformer = cudagraph(combined_transformer)
|
131 |
|
132 |
+
# with torch.no_grad():
|
133 |
+
# pipeline.transformer(**call.kwargs)
|
134 |
|
135 |
pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
|
136 |
pipeline.transformer.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue]
|