linoyts HF Staff commited on
Commit
96a0788
·
verified ·
1 Parent(s): 22e57cc

Upload 2 files

Browse files
Files changed (2) hide show
  1. optimization.py +65 -0
  2. optimization_utils.py +150 -0
optimization.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ from typing import Any
5
+ from typing import Callable
6
+ from typing import ParamSpec
7
+
8
+ import spaces
9
+ import torch
10
+ from torch.utils._pytree import tree_map_only
11
+ from torchao.quantization import quantize_
12
+ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
13
+
14
+ from optimization_utils import capture_component_call
15
+ from optimization_utils import aoti_compile
16
+ from optimization_utils import cudagraph
17
+
18
+
19
+ P = ParamSpec('P')
20
+
21
+
22
+ TRANSFORMER_HIDDEN_DIM = torch.export.Dim('hidden', min=4096, max=8212)
23
+
24
+ TRANSFORMER_DYNAMIC_SHAPES = {
25
+ 'hidden_states': {1: TRANSFORMER_HIDDEN_DIM},
26
+ 'img_ids': {0: TRANSFORMER_HIDDEN_DIM},
27
+ }
28
+
29
+ INDUCTOR_CONFIGS = {
30
+ 'conv_1x1_as_mm': True,
31
+ 'epilogue_fusion': False,
32
+ 'coordinate_descent_tuning': True,
33
+ 'coordinate_descent_check_all_directions': True,
34
+ 'max_autotune': True,
35
+ 'triton.cudagraphs': True,
36
+ }
37
+
38
+
39
+ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
40
+
41
+ @spaces.GPU(duration=1500)
42
+ def compile_transformer():
43
+
44
+ with capture_component_call(pipeline, 'transformer') as call:
45
+ pipeline(*args, **kwargs)
46
+
47
+ dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
48
+ dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
49
+
50
+ pipeline.transformer.fuse_qkv_projections()
51
+
52
+ quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
53
+
54
+ exported = torch.export.export(
55
+ mod=pipeline.transformer,
56
+ args=call.args,
57
+ kwargs=call.kwargs,
58
+ dynamic_shapes=dynamic_shapes,
59
+ )
60
+
61
+ return aoti_compile(exported, INDUCTOR_CONFIGS)
62
+
63
+ transformer_config = pipeline.transformer.config
64
+ pipeline.transformer = compile_transformer()
65
+ pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
optimization_utils.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ import contextlib
4
+ from contextvars import ContextVar
5
+ from io import BytesIO
6
+ from typing import Any
7
+ from typing import Callable
8
+ from typing import ParamSpec
9
+ from typing import TypeVar
10
+ from typing import cast
11
+ from unittest.mock import patch
12
+
13
+ import torch
14
+ from torch.utils._pytree import tree_map_only
15
+ from torch._inductor.package.package import package_aoti
16
+ from torch.export.pt2_archive._package import AOTICompiledModel
17
+ from torch.export.pt2_archive._package_weights import TensorProperties
18
+ from torch.export.pt2_archive._package_weights import Weights
19
+
20
+
21
+ P = ParamSpec('P')
22
+ T = TypeVar('T')
23
+
24
+
25
+ INDUCTOR_CONFIGS_OVERRIDES = {
26
+ 'aot_inductor.package_constants_in_so': False,
27
+ 'aot_inductor.package_constants_on_disk': True,
28
+ 'aot_inductor.package': True,
29
+ }
30
+
31
+
32
+ class ZeroGPUCompiledModel:
33
+ def __init__(self, archive_file: torch.types.FileLike, weights: Weights, cuda: bool = False):
34
+ self.archive_file = archive_file
35
+ self.weights = weights
36
+ if cuda:
37
+ self.weights_to_cuda_()
38
+ self.compiled_model: ContextVar[AOTICompiledModel | None] = ContextVar('compiled_model', default=None)
39
+ def weights_to_cuda_(self):
40
+ for name in self.weights:
41
+ tensor, properties = self.weights.get_weight(name)
42
+ self.weights[name] = (tensor.to('cuda'), properties)
43
+ def __call__(self, *args, **kwargs):
44
+ if (compiled_model := self.compiled_model.get()) is None:
45
+ constants_map = {name: value[0] for name, value in self.weights.items()}
46
+ compiled_model = cast(AOTICompiledModel, torch._inductor.aoti_load_package(self.archive_file))
47
+ compiled_model.load_constants(constants_map, check_full_update=True, user_managed=True)
48
+ self.compiled_model.set(compiled_model)
49
+ return compiled_model(*args, **kwargs)
50
+ def __reduce__(self):
51
+ weight_dict: dict[str, tuple[torch.Tensor, TensorProperties]] = {}
52
+ for name in self.weights:
53
+ tensor, properties = self.weights.get_weight(name)
54
+ tensor_ = torch.empty_like(tensor, device='cpu').pin_memory()
55
+ weight_dict[name] = (tensor_.copy_(tensor).detach().share_memory_(), properties)
56
+ return ZeroGPUCompiledModel, (self.archive_file, Weights(weight_dict), True)
57
+
58
+
59
+ def aoti_compile(
60
+ exported_program: torch.export.ExportedProgram,
61
+ inductor_configs: dict[str, Any] | None = None,
62
+ ):
63
+ inductor_configs = (inductor_configs or {}) | INDUCTOR_CONFIGS_OVERRIDES
64
+ gm = cast(torch.fx.GraphModule, exported_program.module())
65
+ assert exported_program.example_inputs is not None
66
+ args, kwargs = exported_program.example_inputs
67
+ artifacts = torch._inductor.aot_compile(gm, args, kwargs, options=inductor_configs)
68
+ archive_file = BytesIO()
69
+ files: list[str | Weights] = [file for file in artifacts if isinstance(file, str)]
70
+ package_aoti(archive_file, files)
71
+ weights, = (artifact for artifact in artifacts if isinstance(artifact, Weights))
72
+ return ZeroGPUCompiledModel(archive_file, weights)
73
+
74
+
75
+ def cudagraph(fn: Callable[P, list[torch.Tensor]]):
76
+
77
+ graphs = {}
78
+
79
+ def fn_(*args: P.args, **kwargs: P.kwargs):
80
+
81
+ key = hash(tuple(
82
+ tuple(kwarg.shape)
83
+ for a in sorted(kwargs.keys())
84
+ if isinstance((kwarg := kwargs[a]), torch.Tensor)
85
+ ))
86
+
87
+ if key in graphs:
88
+ wrapped, *_ = graphs[key]
89
+ return wrapped(*args, **kwargs)
90
+
91
+ graph = torch.cuda.CUDAGraph()
92
+ in_args, in_kwargs = tree_map_only(torch.Tensor, lambda t: t.clone(), (args, kwargs))
93
+ in_args, in_kwargs = _cast_as((args, kwargs), (in_args, in_kwargs))
94
+
95
+ fn(*in_args, **in_kwargs)
96
+ with torch.cuda.graph(graph):
97
+ out_tensors = fn(*in_args, **in_kwargs)
98
+
99
+ def wrapped(*args: P.args, **kwargs: P.kwargs):
100
+ for a, b in zip(in_args, args):
101
+ if isinstance(a, torch.Tensor):
102
+ assert isinstance(b, torch.Tensor)
103
+ a.copy_(b)
104
+ for key in kwargs:
105
+ if isinstance((kwarg := kwargs[key]), torch.Tensor):
106
+ assert isinstance((in_kwarg := in_kwargs[key]), torch.Tensor)
107
+ in_kwarg.copy_(kwarg)
108
+ graph.replay()
109
+ return [tensor.clone() for tensor in out_tensors]
110
+
111
+ graphs[key] = (wrapped, graph, in_args, in_kwargs, out_tensors)
112
+ return wrapped(*args, **kwargs)
113
+
114
+ return fn_
115
+
116
+
117
+ @contextlib.contextmanager
118
+ def capture_component_call(
119
+ pipeline: Any,
120
+ component_name: str,
121
+ component_method='forward',
122
+ ):
123
+
124
+ class CapturedCallException(Exception):
125
+ def __init__(self, *args, **kwargs):
126
+ super().__init__()
127
+ self.args = args
128
+ self.kwargs = kwargs
129
+
130
+ class CapturedCall:
131
+ def __init__(self):
132
+ self.args: tuple[Any, ...] = ()
133
+ self.kwargs: dict[str, Any] = {}
134
+
135
+ component = getattr(pipeline, component_name)
136
+ captured_call = CapturedCall()
137
+
138
+ def capture_call(*args, **kwargs):
139
+ raise CapturedCallException(*args, **kwargs)
140
+
141
+ with patch.object(component, component_method, new=capture_call):
142
+ try:
143
+ yield captured_call
144
+ except CapturedCallException as e:
145
+ captured_call.args = e.args
146
+ captured_call.kwargs = e.kwargs
147
+
148
+
149
+ def _cast_as(type_from: T, value: Any) -> T:
150
+ return value