File size: 2,814 Bytes
3df4fd5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b63cd34
3df4fd5
 
 
 
b63cd34
3df4fd5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""
"""
from contextvars import ContextVar
from io import BytesIO
from typing import Any
from typing import cast

import torch
from torch._inductor.package.package import package_aoti
from torch.export.pt2_archive._package import AOTICompiledModel
from torch.export.pt2_archive._package_weights import TensorProperties
from torch.export.pt2_archive._package_weights import Weights


INDUCTOR_CONFIGS_OVERRIDES = {
    'aot_inductor.package_constants_in_so': False,
    'aot_inductor.package_constants_on_disk': True,
    'aot_inductor.package': True,
}


class ZeroGPUCompiledModel:
    def __init__(self, archive_file: torch.types.FileLike, weights: Weights, cuda: bool = False):
        self.archive_file = archive_file
        self.weights = weights
        if cuda:
            self.weights_to_cuda_()
        self.compiled_model: ContextVar[AOTICompiledModel | None] = ContextVar('compiled_model', default=None)
    def weights_to_cuda_(self):
        for name in self.weights:
            tensor, properties = self.weights.get_weight(name)
            self.weights[name] = (tensor.to('cuda'), properties)
    def __call__(self, *args, **kwargs):
        if (compiled_model := self.compiled_model.get()) is None:
            constants_map = {name: value[0] for name, value in self.weights.items()}
            compiled_model = cast(AOTICompiledModel, torch._inductor.aoti_load_package(self.archive_file))
            compiled_model.load_constants(constants_map, check_full_update=True, user_managed=True)
            self.compiled_model.set(compiled_model)
        return compiled_model(*args, **kwargs)
    def __reduce__(self):
        weight_dict: dict[str, tuple[torch.Tensor, TensorProperties]] = {}
        for name in self.weights:
            tensor, properties = self.weights.get_weight(name)
            tensor_ = torch.empty_like(tensor, device='cpu').pin_memory()
            weight_dict[name] = (tensor_.copy_(tensor).detach().share_memory_(), properties)
        return ZeroGPUCompiledModel, (self.archive_file, Weights(weight_dict), True)


def aoti_compile(
    exported_program: torch.export.ExportedProgram,
    inductor_configs: dict[str, Any] | None = None,
):
    inductor_configs = (inductor_configs or {}) | INDUCTOR_CONFIGS_OVERRIDES
    gm = cast(torch.fx.GraphModule, exported_program.module())
    assert exported_program.example_inputs is not None
    args, kwargs = exported_program.example_inputs
    artifacts = torch._inductor.aot_compile(gm, args, kwargs, options=inductor_configs)
    archive_file = BytesIO()
    files: list[str | Weights] = [file for file in artifacts if isinstance(file, str)]
    package_aoti(archive_file, files)
    weights, = (artifact for artifact in artifacts if isinstance(artifact, Weights))
    return ZeroGPUCompiledModel(archive_file, weights)