linoyts HF Staff commited on
Commit
10a1bb0
·
verified ·
1 Parent(s): 1866240

Create optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +77 -0
optimization.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ from typing import Any
5
+ from typing import Callable
6
+ from typing import ParamSpec
7
+ from torchao.quantization import quantize_
8
+ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
9
+ import spaces
10
+ import torch
11
+ from torch.utils._pytree import tree_map
12
+
13
+
14
+ P = ParamSpec('P')
15
+
16
+
17
+ TRANSFORMER_IMAGE_SEQ_LENGTH_DIM = torch.export.Dim('image_seq_length')
18
+ TRANSFORMER_TEXT_SEQ_LENGTH_DIM = torch.export.Dim('text_seq_length')
19
+
20
+ TRANSFORMER_DYNAMIC_SHAPES = {
21
+ 'hidden_states': {
22
+ 1: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM,
23
+ },
24
+ 'encoder_hidden_states': {
25
+ 1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
26
+ },
27
+ 'encoder_hidden_states_mask': {
28
+ 1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
29
+ },
30
+ 'image_rotary_emb': ({
31
+ 0: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM,
32
+ }, {
33
+ 0: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
34
+ }),
35
+ }
36
+
37
+
38
+ INDUCTOR_CONFIGS = {
39
+ 'conv_1x1_as_mm': True,
40
+ 'epilogue_fusion': False,
41
+ 'coordinate_descent_tuning': True,
42
+ 'coordinate_descent_check_all_directions': True,
43
+ 'max_autotune': True,
44
+ 'triton.cudagraphs': True,
45
+ }
46
+
47
+
48
+ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
49
+
50
+ @spaces.GPU(duration=1500)
51
+ def compile_transformer():
52
+
53
+ pipeline.load_lora_weights(
54
+ "lightx2v/Qwen-Image-Lightning",
55
+ weight_name="Qwen-Image-Lightning-8steps-V1.1.safetensors"
56
+ )
57
+ pipeline.fuse_lora()
58
+ pipeline.unload_lora_weights()
59
+
60
+ with spaces.aoti_capture(pipeline.transformer) as call:
61
+ pipeline(*args, **kwargs)
62
+
63
+ dynamic_shapes = tree_map(lambda t: None, call.kwargs)
64
+ dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
65
+
66
+ # quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
67
+
68
+ exported = torch.export.export(
69
+ mod=pipeline.transformer,
70
+ args=call.args,
71
+ kwargs=call.kwargs,
72
+ dynamic_shapes=dynamic_shapes,
73
+ )
74
+
75
+ return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
76
+
77
+ spaces.aoti_apply(compile_transformer(), pipeline.transformer)