cbensimon HF Staff commited on
Commit
05cb184
·
verified ·
1 Parent(s): e5e8c84

Update optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +1 -4
optimization.py CHANGED
@@ -23,7 +23,6 @@ TRANSFORMER_HIDDEN_DIM = torch.export.Dim('hidden', min=3584, max=9727)
23
 
24
  TRANSFORMER_DYNAMIC_SHAPES = {
25
  'hidden_states': {1: TRANSFORMER_HIDDEN_DIM},
26
- # 'img_ids': {0: TRANSFORMER_HIDDEN_DIM},
27
  }
28
 
29
 
@@ -45,11 +44,9 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
45
  with capture_component_call(pipeline, 'transformer') as call:
46
  pipeline(*args, **kwargs)
47
 
48
- dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
49
  dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
50
 
51
- # pipeline.transformer.fuse_qkv_projections()
52
-
53
  quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
54
 
55
  exported = torch.export.export(
 
23
 
24
  TRANSFORMER_DYNAMIC_SHAPES = {
25
  'hidden_states': {1: TRANSFORMER_HIDDEN_DIM},
 
26
  }
27
 
28
 
 
44
  with capture_component_call(pipeline, 'transformer') as call:
45
  pipeline(*args, **kwargs)
46
 
47
+ dynamic_shapes = tree_map_only((torch.Tensor, bool, int), lambda t: None, call.kwargs)
48
  dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
49
 
 
 
50
  quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
51
 
52
  exported = torch.export.export(