multimodalart HF Staff commited on
Commit
7a2b253
·
verified ·
1 Parent(s): 4743e79

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -64
app.py CHANGED
@@ -19,7 +19,7 @@ import random
19
 
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
- # --- Helper Dataclasses (Identical to diptych_prompting_inference.py) ---
23
  @dataclass
24
  class BoundingBox:
25
  xmin: int
@@ -48,7 +48,7 @@ class DetectionResult:
48
  ymax=detection_dict['box']['ymax']))
49
 
50
 
51
- # --- Helper Functions (Identical to diptych_prompting_inference.py) ---
52
  def mask_to_polygon(mask: np.ndarray) -> List[List[int]]:
53
  contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
54
  if not contours:
@@ -127,7 +127,7 @@ def make_diptych(image):
127
  return Image.fromarray(diptych_np)
128
 
129
 
130
- # --- Custom Attention Processor (EXACTLY as in diptych_prompting_inference.py) ---
131
  class CustomFluxAttnProcessor2_0:
132
  def __init__(self, height=44, width=88, attn_enforce=1.0):
133
  if not hasattr(F, "scaled_dot_product_attention"):
@@ -197,7 +197,6 @@ class CustomFluxAttnProcessor2_0:
197
  print("--- Loading Models: This may take a few minutes and requires >40GB VRAM ---")
198
  controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", torch_dtype=torch.bfloat16)
199
  pipe = FluxControlNetInpaintingPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", controlnet=controlnet, torch_dtype=torch.bfloat16).to(device)
200
- # pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
201
 
202
  pipe.transformer.to(torch.bfloat16)
203
  pipe.controlnet.to(torch.bfloat16)
@@ -213,21 +212,21 @@ print("--- All models loaded successfully! ---")
213
 
214
  def get_duration(
215
  input_image: Image.Image,
216
- subject_name: str,
217
- target_prompt: str,
218
- attn_enforce: float = 1.3,
219
- ctrl_scale: float = 0.95,
220
- width: int = 768,
221
- height: int = 768,
222
- pixel_offset: int = 8,
223
- num_steps: int = 28,
224
- guidance: float = 3.5,
225
- real_guidance: float = 4.5,
226
- seed: int = 42,
227
- randomize_seed: bool = False,
228
  progress=gr.Progress(track_tqdm=True)
229
  ):
230
- if width > 768 and height > 768:
231
  return 210
232
  else:
233
  return 120
@@ -236,17 +235,18 @@ def get_duration(
236
  def run_diptych_prompting(
237
  input_image: Image.Image,
238
  subject_name: str,
239
- target_prompt: str,
240
- attn_enforce: float = 1.3,
241
- ctrl_scale: float = 0.95,
242
- width: int = 768,
243
- height: int = 768,
244
- pixel_offset: int = 8,
245
- num_steps: int = 28,
246
- guidance: float = 3.5,
247
- real_guidance: float = 4.5,
248
- seed: int = 42,
249
- randomize_seed: bool = False,
 
250
  progress=gr.Progress(track_tqdm=True)
251
  ):
252
  if randomize_seed:
@@ -255,40 +255,42 @@ def run_diptych_prompting(
255
  actual_seed = seed
256
 
257
  if input_image is None: raise gr.Error("Please upload a reference image.")
258
- if not subject_name: raise gr.Error("Please provide the subject's name (e.g., 'a red car').")
259
- if not target_prompt: raise gr.Error("Please provide a target prompt.")
260
 
261
- # 1. Prepare dimensions (logic from original script's main block)
262
  padded_width = width + pixel_offset * 2
263
  padded_height = height + pixel_offset * 2
264
  diptych_size = (padded_width * 2, padded_height)
265
-
266
- # 2. Prepare prompts and images
267
- progress(0, desc="Resizing and segmenting reference image...")
268
- base_prompt = f"a photo of {subject_name}"
269
- diptych_text_prompt = f"A diptych with two side-by-side images of same {subject_name}. On the left, {base_prompt}. On the right, replicate this {subject_name} exactly but as {target_prompt}"
270
-
271
  reference_image = input_image.resize((padded_width, padded_height)).convert("RGB")
272
- segmented_image = segment_image(reference_image, subject_name, object_detector, segmentator, segment_processor)
 
 
 
 
 
 
 
 
 
273
 
 
274
  progress(0.2, desc="Creating diptych and mask...")
275
  mask_image = np.concatenate([np.zeros((padded_height, padded_width, 3)), np.ones((padded_height, padded_width, 3)) * 255], axis=1)
276
  mask_image = Image.fromarray(mask_image.astype(np.uint8))
277
- diptych_image_prompt = make_diptych(segmented_image)
278
 
279
- # 3. Setup Attention Processor (logic from original script's main block)
280
  progress(0.3, desc="Setting up attention processors...")
281
  new_attn_procs = base_attn_procs.copy()
282
  for k in new_attn_procs:
283
- # Use full diptych dimensions for the attention processor
284
  new_attn_procs[k] = CustomFluxAttnProcessor2_0(height=padded_height // 16, width=padded_width * 2 // 16, attn_enforce=attn_enforce)
285
  pipe.transformer.set_attn_processor(new_attn_procs)
286
 
287
- # 4. Run Inference (using parameters identical to the original script)
288
  progress(0.4, desc="Running diffusion process...")
289
  generator = torch.Generator(device="cuda").manual_seed(actual_seed)
290
- result = pipe(
291
- prompt=diptych_text_prompt,
292
  height=diptych_size[1],
293
  width=diptych_size[0],
294
  control_image=diptych_image_prompt,
@@ -301,14 +303,13 @@ def run_diptych_prompting(
301
  true_guidance_scale=real_guidance
302
  ).images[0]
303
 
304
- # 5. Final cropping (logic from original script's main block)
305
  progress(0.95, desc="Finalizing image...")
306
- # Crop the right panel
307
- result = result.crop((padded_width, 0, padded_width * 2, padded_height))
308
- # Crop the pixel offset padding
309
- result = result.crop((pixel_offset, pixel_offset, padded_width - pixel_offset, padded_height - pixel_offset))
310
 
311
- return result
 
312
 
313
 
314
  # --- Gradio UI Definition ---
@@ -318,18 +319,29 @@ css = '''
318
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
319
  gr.Markdown(
320
  """
321
- # Diptych Prompting: Zero-Shot Subject-Driven Image Generation
322
  ### Gradio Demo for the paper "[Large-Scale Text-to-Image Model with Inpainting is a Zero-Shot Subject-Driven Image Generator](https://diptychprompting.github.io/)"
 
323
  """
324
  )
325
  with gr.Row():
326
  with gr.Column(scale=1):
327
- input_image = gr.Image(type="pil", label="1. Reference Image")
328
- subject_name = gr.Textbox(label="2. Subject Name", placeholder="e.g., a plush bear")
329
- target_prompt = gr.Textbox(label="3. Target Prompt", placeholder="e.g., a plush bear riding a skate on the moon")
 
 
 
 
330
  run_button = gr.Button("Generate Image", variant="primary")
 
331
  with gr.Accordion("Advanced Settings", open=False):
 
 
 
 
332
  attn_enforce = gr.Slider(minimum=1.0, maximum=2.0, value=1.3, step=0.05, label="Attention Enforcement")
 
333
  ctrl_scale = gr.Slider(minimum=0.5, maximum=1.0, value=0.95, step=0.01, label="ControlNet Scale")
334
  num_steps = gr.Slider(minimum=20, maximum=50, value=28, step=1, label="Inference Steps")
335
  guidance = gr.Slider(minimum=1.0, maximum=10.0, value=3.5, step=0.1, label="Distilled Guidance Scale")
@@ -339,9 +351,85 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
339
  pixel_offset = gr.Slider(minimum=0, maximum=32, value=8, step=1, label="Padding (Pixel Offset)")
340
  seed = gr.Slider(minimum=0, maximum=9223372036854775807, value=42, step=1, label="Seed")
341
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
 
342
  with gr.Column(scale=1):
343
  output_image = gr.Image(type="pil", label="Generated Image")
 
 
 
 
344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  gr.Examples(
346
  examples=[
347
  ["./assets/cat_squished.png", "a cat toy", "a cat toy riding a skate"],
@@ -349,16 +437,10 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
349
  ["./assets/bear_plushie.jpg", "a bear plushie", "a bear plushie drinking bubble tea"]
350
  ],
351
  inputs=[input_image, subject_name, target_prompt],
352
- outputs=output_image,
353
- fn=run_diptych_prompting,
354
- cache_examples="lazy",
355
- )
356
-
357
- run_button.click(
358
- fn=run_diptych_prompting,
359
- inputs=[input_image, subject_name, target_prompt, attn_enforce, ctrl_scale, width, height, pixel_offset, num_steps, guidance, real_guidance, seed, randomize_seed],
360
- outputs=output_image
361
  )
362
 
363
  if __name__ == "__main__":
364
- demo.launch(share=True)
 
19
 
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
+ # --- Helper Dataclasses (Identical to previous version) ---
23
  @dataclass
24
  class BoundingBox:
25
  xmin: int
 
48
  ymax=detection_dict['box']['ymax']))
49
 
50
 
51
+ # --- Helper Functions (Identical to previous version) ---
52
  def mask_to_polygon(mask: np.ndarray) -> List[List[int]]:
53
  contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
54
  if not contours:
 
127
  return Image.fromarray(diptych_np)
128
 
129
 
130
+ # --- Custom Attention Processor (Identical to previous version) ---
131
  class CustomFluxAttnProcessor2_0:
132
  def __init__(self, height=44, width=88, attn_enforce=1.0):
133
  if not hasattr(F, "scaled_dot_product_attention"):
 
197
  print("--- Loading Models: This may take a few minutes and requires >40GB VRAM ---")
198
  controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", torch_dtype=torch.bfloat16)
199
  pipe = FluxControlNetInpaintingPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", controlnet=controlnet, torch_dtype=torch.bfloat16).to(device)
 
200
 
201
  pipe.transformer.to(torch.bfloat16)
202
  pipe.controlnet.to(torch.bfloat16)
 
212
 
213
  def get_duration(
214
  input_image: Image.Image,
215
+ do_segmentation: bool,
216
+ full_prompt: str,
217
+ attn_enforce: float,
218
+ ctrl_scale: float,
219
+ width: int,
220
+ height: int,
221
+ pixel_offset: int,
222
+ num_steps: int,
223
+ guidance: float,
224
+ real_guidance: float,
225
+ seed: int,
226
+ randomize_seed: bool,
227
  progress=gr.Progress(track_tqdm=True)
228
  ):
229
+ if width > 768 or height > 768:
230
  return 210
231
  else:
232
  return 120
 
235
  def run_diptych_prompting(
236
  input_image: Image.Image,
237
  subject_name: str,
238
+ do_segmentation: bool,
239
+ full_prompt: str,
240
+ attn_enforce: float,
241
+ ctrl_scale: float,
242
+ width: int,
243
+ height: int,
244
+ pixel_offset: int,
245
+ num_steps: int,
246
+ guidance: float,
247
+ real_guidance: float,
248
+ seed: int,
249
+ randomize_seed: bool,
250
  progress=gr.Progress(track_tqdm=True)
251
  ):
252
  if randomize_seed:
 
255
  actual_seed = seed
256
 
257
  if input_image is None: raise gr.Error("Please upload a reference image.")
258
+ if not full_prompt: raise gr.Error("Full Prompt is empty. Please fill out the prompt fields.")
 
259
 
260
+ # 1. Prepare dimensions and reference image
261
  padded_width = width + pixel_offset * 2
262
  padded_height = height + pixel_offset * 2
263
  diptych_size = (padded_width * 2, padded_height)
 
 
 
 
 
 
264
  reference_image = input_image.resize((padded_width, padded_height)).convert("RGB")
265
+
266
+ # 2. Process reference image based on segmentation flag
267
+ progress(0, desc="Preparing reference image...")
268
+ if do_segmentation:
269
+ if not subject_name:
270
+ raise gr.Error("Subject Name is required when 'Do Segmentation' is checked.")
271
+ progress(0.05, desc="Segmenting reference image...")
272
+ processed_image = segment_image(reference_image, subject_name, object_detector, segmentator, segment_processor)
273
+ else:
274
+ processed_image = reference_image
275
 
276
+ # 3. Create diptych and mask
277
  progress(0.2, desc="Creating diptych and mask...")
278
  mask_image = np.concatenate([np.zeros((padded_height, padded_width, 3)), np.ones((padded_height, padded_width, 3)) * 255], axis=1)
279
  mask_image = Image.fromarray(mask_image.astype(np.uint8))
280
+ diptych_image_prompt = make_diptych(processed_image)
281
 
282
+ # 4. Setup Attention Processor
283
  progress(0.3, desc="Setting up attention processors...")
284
  new_attn_procs = base_attn_procs.copy()
285
  for k in new_attn_procs:
 
286
  new_attn_procs[k] = CustomFluxAttnProcessor2_0(height=padded_height // 16, width=padded_width * 2 // 16, attn_enforce=attn_enforce)
287
  pipe.transformer.set_attn_processor(new_attn_procs)
288
 
289
+ # 5. Run Inference
290
  progress(0.4, desc="Running diffusion process...")
291
  generator = torch.Generator(device="cuda").manual_seed(actual_seed)
292
+ full_diptych_result = pipe(
293
+ prompt=full_prompt,
294
  height=diptych_size[1],
295
  width=diptych_size[0],
296
  control_image=diptych_image_prompt,
 
303
  true_guidance_scale=real_guidance
304
  ).images[0]
305
 
306
+ # 6. Final cropping
307
  progress(0.95, desc="Finalizing image...")
308
+ final_image = full_diptych_result.crop((padded_width, 0, padded_width * 2, padded_height))
309
+ final_image = final_image.crop((pixel_offset, pixel_offset, padded_width - pixel_offset, padded_height - pixel_offset))
 
 
310
 
311
+ # 7. Return all outputs
312
+ return final_image, processed_image, full_diptych_result, full_prompt
313
 
314
 
315
  # --- Gradio UI Definition ---
 
319
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
320
  gr.Markdown(
321
  """
322
+ # Diptych Prompting: Zero-Shot Subject-Driven & Style-Driven Image Generation
323
  ### Gradio Demo for the paper "[Large-Scale Text-to-Image Model with Inpainting is a Zero-Shot Subject-Driven Image Generator](https://diptychprompting.github.io/)"
324
+ This demo implements both subject-driven generation and style transfer with advanced controls.
325
  """
326
  )
327
  with gr.Row():
328
  with gr.Column(scale=1):
329
+ input_image = gr.Image(type="pil", label="Reference Image")
330
+
331
+ with gr.Group() as subject_driven_group:
332
+ subject_name = gr.Textbox(label="Subject Name", placeholder="e.g., a plush bear")
333
+
334
+ target_prompt = gr.Textbox(label="Target Prompt", placeholder="e.g., riding a skateboard on the moon")
335
+
336
  run_button = gr.Button("Generate Image", variant="primary")
337
+
338
  with gr.Accordion("Advanced Settings", open=False):
339
+ mode = gr.Radio(["Subject-Driven", "Style-Driven (unstable)"], label="Generation Mode", value="Subject-Driven")
340
+ with gr.Group(visible=False) as style_driven_group:
341
+ original_style_description = gr.Textbox(label="Original Image Description", placeholder="e.g., in watercolor painting style")
342
+ do_segmentation = gr.Checkbox(label="Do Segmentation", value=True)
343
  attn_enforce = gr.Slider(minimum=1.0, maximum=2.0, value=1.3, step=0.05, label="Attention Enforcement")
344
+ full_prompt = gr.Textbox(label="Full Prompt (Auto-generated, editable)", lines=3)
345
  ctrl_scale = gr.Slider(minimum=0.5, maximum=1.0, value=0.95, step=0.01, label="ControlNet Scale")
346
  num_steps = gr.Slider(minimum=20, maximum=50, value=28, step=1, label="Inference Steps")
347
  guidance = gr.Slider(minimum=1.0, maximum=10.0, value=3.5, step=0.1, label="Distilled Guidance Scale")
 
351
  pixel_offset = gr.Slider(minimum=0, maximum=32, value=8, step=1, label="Padding (Pixel Offset)")
352
  seed = gr.Slider(minimum=0, maximum=9223372036854775807, value=42, step=1, label="Seed")
353
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
354
+
355
  with gr.Column(scale=1):
356
  output_image = gr.Image(type="pil", label="Generated Image")
357
+ with gr.Accordion("Other Outputs", open=False) as other_outputs_accordion:
358
+ processed_ref_image = gr.Image(label="Processed Reference (Left Panel)")
359
+ full_diptych_image = gr.Image(label="Full Diptych Output")
360
+ final_prompt_used = gr.Textbox(label="Final Prompt Used")
361
 
362
+ # --- UI Event Handlers ---
363
+
364
+ def toggle_mode_visibility(mode_choice):
365
+ """Hides/shows the relevant input textboxes based on mode."""
366
+ if mode_choice == "Subject-Driven":
367
+ return gr.update(visible=True), gr.update(visible=False)
368
+ else:
369
+ return gr.update(visible=False), gr.update(visible=True)
370
+
371
+ def update_derived_fields(mode_choice, subject, style_desc, target):
372
+ """Updates the full prompt and segmentation checkbox based on other inputs."""
373
+ if mode_choice == "Subject-Driven":
374
+ prompt = f"A diptych with two side-by-side images of same {subject}. On the left, a photo of {subject}. On the right, replicate this {subject} exactly but as {target}"
375
+ return gr.update(value=prompt), gr.update(value=True)
376
+ else: # Style-Driven
377
+ prompt = f"A diptych with two side-by-side images of same style. On the left, {style_desc}. On the right, replicate this style exactly but as {target}"
378
+ return gr.update(value=prompt), gr.update(value=False)
379
+
380
+ # --- UI Connections ---
381
+
382
+ # When mode changes, toggle visibility of the specific prompt fields
383
+ mode.change(
384
+ fn=toggle_mode_visibility,
385
+ inputs=mode,
386
+ outputs=[subject_driven_group, style_driven_group],
387
+ queue=False
388
+ )
389
+
390
+ # A list of all inputs that affect the full prompt or segmentation checkbox
391
+ prompt_component_inputs = [mode, subject_name, original_style_description, target_prompt]
392
+ # A list of the UI elements that are derived from the above inputs
393
+ derived_outputs = [full_prompt, do_segmentation]
394
+
395
+ # When any prompt component changes, update the derived fields
396
+ for component in prompt_component_inputs:
397
+ # Use .then() to chain the update after the visibility toggle for the mode radio
398
+ if component == mode:
399
+ component.change(update_derived_fields, inputs=prompt_component_inputs, outputs=derived_outputs, queue=False)
400
+ else:
401
+ component.input(update_derived_fields, inputs=prompt_component_inputs, outputs=derived_outputs, queue=False)
402
+
403
+ run_button.click(
404
+ fn=run_diptych_prompting,
405
+ inputs=[
406
+ input_image, subject_name, do_segmentation, full_prompt, attn_enforce,
407
+ ctrl_scale, width, height, pixel_offset, num_steps, guidance,
408
+ real_guidance, seed, randomize_seed
409
+ ],
410
+ outputs=[output_image, processed_ref_image, full_diptych_image, final_prompt_used]
411
+ )
412
+ def run_subject_driven_example(input_image, subject_name, target_prompt):
413
+ # Construct the full prompt for subject-driven mode
414
+ full_prompt = f"A diptych with two side-by-side images of same {subject_name}. On the left, a photo of {subject_name}. On the right, replicate this {subject_name} exactly but as {target_prompt}"
415
+
416
+ # Call the main function with all arguments, using defaults for subject-driven mode
417
+ return run_diptych_prompting(
418
+ input_image=input_image,
419
+ subject_name=subject_name,
420
+ do_segmentation=True,
421
+ full_prompt=full_prompt,
422
+ attn_enforce=1.3,
423
+ ctrl_scale=0.95,
424
+ width=768,
425
+ height=768,
426
+ pixel_offset=8,
427
+ num_steps=28,
428
+ guidance=3.5,
429
+ real_guidance=4.5,
430
+ seed=42,
431
+ randomize_seed=False,
432
+ )
433
  gr.Examples(
434
  examples=[
435
  ["./assets/cat_squished.png", "a cat toy", "a cat toy riding a skate"],
 
437
  ["./assets/bear_plushie.jpg", "a bear plushie", "a bear plushie drinking bubble tea"]
438
  ],
439
  inputs=[input_image, subject_name, target_prompt],
440
+ outputs=[output_image, processed_ref_image, full_diptych_image, final_prompt_used],
441
+ fn=run_subject_driven_example,
442
+ cache_examples="lazy"
 
 
 
 
 
 
443
  )
444
 
445
  if __name__ == "__main__":
446
+ demo.launch(share=True, debug=True)