cbensimon HF Staff commited on
Commit
7601fad
·
verified ·
1 Parent(s): 0384ccd

Update optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +7 -5
optimization.py CHANGED
@@ -8,16 +8,19 @@ from torchao.quantization import quantize_
8
  from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
9
  import spaces
10
  import torch
11
- from torch.utils._pytree import tree_map, tree_map_only
12
 
13
 
14
  P = ParamSpec('P')
15
 
16
 
17
- TRANSFORMER_HIDDEN_DIM = torch.export.Dim.AUTO(min=3584, max=9727)
 
 
18
 
19
  TRANSFORMER_DYNAMIC_SHAPES = {
20
- 'hidden_states': {1: TRANSFORMER_HIDDEN_DIM},
 
21
  }
22
 
23
 
@@ -40,7 +43,6 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
40
  pipeline(*args, **kwargs)
41
 
42
  dynamic_shapes = tree_map(lambda t: None, call.kwargs)
43
- # dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda x: None, call.kwargs)
44
  dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
45
 
46
  quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
@@ -49,7 +51,7 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
49
  mod=pipeline.transformer,
50
  args=call.args,
51
  kwargs=call.kwargs,
52
- # dynamic_shapes=dynamic_shapes,
53
  )
54
 
55
  return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
 
8
  from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
9
  import spaces
10
  import torch
11
+ from torch.utils._pytree import tree_map
12
 
13
 
14
  P = ParamSpec('P')
15
 
16
 
17
+ TRANSFORMER_IMAGE_SEQ_LENGTH_DIM = torch.export.Dim.AUTO(min=3584, max=9727)
18
+ TRANSFORMER_IMAGE_HEIGHT_DIM = torch.export.Dim.DYNAMIC
19
+ TRANSFORMER_IMAGE_WIDTH_DIM = torch.export.Dim.DYNAMIC
20
 
21
  TRANSFORMER_DYNAMIC_SHAPES = {
22
+ 'hidden_states': {1: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM},
23
+ 'img_shapes': [(None, TRANSFORMER_IMAGE_HEIGHT_DIM, TRANSFORMER_IMAGE_WIDTH_DIM)]
24
  }
25
 
26
 
 
43
  pipeline(*args, **kwargs)
44
 
45
  dynamic_shapes = tree_map(lambda t: None, call.kwargs)
 
46
  dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
47
 
48
  quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
 
51
  mod=pipeline.transformer,
52
  args=call.args,
53
  kwargs=call.kwargs,
54
+ dynamic_shapes=dynamic_shapes,
55
  )
56
 
57
  return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)