multimodalart HF Staff commited on
Commit
37c471e
·
verified ·
1 Parent(s): 69c43ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -5
app.py CHANGED
@@ -197,7 +197,7 @@ class CustomFluxAttnProcessor2_0:
197
  print("--- Loading Models: This may take a few minutes and requires >40GB VRAM ---")
198
  controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", torch_dtype=torch.bfloat16)
199
  pipe = FluxControlNetInpaintingPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", controlnet=controlnet, torch_dtype=torch.bfloat16).to(device)
200
- pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
201
 
202
  pipe.transformer.to(torch.bfloat16)
203
  pipe.controlnet.to(torch.bfloat16)
@@ -211,9 +211,27 @@ segment_processor = AutoProcessor.from_pretrained(segmenter_id)
211
  object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=device)
212
  print("--- All models loaded successfully! ---")
213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
- # --- Main Inference Function for Gradio ---
216
- @spaces.GPU(duration=180)
217
  def run_diptych_prompting(
218
  input_image: Image.Image,
219
  subject_name: str,
@@ -223,7 +241,7 @@ def run_diptych_prompting(
223
  width: int = 1024,
224
  height: int = 1024,
225
  pixel_offset: int = 8,
226
- num_steps: int = 8,
227
  guidance: float = 3.5,
228
  seed: int = 42,
229
  randomize_seed: bool = False,
@@ -308,7 +326,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
308
  with gr.Accordion("Advanced Settings", open=False):
309
  attn_enforce = gr.Slider(minimum=1.0, maximum=2.0, value=1.3, step=0.05, label="Attention Enforcement")
310
  ctrl_scale = gr.Slider(minimum=0.5, maximum=1.0, value=0.95, step=0.01, label="ControlNet Scale")
311
- num_steps = gr.Slider(minimum=20, maximum=50, value=8, step=1, label="Inference Steps")
312
  guidance = gr.Slider(minimum=1.0, maximum=10.0, value=3.5, step=0.1, label="Guidance Scale")
313
  width = gr.Slider(minimum=512, maximum=1024, value=768, step=64, label="Image Width")
314
  height = gr.Slider(minimum=512, maximum=1024, value=768, step=64, label="Image Height")
 
197
  print("--- Loading Models: This may take a few minutes and requires >40GB VRAM ---")
198
  controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", torch_dtype=torch.bfloat16)
199
  pipe = FluxControlNetInpaintingPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", controlnet=controlnet, torch_dtype=torch.bfloat16).to(device)
200
+ # pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
201
 
202
  pipe.transformer.to(torch.bfloat16)
203
  pipe.controlnet.to(torch.bfloat16)
 
211
  object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=device)
212
  print("--- All models loaded successfully! ---")
213
 
214
+ def get_duration(
215
+ input_image: Image.Image,
216
+ subject_name: str,
217
+ target_prompt: str,
218
+ attn_enforce: float = 1.3,
219
+ ctrl_scale: float = 0.95,
220
+ width: int = 1024,
221
+ height: int = 1024,
222
+ pixel_offset: int = 8,
223
+ num_steps: int = 28,
224
+ guidance: float = 3.5,
225
+ seed: int = 42,
226
+ randomize_seed: bool = False,
227
+ progress=gr.Progress(track_tqdm=True)
228
+ ):
229
+ if width > 768 and height > 768:
230
+ return 210
231
+ else:
232
+ return 120
233
 
234
+ @spaces.GPU(duration=get_duration)
 
235
  def run_diptych_prompting(
236
  input_image: Image.Image,
237
  subject_name: str,
 
241
  width: int = 1024,
242
  height: int = 1024,
243
  pixel_offset: int = 8,
244
+ num_steps: int = 28,
245
  guidance: float = 3.5,
246
  seed: int = 42,
247
  randomize_seed: bool = False,
 
326
  with gr.Accordion("Advanced Settings", open=False):
327
  attn_enforce = gr.Slider(minimum=1.0, maximum=2.0, value=1.3, step=0.05, label="Attention Enforcement")
328
  ctrl_scale = gr.Slider(minimum=0.5, maximum=1.0, value=0.95, step=0.01, label="ControlNet Scale")
329
+ num_steps = gr.Slider(minimum=20, maximum=50, value=28, step=1, label="Inference Steps")
330
  guidance = gr.Slider(minimum=1.0, maximum=10.0, value=3.5, step=0.1, label="Guidance Scale")
331
  width = gr.Slider(minimum=512, maximum=1024, value=768, step=64, label="Image Width")
332
  height = gr.Slider(minimum=512, maximum=1024, value=768, step=64, label="Image Height")