linoyts HF Staff commited on
Commit
f2e1d62
·
verified ·
1 Parent(s): 61f93ff

Update optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +3 -1
optimization.py CHANGED
@@ -4,7 +4,7 @@
4
  from typing import Any
5
  from typing import Callable
6
  from typing import ParamSpec
7
-
8
  import spaces
9
  import torch
10
  from torch.utils._pytree import tree_map
@@ -40,6 +40,8 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
40
 
41
  dynamic_shapes = tree_map(lambda t: None, call.kwargs)
42
  dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
 
 
43
 
44
  exported = torch.export.export(
45
  mod=pipeline.transformer,
 
4
  from typing import Any
5
  from typing import Callable
6
  from typing import ParamSpec
7
+ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
8
  import spaces
9
  import torch
10
  from torch.utils._pytree import tree_map
 
40
 
41
  dynamic_shapes = tree_map(lambda t: None, call.kwargs)
42
  dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
43
+
44
+ quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
45
 
46
  exported = torch.export.export(
47
  mod=pipeline.transformer,