WensongSong commited on
Commit
3606138
·
verified ·
1 Parent(s): 72509e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -5
app.py CHANGED
@@ -143,6 +143,25 @@ def get_mask(image, label):
143
 
144
  return result_mask
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  hf_token = os.getenv("HF_TOKEN")
148
 
@@ -278,6 +297,8 @@ def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_
278
  tar_image = cv2.resize(tar_image, size)
279
  diptych_ref_tar = np.concatenate([masked_ref_image, tar_image], axis=1)
280
 
 
 
281
 
282
  tar_mask = np.stack([tar_mask,tar_mask,tar_mask],-1)
283
  mask_black = np.ones_like(tar_image) * 0
@@ -316,9 +337,9 @@ def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_
316
  edited_image = Image.fromarray(edited_image)
317
 
318
  if ref_mask_option != "Label to Mask":
319
- return [edited_image]
320
  else:
321
- return [return_ref_mask, edited_image]
322
 
323
  def update_ui(option):
324
  if option == "Draw Mask":
@@ -357,15 +378,18 @@ with gr.Blocks() as demo:
357
  ref_mask_option = gr.Radio(["Draw Mask", "Upload with Mask", "Label to Mask"], label="Reference Mask Input Option", value="Upload with Mask")
358
 
359
  with gr.Row():
360
- text_prompt = gr.Textbox(label="Label")
361
 
362
  with gr.Column(scale=1):
363
- baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", height=765, columns=1)
364
  with gr.Accordion("Advanced Option", open=True):
365
  seed = gr.Slider(label="Seed", minimum=-1, maximum=999999999, step=1, value=666)
366
  gr.Markdown("### Guidelines")
367
  gr.Markdown(" Users can try using different seeds. For example, seeds like 42 and 123456 may produce different effects.")
368
- gr.Markdown(" Label to Mask means generating a mask by simply inputting a label.")
 
 
 
369
 
370
  run_local_button = gr.Button(value="Run")
371
 
 
143
 
144
  return result_mask
145
 
146
+ def create_highlighted_mask(image_np, mask_np, alpha=0.5, gray_value=128):
147
+
148
+
149
+ if mask_np.max() <= 1.0:
150
+ mask_np = (mask_np * 255).astype(np.uint8)
151
+ mask_bool = mask_np > 128
152
+
153
+ image_float = image_np.astype(np.float32)
154
+
155
+ # 灰色图层
156
+ gray_overlay = np.full_like(image_float, gray_value, dtype=np.float32)
157
+
158
+ # 混合
159
+ result = image_float.copy()
160
+ result[mask_bool] = (
161
+ (1 - alpha) * image_float[mask_bool] + alpha * gray_overlay[mask_bool]
162
+ )
163
+
164
+ return result.astype(np.uint8)
165
 
166
  hf_token = os.getenv("HF_TOKEN")
167
 
 
297
  tar_image = cv2.resize(tar_image, size)
298
  diptych_ref_tar = np.concatenate([masked_ref_image, tar_image], axis=1)
299
 
300
+ show_diptych_ref_tar = create_highlighted_mask(diptych_ref_tar, mask_diptych)
301
+ show_diptych_ref_tar = Image.fromarray(show_diptych_ref_tar)
302
 
303
  tar_mask = np.stack([tar_mask,tar_mask,tar_mask],-1)
304
  mask_black = np.ones_like(tar_image) * 0
 
337
  edited_image = Image.fromarray(edited_image)
338
 
339
  if ref_mask_option != "Label to Mask":
340
+ return [edited_image, show_diptych_ref_tar]
341
  else:
342
+ return [return_ref_mask, show_diptych_ref_tar, edited_image]
343
 
344
  def update_ui(option):
345
  if option == "Draw Mask":
 
378
  ref_mask_option = gr.Radio(["Draw Mask", "Upload with Mask", "Label to Mask"], label="Reference Mask Input Option", value="Upload with Mask")
379
 
380
  with gr.Row():
381
+ text_prompt = gr.Textbox(label="Label", placeholder="Enter the category of the reference object, e.g., car, dress, toy, etc.")
382
 
383
  with gr.Column(scale=1):
384
+ baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", height=705, columns=1)
385
  with gr.Accordion("Advanced Option", open=True):
386
  seed = gr.Slider(label="Seed", minimum=-1, maximum=999999999, step=1, value=666)
387
  gr.Markdown("### Guidelines")
388
  gr.Markdown(" Users can try using different seeds. For example, seeds like 42 and 123456 may produce different effects.")
389
+ gr.Markdown(" Draw Mask means manually drawing a mask on the original image.")
390
+ gr.Markdown(" Upload with Mask means uploading a mask file.")
391
+ gr.Markdown(" Label to Mask means simply inputting a label to automatically extract the mask and obtain the result.")
392
+
393
 
394
  run_local_button = gr.Button(value="Run")
395