Spaces:
Running
on
Zero
Running
on
Zero
fix pt2
Browse files
app.py
CHANGED
@@ -236,13 +236,29 @@ def create_tags(threshold, sorted_tag_score: dict):
|
|
236 |
return text_no_impl, filtered_tag_score
|
237 |
|
238 |
def clear_image():
|
239 |
-
return "", {}, None, {}, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
|
241 |
@spaces.GPU(duration=5)
|
242 |
def cam_inference(img, threshold, alpha, selected_tag: str):
|
243 |
-
|
244 |
-
|
245 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
|
247 |
if torch.cuda.is_available():
|
248 |
tensor = tensor.to(device, dtype=torch.float16)
|
@@ -438,6 +454,7 @@ with gr.Blocks(css=custom_css) as demo:
|
|
438 |
original_image_state = gr.State() # stash a copy of the input image
|
439 |
sorted_tag_score_state = gr.State(value={}) # stash a copy of the input image
|
440 |
cam_state = gr.State()
|
|
|
441 |
with gr.Row():
|
442 |
with gr.Column():
|
443 |
image = gr.Image(label="Source", sources=['upload', 'clipboard'], type='pil', show_label=False, elem_id="image_container")
|
@@ -458,7 +475,7 @@ with gr.Blocks(css=custom_css) as demo:
|
|
458 |
image.clear(
|
459 |
fn=clear_image,
|
460 |
inputs=[],
|
461 |
-
outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state, cam_state]
|
462 |
)
|
463 |
|
464 |
threshold_slider.input(
|
@@ -469,8 +486,14 @@ with gr.Blocks(css=custom_css) as demo:
|
|
469 |
)
|
470 |
|
471 |
label_box.select(
|
|
|
|
|
|
|
|
|
|
|
|
|
472 |
fn=cam_inference,
|
473 |
-
inputs=[original_image_state, cam_slider, alpha_slider,
|
474 |
outputs=[image, cam_state],
|
475 |
show_progress='minimal'
|
476 |
)
|
|
|
236 |
return text_no_impl, filtered_tag_score
|
237 |
|
238 |
def clear_image():
|
239 |
+
return "", {}, None, {}, None, None
|
240 |
+
|
241 |
+
def extract_selected_tag(evt: gr.SelectData):
|
242 |
+
# evt is a gr.SelectData; keep it out of GPU calls
|
243 |
+
try:
|
244 |
+
return evt.value
|
245 |
+
except Exception:
|
246 |
+
return None
|
247 |
|
248 |
@spaces.GPU(duration=5)
|
249 |
def cam_inference(img, threshold, alpha, selected_tag: str):
|
250 |
+
if img is None or not selected_tag:
|
251 |
+
return img, None
|
252 |
|
253 |
+
# Map to index
|
254 |
+
if selected_tag not in tags:
|
255 |
+
key = selected_tag.replace("_", " ")
|
256 |
+
if key not in tags:
|
257 |
+
return img, None
|
258 |
+
selected_tag = key
|
259 |
+
|
260 |
+
target_tag_index = tags[selected_tag]
|
261 |
+
tensor = transform(img).unsqueeze(0)
|
262 |
|
263 |
if torch.cuda.is_available():
|
264 |
tensor = tensor.to(device, dtype=torch.float16)
|
|
|
454 |
original_image_state = gr.State() # stash a copy of the input image
|
455 |
sorted_tag_score_state = gr.State(value={}) # stash a copy of the input image
|
456 |
cam_state = gr.State()
|
457 |
+
selected_tag_state = gr.State(value=None)
|
458 |
with gr.Row():
|
459 |
with gr.Column():
|
460 |
image = gr.Image(label="Source", sources=['upload', 'clipboard'], type='pil', show_label=False, elem_id="image_container")
|
|
|
475 |
image.clear(
|
476 |
fn=clear_image,
|
477 |
inputs=[],
|
478 |
+
outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state, cam_state, selected_tag_state]
|
479 |
)
|
480 |
|
481 |
threshold_slider.input(
|
|
|
486 |
)
|
487 |
|
488 |
label_box.select(
|
489 |
+
fn=extract_selected_tag,
|
490 |
+
inputs=None,
|
491 |
+
outputs=selected_tag_state,
|
492 |
+
show_progress='hidden',
|
493 |
+
queue=False # This should be a very fast operation
|
494 |
+
).then(
|
495 |
fn=cam_inference,
|
496 |
+
inputs=[original_image_state, cam_slider, alpha_slider, selected_tag_state],
|
497 |
outputs=[image, cam_state],
|
498 |
show_progress='minimal'
|
499 |
)
|