Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,837 Bytes
3fc0a22 7f65363 3fc0a22 10bbb52 3fc0a22 10bbb52 3fc0a22 10bbb52 3fc0a22 7f65363 3fc0a22 7f65363 3fc0a22 7f65363 3fc0a22 7f65363 3fc0a22 7f65363 3fc0a22 10bbb52 3fc0a22 7f65363 3fc0a22 7f65363 10bbb52 3fc0a22 7f65363 3fc0a22 7f65363 10bbb52 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
"""
Taken from https://huggingface.co/spaces/cbensimon/wan2-1-fast/
"""
import contextlib
from contextvars import ContextVar
from io import BytesIO
from typing import Any, cast
from unittest.mock import patch
import torch
from torch.utils._pytree import tree_map_only
from torch._inductor.package.package import package_aoti
from torch._inductor.package import load_package
from torch.export.pt2_archive._package import AOTICompiledModel
from torch.export.pt2_archive._package_weights import Weights
INDUCTOR_CONFIGS_OVERRIDES = {
"aot_inductor.package_constants_in_so": False,
"aot_inductor.package_constants_on_disk": True,
"aot_inductor.package": True,
}
class ZeroGPUWeights:
def __init__(self, constants_map: dict[str, torch.Tensor], to_cuda: bool = False):
if to_cuda:
self.constants_map = {name: tensor.to("cuda") for name, tensor in constants_map.items()}
else:
self.constants_map = constants_map
def __reduce__(self):
constants_map: dict[str, torch.Tensor] = {}
for name, tensor in self.constants_map.items():
tensor_ = torch.empty_like(tensor, device="cpu").pin_memory()
constants_map[name] = tensor_.copy_(tensor).detach().share_memory_()
return ZeroGPUWeights, (constants_map, True)
class ZeroGPUCompiledModel:
def __init__(self, archive_file: torch.types.FileLike, weights: ZeroGPUWeights):
self.archive_file = archive_file
self.weights = weights
self.compiled_model: ContextVar[AOTICompiledModel | None] = ContextVar("compiled_model", default=None)
def __call__(self, *args, **kwargs):
if (compiled_model := self.compiled_model.get()) is None:
# compiled_model = cast(AOTICompiledModel, torch._inductor.aoti_load_package(self.archive_file))
# compiled_model = torch._inductor.aoti_load_package(self.archive_file, run_single_threaded=True)
compiled_model = load_package(self.archive_file, run_single_threaded=True)
compiled_model.load_constants(self.weights.constants_map, check_full_update=True, user_managed=True)
self.compiled_model.set(compiled_model)
return compiled_model(*args, **kwargs)
def __reduce__(self):
return ZeroGPUCompiledModel, (self.archive_file, self.weights)
def aoti_compile(
exported_program: torch.export.ExportedProgram,
inductor_configs: dict[str, Any] | None = None,
):
inductor_configs = (inductor_configs or {}) | INDUCTOR_CONFIGS_OVERRIDES
gm = cast(torch.fx.GraphModule, exported_program.module())
assert exported_program.example_inputs is not None
args, kwargs = exported_program.example_inputs
artifacts = torch._inductor.aot_compile(gm, args, kwargs, options=inductor_configs)
archive_file = BytesIO()
files: list[str | Weights] = [file for file in artifacts if isinstance(file, str)]
package_aoti(archive_file, files)
(weights,) = (artifact for artifact in artifacts if isinstance(artifact, Weights))
zerogpu_weights = ZeroGPUWeights({name: weights.get_weight(name)[0] for name in weights}, to_cuda=True)
return ZeroGPUCompiledModel(archive_file, zerogpu_weights)
@contextlib.contextmanager
def capture_component_call(
pipeline: Any,
component_name: str,
component_method="forward",
):
class CapturedCallException(Exception):
def __init__(self, *args, **kwargs):
super().__init__()
self.args = args
self.kwargs = kwargs
class CapturedCall:
def __init__(self):
self.args: tuple[Any, ...] = ()
self.kwargs: dict[str, Any] = {}
component = getattr(pipeline, component_name)
captured_call = CapturedCall()
def capture_call(*args, **kwargs):
raise CapturedCallException(*args, **kwargs)
with patch.object(component, component_method, new=capture_call):
try:
yield captured_call
except CapturedCallException as e:
captured_call.args = e.args
captured_call.kwargs = e.kwargs
# Taken from
# https://github.com/huggingface/flux-fast/blob/5027798d7f69a8e0e478df92f48663c40727f8ea/utils/pipeline_utils.py#L198C1-L231C14
def cudagraph(f):
_graphs = {}
def f_(*args, **kwargs):
key = hash(tuple(tuple(kwargs[a].shape) for a in sorted(kwargs.keys())
if isinstance(kwargs[a], torch.Tensor)))
if key in _graphs:
# use the cached wrapper if one exists. this will perform CUDAGraph replay
wrapped, *_ = _graphs[key]
return wrapped(*args, **kwargs)
# record a new CUDAGraph and cache it for future use
g = torch.cuda.CUDAGraph()
in_args, in_kwargs = tree_map_only(torch.Tensor, lambda t: t.clone(), (args, kwargs))
f(*in_args, **in_kwargs) # stream warmup
with torch.cuda.graph(g):
out_tensors = f(*in_args, **in_kwargs)
def wrapped(*args, **kwargs):
# note that CUDAGraphs require inputs / outputs to be in fixed memory locations.
# inputs must be copied into the fixed input memory locations.
[a.copy_(b) for a, b in zip(in_args, args) if isinstance(a, torch.Tensor)]
for key in kwargs:
if isinstance(kwargs[key], torch.Tensor):
in_kwargs[key].copy_(kwargs[key])
g.replay()
# clone() outputs on the way out to disconnect them from the fixed output memory
# locations. this allows for CUDAGraph reuse without accidentally overwriting memory
return [o.clone() for o in out_tensors]
# cache function that does CUDAGraph replay
_graphs[key] = (wrapped, g, in_args, in_kwargs, out_tensors)
return wrapped(*args, **kwargs)
return f_ |