Spaces:
Running
on
Zero
Running
on
Zero
Dynamic shapes
Browse files- app.py +1 -6
- optimization.py +16 -1
app.py
CHANGED
@@ -18,12 +18,7 @@ from optimization import optimize_pipeline_
|
|
18 |
MAX_SEED = np.iinfo(np.int32).max
|
19 |
|
20 |
pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
|
21 |
-
|
22 |
-
optimize_pipeline_(pipe,
|
23 |
-
image=Image.new('RGB', (512, 512)),
|
24 |
-
prompt='prompt',
|
25 |
-
guidance_scale=2.5,
|
26 |
-
)
|
27 |
|
28 |
@spaces.GPU
|
29 |
def infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, steps=28, progress=gr.Progress(track_tqdm=True)):
|
|
|
18 |
MAX_SEED = np.iinfo(np.int32).max
|
19 |
|
20 |
pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
|
21 |
+
optimize_pipeline_(pipe, image=None, prompt='prompt')
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
@spaces.GPU
|
24 |
def infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, steps=28, progress=gr.Progress(track_tqdm=True)):
|
optimization.py
CHANGED
@@ -29,10 +29,25 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
|
|
29 |
|
30 |
@spaces.GPU(duration=1500)
|
31 |
def compile_transformer():
|
|
|
32 |
with capture_component_call(pipeline, 'transformer') as call:
|
33 |
pipeline(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
pipeline.transformer.fuse_qkv_projections()
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
return aoti_compile(exported, INDUCTOR_CONFIGS)
|
37 |
|
38 |
transformer_config = pipeline.transformer.config
|
|
|
29 |
|
30 |
@spaces.GPU(duration=1500)
|
31 |
def compile_transformer():
|
32 |
+
|
33 |
with capture_component_call(pipeline, 'transformer') as call:
|
34 |
pipeline(*args, **kwargs)
|
35 |
+
|
36 |
+
hidden_dim = torch.export.Dim('hidden', min=4096, max=8212)
|
37 |
+
dynamic_shapes = {
|
38 |
+
'hidden_states': {1: hidden_dim},
|
39 |
+
'img_ids': {0: hidden_dim},
|
40 |
+
}
|
41 |
+
|
42 |
pipeline.transformer.fuse_qkv_projections()
|
43 |
+
|
44 |
+
exported = torch.export.export(
|
45 |
+
mod=pipeline.transformer,
|
46 |
+
args=call.args,
|
47 |
+
kwargs=call.kwargs,
|
48 |
+
dynamic_shapes=dynamic_shapes,
|
49 |
+
)
|
50 |
+
|
51 |
return aoti_compile(exported, INDUCTOR_CONFIGS)
|
52 |
|
53 |
transformer_config = pipeline.transformer.config
|