FLUX.1-Kontext-Dev / optimization.py
cbensimon's picture
cbensimon HF Staff
Fix dynamic shapes
fc3f0ed
raw
history blame
1.55 kB
"""
"""
from typing import Any
from typing import Callable
from typing import ParamSpec
import spaces
import torch
from torch.utils._pytree import tree_map_only
from pipeline_utils import capture_component_call
from zerogpu import aoti_compile
P = ParamSpec('P')
INDUCTOR_CONFIGS = {
'conv_1x1_as_mm': True,
'epilogue_fusion': False,
'coordinate_descent_tuning': True,
'coordinate_descent_check_all_directions': True,
'max_autotune': True,
'triton.cudagraphs': True,
}
def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
@spaces.GPU(duration=1500)
def compile_transformer():
with capture_component_call(pipeline, 'transformer') as call:
pipeline(*args, **kwargs)
hidden_dim = torch.export.Dim('hidden', min=4096, max=8212)
dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
dynamic_shapes |= {
'hidden_states': {1: hidden_dim},
'img_ids': {0: hidden_dim},
}
pipeline.transformer.fuse_qkv_projections()
exported = torch.export.export(
mod=pipeline.transformer,
args=call.args,
kwargs=call.kwargs,
dynamic_shapes=dynamic_shapes,
)
return aoti_compile(exported, INDUCTOR_CONFIGS)
transformer_config = pipeline.transformer.config
pipeline.transformer = compile_transformer()
pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]