Qwen-Image-Edit-Outpaint / optimization.py
cbensimon's picture
cbensimon HF Staff
Update optimization.py
59021ac verified
raw
history blame
1.43 kB
"""
"""
from typing import Any
from typing import Callable
from typing import ParamSpec
import spaces
import torch
from torch.utils._pytree import tree_map_only
from torchao.quantization import quantize_
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
P = ParamSpec('P')
TRANSFORMER_HIDDEN_DIM = torch.export.Dim('hidden', min=3584, max=9727)
TRANSFORMER_DYNAMIC_SHAPES = {
'hidden_states': {1: TRANSFORMER_HIDDEN_DIM},
}
INDUCTOR_CONFIGS = {
'conv_1x1_as_mm': True,
'epilogue_fusion': False,
'coordinate_descent_tuning': True,
'coordinate_descent_check_all_directions': True,
'max_autotune': True,
'triton.cudagraphs': True,
}
def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
@spaces.GPU(duration=1500)
def compile_transformer():
with spaces.aoti_capture(pipeline.transformer) as call:
pipeline(*args, **kwargs)
dynamic_shapes = tree_map_only((torch.Tensor, bool, int), lambda t: None, call.kwargs)
dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
exported = torch.export.export(
mod=pipeline.transformer,
args=call.args,
kwargs=call.kwargs,
dynamic_shapes=dynamic_shapes,
)
return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
spaces.aoti_apply(compile_transformer(), pipeline.transformer)