3v324v23 commited on
Commit
a240a7e
·
1 Parent(s): 8c29878

confidence threshold

Browse files
Files changed (1) hide show
  1. app.py +14 -8
app.py CHANGED
@@ -2,21 +2,21 @@ import numpy as np
2
  import torch
3
  import gradio as gr
4
  from infer import detections
 
5
  import os
6
  os.system("mkdir data")
7
  os.system("mkdir data/models")
8
  os.system("wget https://www.cs.cmu.edu/~walt/models/walt_people.pth -O data/models/walt_people.pth")
9
  os.system("wget https://www.cs.cmu.edu/~walt/models/walt_vehicle.pth -O data/models/walt_vehicle.pth")
10
  '''
11
- '''
12
- def walt_demo(input_img):
13
  #detect_people = detections('configs/walt/walt_people.py', 'cuda:0', model_path='data/models/walt_people.pth')
14
  if torch.cuda.is_available() == False:
15
  device='cpu'
16
  else:
17
  device='cuda:0'
18
  #detect_people = detections('configs/walt/walt_people.py', device, model_path='data/models/walt_people.pth')
19
- detect = detections('configs/walt/walt_vehicle.py', device, model_path='data/models/walt_vehicle.pth', threshold=0.75)
20
 
21
  count = 0
22
  #img = detect_people.run_on_image(input_img)
@@ -45,9 +45,9 @@ article="""
45
  """
46
 
47
  examples = [
48
- 'demo/images/img_1.jpg',
49
- 'demo/images/img_2.jpg',
50
- 'demo/images/img_4.png',
51
  ]
52
 
53
  '''
@@ -58,9 +58,15 @@ img=walt_demo(img)
58
  cv2.imwrite(filename.replace('/images/','/results/'),img)
59
  cv2.imwrite('check.png',img)
60
  '''
 
 
 
 
 
 
61
  demo = gr.Interface(walt_demo,
62
- gr.Image(),
63
- "image",
64
  article=article,
65
  title=title,
66
  enable_queue=True,
 
2
  import torch
3
  import gradio as gr
4
  from infer import detections
5
+ '''
6
  import os
7
  os.system("mkdir data")
8
  os.system("mkdir data/models")
9
  os.system("wget https://www.cs.cmu.edu/~walt/models/walt_people.pth -O data/models/walt_people.pth")
10
  os.system("wget https://www.cs.cmu.edu/~walt/models/walt_vehicle.pth -O data/models/walt_vehicle.pth")
11
  '''
12
+ def walt_demo(input_img, confidence_threshold):
 
13
  #detect_people = detections('configs/walt/walt_people.py', 'cuda:0', model_path='data/models/walt_people.pth')
14
  if torch.cuda.is_available() == False:
15
  device='cpu'
16
  else:
17
  device='cuda:0'
18
  #detect_people = detections('configs/walt/walt_people.py', device, model_path='data/models/walt_people.pth')
19
+ detect = detections('configs/walt/walt_vehicle.py', device, model_path='data/models/walt_vehicle.pth', threshold=confidence_threshold)
20
 
21
  count = 0
22
  #img = detect_people.run_on_image(input_img)
 
45
  """
46
 
47
  examples = [
48
+ ['demo/images/img_1.jpg',0.8],
49
+ ['demo/images/img_2.jpg',0.8],
50
+ ['demo/images/img_4.png',0.85],
51
  ]
52
 
53
  '''
 
58
  cv2.imwrite(filename.replace('/images/','/results/'),img)
59
  cv2.imwrite('check.png',img)
60
  '''
61
+ confidence_threshold = gr.Slider(minimum=0.3,
62
+ maximum=1.0,
63
+ step=0.01,
64
+ value=1.0,
65
+ label="Amodal Detection Confidence Threshold")
66
+ inputs = [gr.Image(), confidence_threshold]
67
  demo = gr.Interface(walt_demo,
68
+ outputs="image",
69
+ inputs=inputs,
70
  article=article,
71
  title=title,
72
  enable_queue=True,