cbensimon HF Staff commited on
Commit
318b03c
·
1 Parent(s): b63cd34

Dynamic shapes

Browse files
Files changed (2) hide show
  1. app.py +1 -6
  2. optimization.py +16 -1
app.py CHANGED
@@ -18,12 +18,7 @@ from optimization import optimize_pipeline_
18
  MAX_SEED = np.iinfo(np.int32).max
19
 
20
  pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
21
-
22
- optimize_pipeline_(pipe,
23
- image=Image.new('RGB', (512, 512)),
24
- prompt='prompt',
25
- guidance_scale=2.5,
26
- )
27
 
28
  @spaces.GPU
29
  def infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, steps=28, progress=gr.Progress(track_tqdm=True)):
 
18
  MAX_SEED = np.iinfo(np.int32).max
19
 
20
  pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
21
+ optimize_pipeline_(pipe, image=None, prompt='prompt')
 
 
 
 
 
22
 
23
  @spaces.GPU
24
  def infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, steps=28, progress=gr.Progress(track_tqdm=True)):
optimization.py CHANGED
@@ -29,10 +29,25 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
29
 
30
  @spaces.GPU(duration=1500)
31
  def compile_transformer():
 
32
  with capture_component_call(pipeline, 'transformer') as call:
33
  pipeline(*args, **kwargs)
 
 
 
 
 
 
 
34
  pipeline.transformer.fuse_qkv_projections()
35
- exported = torch.export.export(pipeline.transformer, args=call.args, kwargs=call.kwargs)
 
 
 
 
 
 
 
36
  return aoti_compile(exported, INDUCTOR_CONFIGS)
37
 
38
  transformer_config = pipeline.transformer.config
 
29
 
30
  @spaces.GPU(duration=1500)
31
  def compile_transformer():
32
+
33
  with capture_component_call(pipeline, 'transformer') as call:
34
  pipeline(*args, **kwargs)
35
+
36
+ hidden_dim = torch.export.Dim('hidden', min=4096, max=8212)
37
+ dynamic_shapes = {
38
+ 'hidden_states': {1: hidden_dim},
39
+ 'img_ids': {0: hidden_dim},
40
+ }
41
+
42
  pipeline.transformer.fuse_qkv_projections()
43
+
44
+ exported = torch.export.export(
45
+ mod=pipeline.transformer,
46
+ args=call.args,
47
+ kwargs=call.kwargs,
48
+ dynamic_shapes=dynamic_shapes,
49
+ )
50
+
51
  return aoti_compile(exported, INDUCTOR_CONFIGS)
52
 
53
  transformer_config = pipeline.transformer.config