cbensimon HF Staff commited on
Commit
59021ac
·
verified ·
1 Parent(s): 66f5ac6

Update optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +3 -11
optimization.py CHANGED
@@ -11,10 +11,6 @@ from torch.utils._pytree import tree_map_only
11
  from torchao.quantization import quantize_
12
  from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
13
 
14
- from optimization_utils import capture_component_call
15
- from optimization_utils import aoti_compile
16
- from optimization_utils import cudagraph
17
-
18
 
19
  P = ParamSpec('P')
20
 
@@ -41,13 +37,11 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
41
  @spaces.GPU(duration=1500)
42
  def compile_transformer():
43
 
44
- with capture_component_call(pipeline, 'transformer') as call:
45
  pipeline(*args, **kwargs)
46
 
47
  dynamic_shapes = tree_map_only((torch.Tensor, bool, int), lambda t: None, call.kwargs)
48
  dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
49
-
50
- quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
51
 
52
  exported = torch.export.export(
53
  mod=pipeline.transformer,
@@ -56,8 +50,6 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
56
  dynamic_shapes=dynamic_shapes,
57
  )
58
 
59
- return aoti_compile(exported, INDUCTOR_CONFIGS)
60
 
61
- transformer_config = pipeline.transformer.config
62
- pipeline.transformer = compile_transformer()
63
- pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
 
11
  from torchao.quantization import quantize_
12
  from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
13
 
 
 
 
 
14
 
15
  P = ParamSpec('P')
16
 
 
37
  @spaces.GPU(duration=1500)
38
  def compile_transformer():
39
 
40
+ with spaces.aoti_capture(pipeline.transformer) as call:
41
  pipeline(*args, **kwargs)
42
 
43
  dynamic_shapes = tree_map_only((torch.Tensor, bool, int), lambda t: None, call.kwargs)
44
  dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
 
 
45
 
46
  exported = torch.export.export(
47
  mod=pipeline.transformer,
 
50
  dynamic_shapes=dynamic_shapes,
51
  )
52
 
53
+ return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
54
 
55
+ spaces.aoti_apply(compile_transformer(), pipeline.transformer)