confidence threshold
Browse files
app.py
CHANGED
@@ -2,21 +2,21 @@ import numpy as np
|
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
from infer import detections
|
|
|
5 |
import os
|
6 |
os.system("mkdir data")
|
7 |
os.system("mkdir data/models")
|
8 |
os.system("wget https://www.cs.cmu.edu/~walt/models/walt_people.pth -O data/models/walt_people.pth")
|
9 |
os.system("wget https://www.cs.cmu.edu/~walt/models/walt_vehicle.pth -O data/models/walt_vehicle.pth")
|
10 |
'''
|
11 |
-
|
12 |
-
def walt_demo(input_img):
|
13 |
#detect_people = detections('configs/walt/walt_people.py', 'cuda:0', model_path='data/models/walt_people.pth')
|
14 |
if torch.cuda.is_available() == False:
|
15 |
device='cpu'
|
16 |
else:
|
17 |
device='cuda:0'
|
18 |
#detect_people = detections('configs/walt/walt_people.py', device, model_path='data/models/walt_people.pth')
|
19 |
-
detect = detections('configs/walt/walt_vehicle.py', device, model_path='data/models/walt_vehicle.pth', threshold=
|
20 |
|
21 |
count = 0
|
22 |
#img = detect_people.run_on_image(input_img)
|
@@ -45,9 +45,9 @@ article="""
|
|
45 |
"""
|
46 |
|
47 |
examples = [
|
48 |
-
'demo/images/img_1.jpg',
|
49 |
-
'demo/images/img_2.jpg',
|
50 |
-
'demo/images/img_4.png',
|
51 |
]
|
52 |
|
53 |
'''
|
@@ -58,9 +58,15 @@ img=walt_demo(img)
|
|
58 |
cv2.imwrite(filename.replace('/images/','/results/'),img)
|
59 |
cv2.imwrite('check.png',img)
|
60 |
'''
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
demo = gr.Interface(walt_demo,
|
62 |
-
|
63 |
-
|
64 |
article=article,
|
65 |
title=title,
|
66 |
enable_queue=True,
|
|
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
from infer import detections
|
5 |
+
'''
|
6 |
import os
|
7 |
os.system("mkdir data")
|
8 |
os.system("mkdir data/models")
|
9 |
os.system("wget https://www.cs.cmu.edu/~walt/models/walt_people.pth -O data/models/walt_people.pth")
|
10 |
os.system("wget https://www.cs.cmu.edu/~walt/models/walt_vehicle.pth -O data/models/walt_vehicle.pth")
|
11 |
'''
|
12 |
+
def walt_demo(input_img, confidence_threshold):
|
|
|
13 |
#detect_people = detections('configs/walt/walt_people.py', 'cuda:0', model_path='data/models/walt_people.pth')
|
14 |
if torch.cuda.is_available() == False:
|
15 |
device='cpu'
|
16 |
else:
|
17 |
device='cuda:0'
|
18 |
#detect_people = detections('configs/walt/walt_people.py', device, model_path='data/models/walt_people.pth')
|
19 |
+
detect = detections('configs/walt/walt_vehicle.py', device, model_path='data/models/walt_vehicle.pth', threshold=confidence_threshold)
|
20 |
|
21 |
count = 0
|
22 |
#img = detect_people.run_on_image(input_img)
|
|
|
45 |
"""
|
46 |
|
47 |
examples = [
|
48 |
+
['demo/images/img_1.jpg',0.8],
|
49 |
+
['demo/images/img_2.jpg',0.8],
|
50 |
+
['demo/images/img_4.png',0.85],
|
51 |
]
|
52 |
|
53 |
'''
|
|
|
58 |
cv2.imwrite(filename.replace('/images/','/results/'),img)
|
59 |
cv2.imwrite('check.png',img)
|
60 |
'''
|
61 |
+
confidence_threshold = gr.Slider(minimum=0.3,
|
62 |
+
maximum=1.0,
|
63 |
+
step=0.01,
|
64 |
+
value=1.0,
|
65 |
+
label="Amodal Detection Confidence Threshold")
|
66 |
+
inputs = [gr.Image(), confidence_threshold]
|
67 |
demo = gr.Interface(walt_demo,
|
68 |
+
outputs="image",
|
69 |
+
inputs=inputs,
|
70 |
article=article,
|
71 |
title=title,
|
72 |
enable_queue=True,
|