cbensimon HF Staff commited on
Commit
fc3f0ed
·
1 Parent(s): 318b03c

Fix dynamic shapes

Browse files
Files changed (1) hide show
  1. optimization.py +3 -1
optimization.py CHANGED
@@ -7,6 +7,7 @@ from typing import ParamSpec
7
 
8
  import spaces
9
  import torch
 
10
 
11
  from pipeline_utils import capture_component_call
12
  from zerogpu import aoti_compile
@@ -34,7 +35,8 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
34
  pipeline(*args, **kwargs)
35
 
36
  hidden_dim = torch.export.Dim('hidden', min=4096, max=8212)
37
- dynamic_shapes = {
 
38
  'hidden_states': {1: hidden_dim},
39
  'img_ids': {0: hidden_dim},
40
  }
 
7
 
8
  import spaces
9
  import torch
10
+ from torch.utils._pytree import tree_map_only
11
 
12
  from pipeline_utils import capture_component_call
13
  from zerogpu import aoti_compile
 
35
  pipeline(*args, **kwargs)
36
 
37
  hidden_dim = torch.export.Dim('hidden', min=4096, max=8212)
38
+ dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
39
+ dynamic_shapes |= {
40
  'hidden_states': {1: hidden_dim},
41
  'img_ids': {0: hidden_dim},
42
  }