cdnuts commited on
Commit
8d39524
·
verified ·
1 Parent(s): 6fb195f

gr.selectdata can't be pickled fix

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -239,15 +239,15 @@ def clear_image():
239
  return "", {}, None, {}, None
240
 
241
  @spaces.GPU(duration=5)
242
- def cam_inference(img, threshold, alpha, evt: gr.SelectData):
243
- target_tag_index = tags[evt.value]
244
  tensor = transform(img).unsqueeze(0)
245
 
 
246
  if torch.cuda.is_available():
247
  tensor = tensor.to(device, dtype=torch.float16)
248
  else:
249
  tensor = tensor.to(device)
250
-
251
  tensor.requires_grad_()
252
 
253
  gradients = {}
@@ -263,7 +263,7 @@ def cam_inference(img, threshold, alpha, evt: gr.SelectData):
263
  handle_backward = model.norm.register_full_backward_hook(hook_backward)
264
 
265
  probits = model(tensor)[0]
266
-
267
  model.zero_grad()
268
  probits[target_tag_index].backward(retain_graph=True)
269
 
@@ -470,7 +470,7 @@ with gr.Blocks(css=custom_css) as demo:
470
 
471
  label_box.select(
472
  fn=cam_inference,
473
- inputs=[original_image_state, cam_slider, alpha_slider],
474
  outputs=[image, cam_state],
475
  show_progress='minimal'
476
  )
 
239
  return "", {}, None, {}, None
240
 
241
  @spaces.GPU(duration=5)
242
+ def cam_inference(img, threshold, alpha, selected_tag: str):
243
+ target_tag_index = tags[selected_tag]
244
  tensor = transform(img).unsqueeze(0)
245
 
246
+
247
  if torch.cuda.is_available():
248
  tensor = tensor.to(device, dtype=torch.float16)
249
  else:
250
  tensor = tensor.to(device)
 
251
  tensor.requires_grad_()
252
 
253
  gradients = {}
 
263
  handle_backward = model.norm.register_full_backward_hook(hook_backward)
264
 
265
  probits = model(tensor)[0]
266
+
267
  model.zero_grad()
268
  probits[target_tag_index].backward(retain_graph=True)
269
 
 
470
 
471
  label_box.select(
472
  fn=cam_inference,
473
+ inputs=[original_image_state, cam_slider, alpha_slider, label_box],
474
  outputs=[image, cam_state],
475
  show_progress='minimal'
476
  )