Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,670 Bytes
3df4fd5 288103a 3df4fd5 288103a 3df4fd5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
"""
"""
import spaces
import torch
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from zerogpu import aoti_compile
def _example_tensor(*shape):
return torch.randn(*shape, device='cuda', dtype=torch.bfloat16)
def optimize_pipeline_(pipeline: FluxPipeline):
is_timestep_distilled = not pipeline.transformer.config.guidance_embeds
seq_length = 256 if is_timestep_distilled else 512
transformer_kwargs = {
'hidden_states': _example_tensor(1, 4096, 64),
'timestep': torch.tensor([1.], device='cuda', dtype=torch.bfloat16),
'guidance': None if is_timestep_distilled else torch.tensor([1.], device='cuda', dtype=torch.bfloat16),
'pooled_projections': _example_tensor(1, 768),
'encoder_hidden_states': _example_tensor(1, seq_length, 4096),
'txt_ids': _example_tensor(seq_length, 3),
'img_ids': _example_tensor(4096, 3),
'joint_attention_kwargs': {},
'return_dict': False,
}
inductor_configs = {
'conv_1x1_as_mm': True,
'epilogue_fusion': False,
'coordinate_descent_tuning': True,
'coordinate_descent_check_all_directions': True,
'max_autotune': True,
'triton.cudagraphs': True,
}
@spaces.GPU(duration=1500)
def compile_transformer():
pipeline.transformer.fuse_qkv_projections()
exported = torch.export.export(pipeline.transformer, args=(), kwargs=transformer_kwargs)
return aoti_compile(exported, inductor_configs)
transformer_config = pipeline.transformer.config
pipeline.transformer = compile_transformer()
pipeline.transformer.config = transformer_config
|