cbensimon HF Staff commited on
Commit
b63cd34
·
1 Parent(s): 2e80751

capture_component_call

Browse files
Files changed (4) hide show
  1. app.py +6 -1
  2. optimization.py +20 -29
  3. pipeline_utils.py +40 -0
  4. zerogpu.py +2 -2
app.py CHANGED
@@ -18,7 +18,12 @@ from optimization import optimize_pipeline_
18
  MAX_SEED = np.iinfo(np.int32).max
19
 
20
  pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
21
- optimize_pipeline_(pipe)
 
 
 
 
 
22
 
23
  @spaces.GPU
24
  def infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, steps=28, progress=gr.Progress(track_tqdm=True)):
 
18
  MAX_SEED = np.iinfo(np.int32).max
19
 
20
  pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
21
+
22
+ optimize_pipeline_(pipe,
23
+ image=Image.new('RGB', (512, 512)),
24
+ prompt='prompt',
25
+ guidance_scale=2.5,
26
+ )
27
 
28
  @spaces.GPU
29
  def infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, steps=28, progress=gr.Progress(track_tqdm=True)):
optimization.py CHANGED
@@ -1,49 +1,40 @@
1
  """
2
  """
3
 
 
 
 
 
4
  import spaces
5
  import torch
6
- from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
7
 
 
8
  from zerogpu import aoti_compile
9
 
10
 
11
- def _example_tensor(*shape):
12
- return torch.randn(*shape, device='cuda', dtype=torch.bfloat16)
13
-
14
 
15
- def optimize_pipeline_(pipeline: FluxPipeline):
16
 
17
- is_timestep_distilled = not pipeline.transformer.config.guidance_embeds
18
- seq_length = 256 if is_timestep_distilled else 512
 
 
 
 
 
 
19
 
20
- transformer_kwargs = {
21
- 'hidden_states': _example_tensor(1, 4096, 64),
22
- 'timestep': torch.tensor([1.], device='cuda', dtype=torch.bfloat16),
23
- 'guidance': None if is_timestep_distilled else torch.tensor([1.], device='cuda', dtype=torch.bfloat16),
24
- 'pooled_projections': _example_tensor(1, 768),
25
- 'encoder_hidden_states': _example_tensor(1, seq_length, 4096),
26
- 'txt_ids': _example_tensor(seq_length, 3),
27
- 'img_ids': _example_tensor(4096, 3),
28
- 'joint_attention_kwargs': {},
29
- 'return_dict': False,
30
- }
31
 
32
- inductor_configs = {
33
- 'conv_1x1_as_mm': True,
34
- 'epilogue_fusion': False,
35
- 'coordinate_descent_tuning': True,
36
- 'coordinate_descent_check_all_directions': True,
37
- 'max_autotune': True,
38
- 'triton.cudagraphs': True,
39
- }
40
 
41
  @spaces.GPU(duration=1500)
42
  def compile_transformer():
 
 
43
  pipeline.transformer.fuse_qkv_projections()
44
- exported = torch.export.export(pipeline.transformer, args=(), kwargs=transformer_kwargs)
45
- return aoti_compile(exported, inductor_configs)
46
 
47
  transformer_config = pipeline.transformer.config
48
  pipeline.transformer = compile_transformer()
49
- pipeline.transformer.config = transformer_config
 
1
  """
2
  """
3
 
4
+ from typing import Any
5
+ from typing import Callable
6
+ from typing import ParamSpec
7
+
8
  import spaces
9
  import torch
 
10
 
11
+ from pipeline_utils import capture_component_call
12
  from zerogpu import aoti_compile
13
 
14
 
15
+ P = ParamSpec('P')
 
 
16
 
 
17
 
18
+ INDUCTOR_CONFIGS = {
19
+ 'conv_1x1_as_mm': True,
20
+ 'epilogue_fusion': False,
21
+ 'coordinate_descent_tuning': True,
22
+ 'coordinate_descent_check_all_directions': True,
23
+ 'max_autotune': True,
24
+ 'triton.cudagraphs': True,
25
+ }
26
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
 
 
 
 
 
 
 
29
 
30
  @spaces.GPU(duration=1500)
31
  def compile_transformer():
32
+ with capture_component_call(pipeline, 'transformer') as call:
33
+ pipeline(*args, **kwargs)
34
  pipeline.transformer.fuse_qkv_projections()
35
+ exported = torch.export.export(pipeline.transformer, args=call.args, kwargs=call.kwargs)
36
+ return aoti_compile(exported, INDUCTOR_CONFIGS)
37
 
38
  transformer_config = pipeline.transformer.config
39
  pipeline.transformer = compile_transformer()
40
+ pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
pipeline_utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ import contextlib
5
+ from unittest.mock import patch
6
+
7
+ from typing import Any
8
+
9
+
10
+ class CapturedCallException(Exception):
11
+ def __init__(self, *args, **kwargs):
12
+ super().__init__()
13
+ self.args = args
14
+ self.kwargs = kwargs
15
+
16
+
17
+ class CapturedCall:
18
+ def __init__(self):
19
+ self.args: tuple[Any, ...] = ()
20
+ self.kwargs: dict[str, Any] = {}
21
+
22
+
23
+ @contextlib.contextmanager
24
+ def capture_component_call(
25
+ pipeline: Any,
26
+ component_name: str,
27
+ component_method='forward',
28
+ ):
29
+ component = getattr(pipeline, component_name)
30
+ captured_call = CapturedCall()
31
+
32
+ def capture_call(*args, **kwargs):
33
+ raise CapturedCallException(*args, **kwargs)
34
+
35
+ with patch.object(component, component_method, new=capture_call):
36
+ try:
37
+ yield captured_call
38
+ except CapturedCallException as e:
39
+ captured_call.args = e.args
40
+ captured_call.kwargs = e.kwargs
zerogpu.py CHANGED
@@ -51,12 +51,12 @@ def aoti_compile(
51
  inductor_configs: dict[str, Any] | None = None,
52
  ):
53
  inductor_configs = (inductor_configs or {}) | INDUCTOR_CONFIGS_OVERRIDES
54
- gm = exported_program.module()
55
  assert exported_program.example_inputs is not None
56
  args, kwargs = exported_program.example_inputs
57
  artifacts = torch._inductor.aot_compile(gm, args, kwargs, options=inductor_configs)
58
  archive_file = BytesIO()
59
- files = [file for file in artifacts if isinstance(file, str)]
60
  package_aoti(archive_file, files)
61
  weights, = (artifact for artifact in artifacts if isinstance(artifact, Weights))
62
  return ZeroGPUCompiledModel(archive_file, weights)
 
51
  inductor_configs: dict[str, Any] | None = None,
52
  ):
53
  inductor_configs = (inductor_configs or {}) | INDUCTOR_CONFIGS_OVERRIDES
54
+ gm = cast(torch.fx.GraphModule, exported_program.module())
55
  assert exported_program.example_inputs is not None
56
  args, kwargs = exported_program.example_inputs
57
  artifacts = torch._inductor.aot_compile(gm, args, kwargs, options=inductor_configs)
58
  archive_file = BytesIO()
59
+ files: list[str | Weights] = [file for file in artifacts if isinstance(file, str)]
60
  package_aoti(archive_file, files)
61
  weights, = (artifact for artifact in artifacts if isinstance(artifact, Weights))
62
  return ZeroGPUCompiledModel(archive_file, weights)