cbensimon HF Staff commited on
Commit
10b04c2
·
verified ·
1 Parent(s): 7cc6b52

Update optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +3 -3
optimization.py CHANGED
@@ -7,13 +7,13 @@ from typing import ParamSpec
7
 
8
  import spaces
9
  import torch
10
- from torch.utils._pytree import tree_map_only
11
 
12
 
13
  P = ParamSpec('P')
14
 
15
 
16
- TRANSFORMER_HIDDEN_DIM = torch.export.Dim('hidden', min=3584, max=9727)
17
 
18
  TRANSFORMER_DYNAMIC_SHAPES = {
19
  'hidden_states': {1: TRANSFORMER_HIDDEN_DIM},
@@ -38,7 +38,7 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
38
  with spaces.aoti_capture(pipeline.transformer) as call:
39
  pipeline(*args, **kwargs)
40
 
41
- dynamic_shapes = tree_map_only((torch.Tensor, bool, int), lambda t: None, call.kwargs)
42
  dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
43
 
44
  exported = torch.export.export(
 
7
 
8
  import spaces
9
  import torch
10
+ from torch.utils._pytree import tree_map
11
 
12
 
13
  P = ParamSpec('P')
14
 
15
 
16
+ TRANSFORMER_HIDDEN_DIM = torch.export.Dim.AUTO('hidden', min=3584, max=9727)
17
 
18
  TRANSFORMER_DYNAMIC_SHAPES = {
19
  'hidden_states': {1: TRANSFORMER_HIDDEN_DIM},
 
38
  with spaces.aoti_capture(pipeline.transformer) as call:
39
  pipeline(*args, **kwargs)
40
 
41
+ dynamic_shapes = tree_map(lambda t: None, call.kwargs)
42
  dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
43
 
44
  exported = torch.export.export(