cdnuts commited on
Commit
9977d4a
·
verified ·
1 Parent(s): 8d39524
Files changed (1) hide show
  1. app.py +28 -5
app.py CHANGED
@@ -236,13 +236,29 @@ def create_tags(threshold, sorted_tag_score: dict):
236
  return text_no_impl, filtered_tag_score
237
 
238
  def clear_image():
239
- return "", {}, None, {}, None
 
 
 
 
 
 
 
240
 
241
  @spaces.GPU(duration=5)
242
  def cam_inference(img, threshold, alpha, selected_tag: str):
243
- target_tag_index = tags[selected_tag]
244
- tensor = transform(img).unsqueeze(0)
245
 
 
 
 
 
 
 
 
 
 
246
 
247
  if torch.cuda.is_available():
248
  tensor = tensor.to(device, dtype=torch.float16)
@@ -438,6 +454,7 @@ with gr.Blocks(css=custom_css) as demo:
438
  original_image_state = gr.State() # stash a copy of the input image
439
  sorted_tag_score_state = gr.State(value={}) # stash a copy of the input image
440
  cam_state = gr.State()
 
441
  with gr.Row():
442
  with gr.Column():
443
  image = gr.Image(label="Source", sources=['upload', 'clipboard'], type='pil', show_label=False, elem_id="image_container")
@@ -458,7 +475,7 @@ with gr.Blocks(css=custom_css) as demo:
458
  image.clear(
459
  fn=clear_image,
460
  inputs=[],
461
- outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state, cam_state]
462
  )
463
 
464
  threshold_slider.input(
@@ -469,8 +486,14 @@ with gr.Blocks(css=custom_css) as demo:
469
  )
470
 
471
  label_box.select(
 
 
 
 
 
 
472
  fn=cam_inference,
473
- inputs=[original_image_state, cam_slider, alpha_slider, label_box],
474
  outputs=[image, cam_state],
475
  show_progress='minimal'
476
  )
 
236
  return text_no_impl, filtered_tag_score
237
 
238
  def clear_image():
239
+ return "", {}, None, {}, None, None
240
+
241
+ def extract_selected_tag(evt: gr.SelectData):
242
+ # evt is a gr.SelectData; keep it out of GPU calls
243
+ try:
244
+ return evt.value
245
+ except Exception:
246
+ return None
247
 
248
  @spaces.GPU(duration=5)
249
  def cam_inference(img, threshold, alpha, selected_tag: str):
250
+ if img is None or not selected_tag:
251
+ return img, None
252
 
253
+ # Map to index
254
+ if selected_tag not in tags:
255
+ key = selected_tag.replace("_", " ")
256
+ if key not in tags:
257
+ return img, None
258
+ selected_tag = key
259
+
260
+ target_tag_index = tags[selected_tag]
261
+ tensor = transform(img).unsqueeze(0)
262
 
263
  if torch.cuda.is_available():
264
  tensor = tensor.to(device, dtype=torch.float16)
 
454
  original_image_state = gr.State() # stash a copy of the input image
455
  sorted_tag_score_state = gr.State(value={}) # stash a copy of the input image
456
  cam_state = gr.State()
457
+ selected_tag_state = gr.State(value=None)
458
  with gr.Row():
459
  with gr.Column():
460
  image = gr.Image(label="Source", sources=['upload', 'clipboard'], type='pil', show_label=False, elem_id="image_container")
 
475
  image.clear(
476
  fn=clear_image,
477
  inputs=[],
478
+ outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state, cam_state, selected_tag_state]
479
  )
480
 
481
  threshold_slider.input(
 
486
  )
487
 
488
  label_box.select(
489
+ fn=extract_selected_tag,
490
+ inputs=None,
491
+ outputs=selected_tag_state,
492
+ show_progress='hidden',
493
+ queue=False # This should be a very fast operation
494
+ ).then(
495
  fn=cam_inference,
496
+ inputs=[original_image_state, cam_slider, alpha_slider, selected_tag_state],
497
  outputs=[image, cam_state],
498
  show_progress='minimal'
499
  )