Spaces:
Running
on
Zero
Running
on
Zero
Fix dynamic shapes
Browse files- optimization.py +3 -1
optimization.py
CHANGED
@@ -7,6 +7,7 @@ from typing import ParamSpec
|
|
7 |
|
8 |
import spaces
|
9 |
import torch
|
|
|
10 |
|
11 |
from pipeline_utils import capture_component_call
|
12 |
from zerogpu import aoti_compile
|
@@ -34,7 +35,8 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
|
|
34 |
pipeline(*args, **kwargs)
|
35 |
|
36 |
hidden_dim = torch.export.Dim('hidden', min=4096, max=8212)
|
37 |
-
dynamic_shapes =
|
|
|
38 |
'hidden_states': {1: hidden_dim},
|
39 |
'img_ids': {0: hidden_dim},
|
40 |
}
|
|
|
7 |
|
8 |
import spaces
|
9 |
import torch
|
10 |
+
from torch.utils._pytree import tree_map_only
|
11 |
|
12 |
from pipeline_utils import capture_component_call
|
13 |
from zerogpu import aoti_compile
|
|
|
35 |
pipeline(*args, **kwargs)
|
36 |
|
37 |
hidden_dim = torch.export.Dim('hidden', min=4096, max=8212)
|
38 |
+
dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
|
39 |
+
dynamic_shapes |= {
|
40 |
'hidden_states': {1: hidden_dim},
|
41 |
'img_ids': {0: hidden_dim},
|
42 |
}
|