Qwen-Image-Edit-Outpaint / optimization.py
cbensimon's picture
cbensimon HF Staff
Update optimization.py
10b04c2 verified
raw
history blame
1.28 kB
"""
"""
from typing import Any
from typing import Callable
from typing import ParamSpec
import spaces
import torch
from torch.utils._pytree import tree_map
P = ParamSpec('P')
TRANSFORMER_HIDDEN_DIM = torch.export.Dim.AUTO('hidden', min=3584, max=9727)
TRANSFORMER_DYNAMIC_SHAPES = {
'hidden_states': {1: TRANSFORMER_HIDDEN_DIM},
}
INDUCTOR_CONFIGS = {
'conv_1x1_as_mm': True,
'epilogue_fusion': False,
'coordinate_descent_tuning': True,
'coordinate_descent_check_all_directions': True,
'max_autotune': True,
'triton.cudagraphs': True,
}
def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
@spaces.GPU(duration=1500)
def compile_transformer():
with spaces.aoti_capture(pipeline.transformer) as call:
pipeline(*args, **kwargs)
dynamic_shapes = tree_map(lambda t: None, call.kwargs)
dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
exported = torch.export.export(
mod=pipeline.transformer,
args=call.args,
kwargs=call.kwargs,
dynamic_shapes=dynamic_shapes,
)
return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
spaces.aoti_apply(compile_transformer(), pipeline.transformer)