linoyts HF Staff commited on
Commit
dc28977
·
verified ·
1 Parent(s): 777d439

Update optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +4 -3
optimization.py CHANGED
@@ -8,7 +8,7 @@ from torchao.quantization import quantize_
8
  from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
9
  import spaces
10
  import torch
11
- from torch.utils._pytree import tree_map
12
 
13
 
14
  P = ParamSpec('P')
@@ -39,10 +39,11 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
39
  with spaces.aoti_capture(pipeline.transformer) as call:
40
  pipeline(*args, **kwargs)
41
 
42
- dynamic_shapes = tree_map(lambda t: None, call.kwargs)
 
43
  dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
44
 
45
- quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
46
 
47
  exported = torch.export.export(
48
  mod=pipeline.transformer,
 
8
  from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
9
  import spaces
10
  import torch
11
+ from torch.utils._pytree import tree_map, tree_map_only
12
 
13
 
14
  P = ParamSpec('P')
 
39
  with spaces.aoti_capture(pipeline.transformer) as call:
40
  pipeline(*args, **kwargs)
41
 
42
+ #dynamic_shapes = tree_map(lambda t: None, call.kwargs)
43
+ dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda x: None, call.kwargs)
44
  dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
45
 
46
+ #quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
47
 
48
  exported = torch.export.export(
49
  mod=pipeline.transformer,