Spaces:
Running
on
Zero
Running
on
Zero
gr.selectdata can't be pickled fix
Browse files
app.py
CHANGED
@@ -239,15 +239,15 @@ def clear_image():
|
|
239 |
return "", {}, None, {}, None
|
240 |
|
241 |
@spaces.GPU(duration=5)
|
242 |
-
def cam_inference(img, threshold, alpha,
|
243 |
-
target_tag_index = tags[
|
244 |
tensor = transform(img).unsqueeze(0)
|
245 |
|
|
|
246 |
if torch.cuda.is_available():
|
247 |
tensor = tensor.to(device, dtype=torch.float16)
|
248 |
else:
|
249 |
tensor = tensor.to(device)
|
250 |
-
|
251 |
tensor.requires_grad_()
|
252 |
|
253 |
gradients = {}
|
@@ -263,7 +263,7 @@ def cam_inference(img, threshold, alpha, evt: gr.SelectData):
|
|
263 |
handle_backward = model.norm.register_full_backward_hook(hook_backward)
|
264 |
|
265 |
probits = model(tensor)[0]
|
266 |
-
|
267 |
model.zero_grad()
|
268 |
probits[target_tag_index].backward(retain_graph=True)
|
269 |
|
@@ -470,7 +470,7 @@ with gr.Blocks(css=custom_css) as demo:
|
|
470 |
|
471 |
label_box.select(
|
472 |
fn=cam_inference,
|
473 |
-
inputs=[original_image_state, cam_slider, alpha_slider],
|
474 |
outputs=[image, cam_state],
|
475 |
show_progress='minimal'
|
476 |
)
|
|
|
239 |
return "", {}, None, {}, None
|
240 |
|
241 |
@spaces.GPU(duration=5)
|
242 |
+
def cam_inference(img, threshold, alpha, selected_tag: str):
|
243 |
+
target_tag_index = tags[selected_tag]
|
244 |
tensor = transform(img).unsqueeze(0)
|
245 |
|
246 |
+
|
247 |
if torch.cuda.is_available():
|
248 |
tensor = tensor.to(device, dtype=torch.float16)
|
249 |
else:
|
250 |
tensor = tensor.to(device)
|
|
|
251 |
tensor.requires_grad_()
|
252 |
|
253 |
gradients = {}
|
|
|
263 |
handle_backward = model.norm.register_full_backward_hook(hook_backward)
|
264 |
|
265 |
probits = model(tensor)[0]
|
266 |
+
|
267 |
model.zero_grad()
|
268 |
probits[target_tag_index].backward(retain_graph=True)
|
269 |
|
|
|
470 |
|
471 |
label_box.select(
|
472 |
fn=cam_inference,
|
473 |
+
inputs=[original_image_state, cam_slider, alpha_slider, label_box],
|
474 |
outputs=[image, cam_state],
|
475 |
show_progress='minimal'
|
476 |
)
|