testing git upload
Browse files- LICENSE +20 -0
- app.py +49 -9
- cwalt/CWALT.py +161 -0
- cwalt/Clip_WALT_Generate.py +284 -0
- cwalt/Download_Detections.py +28 -0
- cwalt/clustering_utils.py +132 -0
- cwalt/kmedoid.py +55 -0
- cwalt/utils.py +168 -0
- cwalt_generate.py +14 -0
- infer.py +114 -0
- test.py +226 -0
- train.py +191 -0
LICENSE
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2022-2022 dinesh reddy and others
|
2 |
+
|
3 |
+
Permission is hereby granted, free of charge, to any person obtaining
|
4 |
+
a copy of this software and associated documentation files (the
|
5 |
+
"Software"), to deal in the Software without restriction, including
|
6 |
+
without limitation the rights to use, copy, modify, merge, publish,
|
7 |
+
distribute, sublicense, and/or sell copies of the Software, and to
|
8 |
+
permit persons to whom the Software is furnished to do so, subject to
|
9 |
+
the following conditions:
|
10 |
+
|
11 |
+
The above copyright notice and this permission notice shall be
|
12 |
+
included in all copies or substantial portions of the Software.
|
13 |
+
|
14 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
15 |
+
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
16 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
17 |
+
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
18 |
+
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
19 |
+
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
20 |
+
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
app.py
CHANGED
@@ -1,17 +1,57 @@
|
|
1 |
import numpy as np
|
2 |
|
3 |
import gradio as gr
|
|
|
4 |
|
5 |
-
def
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
12 |
|
13 |
-
|
14 |
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
|
|
|
1 |
import numpy as np
|
2 |
|
3 |
import gradio as gr
|
4 |
+
from infer import detections
|
5 |
|
6 |
+
def walt_demo(input_img):
|
7 |
+
#detect_people = detections('configs/walt/walt_people.py', 'cuda:0', model_path='data/models/walt_people.pth')
|
8 |
+
detect = detections('configs/walt/walt_vehicle.py', 'cuda:0', model_path='data/models/walt_vehicle.pth')
|
9 |
+
count = 0
|
10 |
+
#img = detect_people.run_on_image(input_img)
|
11 |
+
output_img = detect.run_on_image(input_img)
|
12 |
+
#try:
|
13 |
+
#except:
|
14 |
+
# print("detecting on image failed")
|
15 |
|
16 |
+
return output_img
|
17 |
|
18 |
+
description = """
|
19 |
+
WALT Demo on WALT dataset. After watching and automatically learning for several days, this approach shows significant performance improvement in detecting and segmenting occluded people and vehicles, over human-supervised amodal approaches</b>.
|
20 |
+
<center>
|
21 |
+
<a href="https://www.cs.cmu.edu/~walt/">
|
22 |
+
<img style="display:inline" alt="Project page" src="https://img.shields.io/badge/Project%20Page-WALT-green">
|
23 |
+
</a>
|
24 |
+
<a href="https://www.cs.cmu.edu/~walt/pdf/walt.pdf"><img style="display:inline" src="https://img.shields.io/badge/Paper-Pdf-red"></a>
|
25 |
+
<a href="https://github.com/dineshreddy91/WALT"><img style="display:inline" src="https://img.shields.io/github/stars/dineshreddy91/WALT?style=social"></a>
|
26 |
+
</center>
|
27 |
+
"""
|
28 |
+
title = "WALT:Watch And Learn 2D Amodal Representation using Time-lapse Imagery"
|
29 |
+
article="""
|
30 |
+
<center>
|
31 |
+
<img src='https://visitor-badge.glitch.me/badge?page_id=anhquancao.MonoScene&left_color=darkmagenta&right_color=purple' alt='visitor badge'>
|
32 |
+
</center>
|
33 |
+
"""
|
34 |
+
|
35 |
+
examples = [
|
36 |
+
'demo/images/img_1.jpg',
|
37 |
+
]
|
38 |
+
|
39 |
+
|
40 |
+
import cv2
|
41 |
+
filename='demo/images/img_1.jpg'
|
42 |
+
img=cv2.imread(filename)
|
43 |
+
img=walt_demo(img)
|
44 |
+
cv2.imwrite(filename.replace('demo','demo/results/'),img)
|
45 |
+
|
46 |
+
demo = gr.Interface(walt_demo,
|
47 |
+
gr.Image(),
|
48 |
+
"image",
|
49 |
+
article=article,
|
50 |
+
title=title,
|
51 |
+
enable_queue=True,
|
52 |
+
examples=examples,
|
53 |
+
description=description)
|
54 |
+
|
55 |
+
demo.launch(server_name="0.0.0.0", server_port=7000)
|
56 |
|
57 |
|
cwalt/CWALT.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Tue Oct 19 19:14:47 2021
|
5 |
+
|
6 |
+
@author: dinesh
|
7 |
+
"""
|
8 |
+
import glob
|
9 |
+
from .utils import bb_intersection_over_union_unoccluded
|
10 |
+
import numpy as np
|
11 |
+
from PIL import Image
|
12 |
+
import datetime
|
13 |
+
import cv2
|
14 |
+
import os
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
|
18 |
+
def get_image(time, folder):
|
19 |
+
for week_loop in range(5):
|
20 |
+
try:
|
21 |
+
image = np.array(Image.open(folder+'/week' +str(week_loop)+'/'+ str(time).replace(' ','T').replace(':','-').split('+')[0] + '.jpg'))
|
22 |
+
break
|
23 |
+
except:
|
24 |
+
continue
|
25 |
+
if image is None:
|
26 |
+
print('file not found')
|
27 |
+
return image
|
28 |
+
|
29 |
+
def get_mask(segm, image):
|
30 |
+
poly = np.array(segm).reshape((int(len(segm)/2), 2))
|
31 |
+
mask = image.copy()*0
|
32 |
+
cv2.fillConvexPoly(mask, poly, (255, 255, 255))
|
33 |
+
return mask
|
34 |
+
|
35 |
+
def get_unoccluded(indices, tracks_all):
|
36 |
+
unoccluded_indexes = []
|
37 |
+
unoccluded_index_all =[]
|
38 |
+
while 1:
|
39 |
+
unoccluded_clusters = []
|
40 |
+
len_unocc = len(unoccluded_indexes)
|
41 |
+
for ind in indices:
|
42 |
+
if ind in unoccluded_indexes:
|
43 |
+
continue
|
44 |
+
occ = False
|
45 |
+
for ind_compare in indices:
|
46 |
+
if ind_compare in unoccluded_indexes:
|
47 |
+
continue
|
48 |
+
if bb_intersection_over_union_unoccluded(tracks_all[ind], tracks_all[ind_compare]) > 0.01 and ind_compare != ind:
|
49 |
+
occ = True
|
50 |
+
if occ==False:
|
51 |
+
unoccluded_indexes.extend([ind])
|
52 |
+
unoccluded_clusters.extend([ind])
|
53 |
+
if len(unoccluded_indexes) == len_unocc and len_unocc != 0:
|
54 |
+
for ind in indices:
|
55 |
+
if ind not in unoccluded_indexes:
|
56 |
+
unoccluded_indexes.extend([ind])
|
57 |
+
unoccluded_clusters.extend([ind])
|
58 |
+
|
59 |
+
unoccluded_index_all.append(unoccluded_clusters)
|
60 |
+
if len(unoccluded_indexes) > len(indices)-5:
|
61 |
+
break
|
62 |
+
return unoccluded_index_all
|
63 |
+
|
64 |
+
def primes(n): # simple sieve of multiples
|
65 |
+
odds = range(3, n+1, 2)
|
66 |
+
sieve = set(sum([list(range(q*q, n+1, q+q)) for q in odds], []))
|
67 |
+
return [2] + [p for p in odds if p not in sieve]
|
68 |
+
|
69 |
+
def save_image(image_read, save_path, data, path):
|
70 |
+
tracks = data['tracks_all_unoccluded']
|
71 |
+
segmentations = data['segmentation_all_unoccluded']
|
72 |
+
timestamps = data['timestamps_final_unoccluded']
|
73 |
+
|
74 |
+
image = image_read.copy()
|
75 |
+
indices = np.random.randint(len(tracks),size=30)
|
76 |
+
prime_numbers = primes(1000)
|
77 |
+
unoccluded_index_all = get_unoccluded(indices, tracks)
|
78 |
+
|
79 |
+
mask_stacked = image*0
|
80 |
+
mask_stacked_all =[]
|
81 |
+
count = 0
|
82 |
+
time = datetime.datetime.now()
|
83 |
+
|
84 |
+
for l in indices:
|
85 |
+
try:
|
86 |
+
image_crop = get_image(timestamps[l], path)
|
87 |
+
except:
|
88 |
+
continue
|
89 |
+
try:
|
90 |
+
bb_left, bb_top, bb_width, bb_height, confidence = tracks[l]
|
91 |
+
except:
|
92 |
+
bb_left, bb_top, bb_width, bb_height, confidence, track_id = tracks[l]
|
93 |
+
mask = get_mask(segmentations[l], image)
|
94 |
+
|
95 |
+
image[mask > 0] = image_crop[mask > 0]
|
96 |
+
mask[mask > 0] = 1
|
97 |
+
for count, mask_inc in enumerate(mask_stacked_all):
|
98 |
+
mask_stacked_all[count][cv2.bitwise_and(mask, mask_inc) > 0] = 2
|
99 |
+
mask_stacked_all.append(mask)
|
100 |
+
mask_stacked += mask
|
101 |
+
count = count+1
|
102 |
+
|
103 |
+
cv2.imwrite(save_path + '/images/'+str(time).replace(' ','T').replace(':','-').split('+')[0] + '.jpg', image[:, :, ::-1])
|
104 |
+
cv2.imwrite(save_path + '/Segmentation/'+str(time).replace(' ','T').replace(':','-').split('+')[0] + '.jpg', mask_stacked[:, :, ::-1]*30)
|
105 |
+
np.savez_compressed(save_path+'/Segmentation/'+str(time).replace(' ','T').replace(':','-').split('+')[0], mask=mask_stacked_all)
|
106 |
+
|
107 |
+
def CWALT_Generation(camera_name):
|
108 |
+
save_path_train = 'data/cwalt_train'
|
109 |
+
save_path_test = 'data/cwalt_test'
|
110 |
+
|
111 |
+
json_file_path = 'data/{}/{}.json'.format(camera_name,camera_name) # iii1/iii1_7_test.json' # './data.json'
|
112 |
+
path = 'data/' + camera_name
|
113 |
+
|
114 |
+
data = np.load(json_file_path + '.npz', allow_pickle=True)
|
115 |
+
|
116 |
+
## slip data
|
117 |
+
|
118 |
+
data_train=dict()
|
119 |
+
data_test=dict()
|
120 |
+
|
121 |
+
split_index = int(len(data['timestamps_final_unoccluded'])*0.8)
|
122 |
+
|
123 |
+
data_train['tracks_all_unoccluded'] = data['tracks_all_unoccluded'][0:split_index]
|
124 |
+
data_train['segmentation_all_unoccluded'] = data['segmentation_all_unoccluded'][0:split_index]
|
125 |
+
data_train['timestamps_final_unoccluded'] = data['timestamps_final_unoccluded'][0:split_index]
|
126 |
+
|
127 |
+
data_test['tracks_all_unoccluded'] = data['tracks_all_unoccluded'][split_index:]
|
128 |
+
data_test['segmentation_all_unoccluded'] = data['segmentation_all_unoccluded'][split_index:]
|
129 |
+
data_test['timestamps_final_unoccluded'] = data['timestamps_final_unoccluded'][split_index:]
|
130 |
+
|
131 |
+
image_read = np.array(Image.open(path + '/T18-median_image.jpg'))
|
132 |
+
image_read = cv2.resize(image_read, (int(image_read.shape[1]/2), int(image_read.shape[0]/2)))
|
133 |
+
|
134 |
+
try:
|
135 |
+
os.mkdir(save_path_train)
|
136 |
+
except:
|
137 |
+
print(save_path_train)
|
138 |
+
|
139 |
+
try:
|
140 |
+
os.mkdir(save_path_train + '/images')
|
141 |
+
os.mkdir(save_path_train + '/Segmentation')
|
142 |
+
except:
|
143 |
+
print(save_path_train+ '/images')
|
144 |
+
|
145 |
+
try:
|
146 |
+
os.mkdir(save_path_test)
|
147 |
+
except:
|
148 |
+
print(save_path_test)
|
149 |
+
|
150 |
+
try:
|
151 |
+
os.mkdir(save_path_test + '/images')
|
152 |
+
os.mkdir(save_path_test + '/Segmentation')
|
153 |
+
except:
|
154 |
+
print(save_path_test+ '/images')
|
155 |
+
|
156 |
+
for loop in tqdm(range(3000), desc="Generating training CWALT Images "):
|
157 |
+
save_image(image_read, save_path_train, data_train, path)
|
158 |
+
|
159 |
+
for loop in tqdm(range(300), desc="Generating testing CWALT Images "):
|
160 |
+
save_image(image_read, save_path_test, data_test, path)
|
161 |
+
|
cwalt/Clip_WALT_Generate.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Fri May 20 15:15:11 2022
|
5 |
+
|
6 |
+
@author: dinesh
|
7 |
+
"""
|
8 |
+
|
9 |
+
from collections import OrderedDict
|
10 |
+
from matplotlib import pyplot as plt
|
11 |
+
from .utils import *
|
12 |
+
import scipy.interpolate
|
13 |
+
|
14 |
+
from scipy import interpolate
|
15 |
+
from .clustering_utils import *
|
16 |
+
import glob
|
17 |
+
import cv2
|
18 |
+
from PIL import Image
|
19 |
+
|
20 |
+
|
21 |
+
import json
|
22 |
+
import cv2
|
23 |
+
|
24 |
+
import numpy as np
|
25 |
+
from tqdm import tqdm
|
26 |
+
|
27 |
+
|
28 |
+
def ignore_indexes(tracks_all, labels_all):
|
29 |
+
# get repeating bounding boxes
|
30 |
+
get_indexes = lambda x, xs: [i for (y, i) in zip(xs, range(len(xs))) if x == y]
|
31 |
+
ignore_ind = []
|
32 |
+
for index, track in enumerate(tracks_all):
|
33 |
+
print('in ignore', index, len(tracks_all))
|
34 |
+
if index in ignore_ind:
|
35 |
+
continue
|
36 |
+
|
37 |
+
if labels_all[index] < 1 or labels_all[index] > 3:
|
38 |
+
ignore_ind.extend([index])
|
39 |
+
|
40 |
+
ind = get_indexes(track, tracks_all)
|
41 |
+
if len(ind) > 30:
|
42 |
+
ignore_ind.extend(ind)
|
43 |
+
|
44 |
+
return ignore_ind
|
45 |
+
|
46 |
+
def repeated_indexes_old(tracks_all,ignore_ind, unoccluded_indexes=None):
|
47 |
+
# get repeating bounding boxes
|
48 |
+
get_indexes = lambda x, xs: [i for (y, i) in zip(xs, range(len(xs))) if bb_intersection_over_union(x, y) > 0.8 and i not in ignore_ind]
|
49 |
+
repeat_ind = []
|
50 |
+
repeat_inds =[]
|
51 |
+
if unoccluded_indexes == None:
|
52 |
+
for index, track in enumerate(tracks_all):
|
53 |
+
if index in repeat_ind or index in ignore_ind:
|
54 |
+
continue
|
55 |
+
ind = get_indexes(track, tracks_all)
|
56 |
+
if len(ind) > 20:
|
57 |
+
repeat_ind.extend(ind)
|
58 |
+
repeat_inds.append([ind,track])
|
59 |
+
else:
|
60 |
+
for index in unoccluded_indexes:
|
61 |
+
if index in repeat_ind or index in ignore_ind:
|
62 |
+
continue
|
63 |
+
ind = get_indexes(tracks_all[index], tracks_all)
|
64 |
+
if len(ind) > 3:
|
65 |
+
repeat_ind.extend(ind)
|
66 |
+
repeat_inds.append([ind,tracks_all[index]])
|
67 |
+
return repeat_inds
|
68 |
+
|
69 |
+
def get_unoccluded_instances(timestamps_final, tracks_all, ignore_ind=[], threshold = 0.01):
|
70 |
+
get_indexes = lambda x, xs: [i for (y, i) in zip(xs, range(len(xs))) if x==y]
|
71 |
+
unoccluded_indexes = []
|
72 |
+
time_checked = []
|
73 |
+
stationary_obj = []
|
74 |
+
count =0
|
75 |
+
|
76 |
+
for time in tqdm(np.unique(timestamps_final), desc="Detecting Unocclued objects in Image "):
|
77 |
+
count += 1
|
78 |
+
if [time.year,time.month, time.day, time.hour, time.minute, time.second, time.microsecond] in time_checked:
|
79 |
+
analyze_bb = []
|
80 |
+
for ind in unoccluded_indexes_time:
|
81 |
+
for ind_compare in same_time_instances:
|
82 |
+
iou = bb_intersection_over_union(tracks_all[ind], tracks_all[ind_compare])
|
83 |
+
if iou < 0.5 and iou > 0:
|
84 |
+
analyze_bb.extend([ind_compare])
|
85 |
+
if iou > 0.99:
|
86 |
+
stationary_obj.extend([str(ind_compare)+'+'+str(ind)])
|
87 |
+
|
88 |
+
for ind in analyze_bb:
|
89 |
+
occ = False
|
90 |
+
for ind_compare in same_time_instances:
|
91 |
+
if bb_intersection_over_union_unoccluded(tracks_all[ind], tracks_all[ind_compare], threshold=threshold) > threshold and ind_compare != ind:
|
92 |
+
occ = True
|
93 |
+
break
|
94 |
+
if occ == False:
|
95 |
+
unoccluded_indexes.extend([ind])
|
96 |
+
continue
|
97 |
+
|
98 |
+
same_time_instances = get_indexes(time,timestamps_final)
|
99 |
+
unoccluded_indexes_time = []
|
100 |
+
|
101 |
+
for ind in same_time_instances:
|
102 |
+
if tracks_all[ind][4] < 0.9 or ind in ignore_ind:# or ind != 1859:
|
103 |
+
continue
|
104 |
+
occ = False
|
105 |
+
for ind_compare in same_time_instances:
|
106 |
+
if bb_intersection_over_union_unoccluded(tracks_all[ind], tracks_all[ind_compare], threshold=threshold) > threshold and ind_compare != ind and tracks_all[ind_compare][4] < 0.5:
|
107 |
+
occ = True
|
108 |
+
break
|
109 |
+
if occ==False:
|
110 |
+
unoccluded_indexes.extend([ind])
|
111 |
+
unoccluded_indexes_time.extend([ind])
|
112 |
+
time_checked.append([time.year,time.month, time.day, time.hour, time.minute, time.second, time.microsecond])
|
113 |
+
return unoccluded_indexes,stationary_obj
|
114 |
+
|
115 |
+
def visualize_unoccluded_detection(timestamps_final,tracks_all,segmentation_all, unoccluded_indexes, cwalt_data_path, camera_name, ignore_ind=[]):
|
116 |
+
tracks_final = []
|
117 |
+
tracks_final.append([])
|
118 |
+
try:
|
119 |
+
os.mkdir(cwalt_data_path + '/' + camera_name+'_unoccluded_car_detection/')
|
120 |
+
except:
|
121 |
+
print('Unoccluded debugging exists')
|
122 |
+
|
123 |
+
for time in tqdm(np.unique(timestamps_final), desc="Visualizing Unocclued objects in Image "):
|
124 |
+
get_indexes = lambda x, xs: [i for (y, i) in zip(xs, range(len(xs))) if x==y]
|
125 |
+
ind = get_indexes(time, timestamps_final)
|
126 |
+
image_unocc = False
|
127 |
+
for index in ind:
|
128 |
+
if index not in unoccluded_indexes:
|
129 |
+
continue
|
130 |
+
else:
|
131 |
+
image_unocc = True
|
132 |
+
break
|
133 |
+
if image_unocc == False:
|
134 |
+
continue
|
135 |
+
|
136 |
+
for week_loop in range(5):
|
137 |
+
try:
|
138 |
+
image = np.array(Image.open(cwalt_data_path+'/week' +str(week_loop)+'/'+ str(time).replace(' ','T').replace(':','-').split('+')[0] + '.jpg'))
|
139 |
+
break
|
140 |
+
except:
|
141 |
+
continue
|
142 |
+
|
143 |
+
try:
|
144 |
+
mask = image*0
|
145 |
+
except:
|
146 |
+
print('image not found for ' + str(time).replace(' ','T').replace(':','-').split('+')[0] + '.jpg' )
|
147 |
+
continue
|
148 |
+
image_original = image.copy()
|
149 |
+
|
150 |
+
for index in ind:
|
151 |
+
track = tracks_all[index]
|
152 |
+
|
153 |
+
if index in ignore_ind:
|
154 |
+
continue
|
155 |
+
if index not in unoccluded_indexes:
|
156 |
+
continue
|
157 |
+
try:
|
158 |
+
bb_left, bb_top, bb_width, bb_height, confidence, id = track
|
159 |
+
except:
|
160 |
+
bb_left, bb_top, bb_width, bb_height, confidence = track
|
161 |
+
|
162 |
+
if confidence > 0.6:
|
163 |
+
mask = poly_seg(image, segmentation_all[index])
|
164 |
+
cv2.imwrite(cwalt_data_path + '/' + camera_name+'_unoccluded_car_detection/' + str(index)+'.png', mask[:, :, ::-1])
|
165 |
+
|
166 |
+
def repeated_indexes(tracks_all,ignore_ind, repeat_count = 10, unoccluded_indexes=None):
|
167 |
+
get_indexes = lambda x, xs: [i for (y, i) in zip(xs, range(len(xs))) if bb_intersection_over_union(x, y) > 0.8 and i not in ignore_ind]
|
168 |
+
repeat_ind = []
|
169 |
+
repeat_inds =[]
|
170 |
+
if unoccluded_indexes == None:
|
171 |
+
for index, track in enumerate(tracks_all):
|
172 |
+
if index in repeat_ind or index in ignore_ind:
|
173 |
+
continue
|
174 |
+
|
175 |
+
ind = get_indexes(track, tracks_all)
|
176 |
+
if len(ind) > repeat_count:
|
177 |
+
repeat_ind.extend(ind)
|
178 |
+
repeat_inds.append([ind,track])
|
179 |
+
else:
|
180 |
+
for index in unoccluded_indexes:
|
181 |
+
if index in repeat_ind or index in ignore_ind:
|
182 |
+
continue
|
183 |
+
ind = get_indexes(tracks_all[index], tracks_all)
|
184 |
+
if len(ind) > repeat_count:
|
185 |
+
repeat_ind.extend(ind)
|
186 |
+
repeat_inds.append([ind,tracks_all[index]])
|
187 |
+
|
188 |
+
|
189 |
+
return repeat_inds
|
190 |
+
|
191 |
+
def poly_seg(image, segm):
|
192 |
+
poly = np.array(segm).reshape((int(len(segm)/2), 2))
|
193 |
+
overlay = image.copy()
|
194 |
+
alpha = 0.5
|
195 |
+
cv2.fillPoly(overlay, [poly], color=(255, 255, 0))
|
196 |
+
cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image)
|
197 |
+
return image
|
198 |
+
|
199 |
+
def visualize_unoccuded_clusters(repeat_inds, tracks, segmentation_all, timestamps_final, cwalt_data_path):
|
200 |
+
for index_, repeat_ind in enumerate(repeat_inds):
|
201 |
+
image = np.array(Image.open(cwalt_data_path+'/'+'T18-median_image.jpg'))
|
202 |
+
try:
|
203 |
+
os.mkdir(cwalt_data_path+ '/Cwalt_database/')
|
204 |
+
except:
|
205 |
+
print('folder exists')
|
206 |
+
try:
|
207 |
+
os.mkdir(cwalt_data_path+ '/Cwalt_database/' + str(index_) +'/')
|
208 |
+
except:
|
209 |
+
print(cwalt_data_path+ '/Cwalt_database/' + str(index_) +'/')
|
210 |
+
|
211 |
+
for i in repeat_ind[0]:
|
212 |
+
try:
|
213 |
+
bb_left, bb_top, bb_width, bb_height, confidence = tracks[i]#bbox
|
214 |
+
except:
|
215 |
+
bb_left, bb_top, bb_width, bb_height, confidence, track_id = tracks[i]#bbox
|
216 |
+
|
217 |
+
cv2.rectangle(image,(int(bb_left), int(bb_top)),(int(bb_left+bb_width), int(bb_top+bb_height)),(0, 0, 255), 2)
|
218 |
+
time = timestamps_final[i]
|
219 |
+
for week_loop in range(5):
|
220 |
+
try:
|
221 |
+
image1 = np.array(Image.open(cwalt_data_path+'/week' +str(week_loop)+'/'+ str(time).replace(' ','T').replace(':','-').split('+')[0] + '.jpg'))
|
222 |
+
break
|
223 |
+
except:
|
224 |
+
continue
|
225 |
+
|
226 |
+
crop = image1[int(bb_top): int(bb_top + bb_height), int(bb_left):int(bb_left + bb_width)]
|
227 |
+
cv2.imwrite(cwalt_data_path+ '/Cwalt_database/' + str(index_) +'/o_' + str(i) +'.jpg', crop[:, :, ::-1])
|
228 |
+
image1 = poly_seg(image1,segmentation_all[i])
|
229 |
+
crop = image1[int(bb_top): int(bb_top + bb_height), int(bb_left):int(bb_left + bb_width)]
|
230 |
+
cv2.imwrite(cwalt_data_path+ '/Cwalt_database/' + str(index_) +'/' + str(i)+'.jpg', crop[:, :, ::-1])
|
231 |
+
if index_ > 100:
|
232 |
+
break
|
233 |
+
|
234 |
+
cv2.imwrite(cwalt_data_path+ '/Cwalt_database/' + str(index_) +'.jpg', image[:, :, ::-1])
|
235 |
+
|
236 |
+
def Get_unoccluded_objects(camera_name, debug = False, scale=True):
|
237 |
+
cwalt_data_path = 'data/' + camera_name
|
238 |
+
data_folder = cwalt_data_path
|
239 |
+
json_file_path = cwalt_data_path + '/' + camera_name + '.json'
|
240 |
+
|
241 |
+
with open(json_file_path, 'r') as j:
|
242 |
+
annotations = json.loads(j.read())
|
243 |
+
|
244 |
+
tracks_all = [parse_bbox(anno['bbox']) for anno in annotations]
|
245 |
+
segmentation_all = [parse_bbox(anno['segmentation']) for anno in annotations]
|
246 |
+
labels_all = [anno['label_id'] for anno in annotations]
|
247 |
+
timestamps_final = [parse(anno['time']) for anno in annotations]
|
248 |
+
|
249 |
+
if scale ==True:
|
250 |
+
scale_factor = 2
|
251 |
+
tracks_all_numpy = np.array(tracks_all)
|
252 |
+
tracks_all_numpy[:,:4] = np.array(tracks_all)[:,:4]/scale_factor
|
253 |
+
tracks_all = tracks_all_numpy.tolist()
|
254 |
+
|
255 |
+
segmentation_all_scaled = []
|
256 |
+
for list_loop in segmentation_all:
|
257 |
+
segmentation_all_scaled.append((np.floor_divide(np.array(list_loop),scale_factor)).tolist())
|
258 |
+
segmentation_all = segmentation_all_scaled
|
259 |
+
|
260 |
+
if debug == True:
|
261 |
+
timestamps_final = timestamps_final[:1000]
|
262 |
+
labels_all = labels_all[:1000]
|
263 |
+
segmentation_all = segmentation_all[:1000]
|
264 |
+
tracks_all = tracks_all[:1000]
|
265 |
+
|
266 |
+
unoccluded_indexes, stationary = get_unoccluded_instances(timestamps_final, tracks_all, threshold = 0.05)
|
267 |
+
if debug == True:
|
268 |
+
visualize_unoccluded_detection(timestamps_final, tracks_all, segmentation_all, unoccluded_indexes, cwalt_data_path, camera_name)
|
269 |
+
|
270 |
+
tracks_all_unoccluded = [tracks_all[i] for i in unoccluded_indexes]
|
271 |
+
segmentation_all_unoccluded = [segmentation_all[i] for i in unoccluded_indexes]
|
272 |
+
labels_all_unoccluded = [labels_all[i] for i in unoccluded_indexes]
|
273 |
+
timestamps_final_unoccluded = [timestamps_final[i] for i in unoccluded_indexes]
|
274 |
+
np.savez(json_file_path,tracks_all_unoccluded=tracks_all_unoccluded, segmentation_all_unoccluded=segmentation_all_unoccluded, labels_all_unoccluded=labels_all_unoccluded, timestamps_final_unoccluded=timestamps_final_unoccluded )
|
275 |
+
|
276 |
+
if debug == True:
|
277 |
+
repeat_inds_clusters = repeated_indexes(tracks_all_unoccluded,[], repeat_count=1)
|
278 |
+
visualize_unoccuded_clusters(repeat_inds_clusters, tracks_all_unoccluded, segmentation_all_unoccluded, timestamps_final_unoccluded, cwalt_data_path)
|
279 |
+
else:
|
280 |
+
repeat_inds_clusters = repeated_indexes(tracks_all_unoccluded,[], repeat_count=10)
|
281 |
+
|
282 |
+
np.savez(json_file_path + '_clubbed', repeat_inds=repeat_inds_clusters)
|
283 |
+
np.savez(json_file_path + '_stationary', stationary=stationary)
|
284 |
+
|
cwalt/Download_Detections.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from psycopg2.extras import RealDictCursor
|
3 |
+
#import cv2
|
4 |
+
import psycopg2
|
5 |
+
import cv2
|
6 |
+
|
7 |
+
|
8 |
+
CONNECTION = "postgres://postgres:"
|
9 |
+
|
10 |
+
conn = psycopg2.connect(CONNECTION)
|
11 |
+
cursor = conn.cursor(cursor_factory=RealDictCursor)
|
12 |
+
|
13 |
+
|
14 |
+
def get_sample():
|
15 |
+
camera_name, camera_id = 'cam2', 4
|
16 |
+
|
17 |
+
print('Executing SQL command')
|
18 |
+
|
19 |
+
cursor.execute("SELECT * FROM annotations WHERE camera_id = {} and time >='2021-05-01 00:00:00' and time <='2021-05-07 23:59:50' and label_id in (1,2)".format(camera_id))
|
20 |
+
|
21 |
+
print('Dumping to json')
|
22 |
+
annotations = json.dumps(cursor.fetchall(), indent=2, default=str)
|
23 |
+
wjdata = json.loads(annotations)
|
24 |
+
with open('{}_{}_test.json'.format(camera_name, camera_id), 'w') as f:
|
25 |
+
json.dump(wjdata, f)
|
26 |
+
print('Done dumping to json')
|
27 |
+
|
28 |
+
get_sample()
|
cwalt/clustering_utils.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Fri May 20 15:18:20 2022
|
5 |
+
|
6 |
+
@author: dinesh
|
7 |
+
"""
|
8 |
+
|
9 |
+
# 0 - Import related libraries
|
10 |
+
|
11 |
+
import urllib
|
12 |
+
import zipfile
|
13 |
+
import os
|
14 |
+
import scipy.io
|
15 |
+
import math
|
16 |
+
import numpy as np
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
import seaborn as sns
|
19 |
+
|
20 |
+
from scipy.spatial.distance import directed_hausdorff
|
21 |
+
from sklearn.cluster import DBSCAN
|
22 |
+
from sklearn.metrics.pairwise import pairwise_distances
|
23 |
+
import scipy.spatial.distance
|
24 |
+
|
25 |
+
from .kmedoid import kMedoids # kMedoids code is adapted from https://github.com/letiantian/kmedoids
|
26 |
+
|
27 |
+
# Some visualization stuff, not so important
|
28 |
+
# sns.set()
|
29 |
+
plt.rcParams['figure.figsize'] = (12, 12)
|
30 |
+
|
31 |
+
# Utility Functions
|
32 |
+
|
33 |
+
color_lst = plt.rcParams['axes.prop_cycle'].by_key()['color']
|
34 |
+
color_lst.extend(['firebrick', 'olive', 'indigo', 'khaki', 'teal', 'saddlebrown',
|
35 |
+
'skyblue', 'coral', 'darkorange', 'lime', 'darkorchid', 'dimgray'])
|
36 |
+
|
37 |
+
|
38 |
+
def plot_cluster(image, traj_lst, cluster_lst):
|
39 |
+
'''
|
40 |
+
Plots given trajectories with a color that is specific for every trajectory's own cluster index.
|
41 |
+
Outlier trajectories which are specified with -1 in `cluster_lst` are plotted dashed with black color
|
42 |
+
'''
|
43 |
+
cluster_count = np.max(cluster_lst) + 1
|
44 |
+
|
45 |
+
for traj, cluster in zip(traj_lst, cluster_lst):
|
46 |
+
|
47 |
+
# if cluster == -1:
|
48 |
+
# # Means it it a noisy trajectory, paint it black
|
49 |
+
# plt.plot(traj[:, 0], traj[:, 1], c='k', linestyle='dashed')
|
50 |
+
#
|
51 |
+
# else:
|
52 |
+
plt.plot(traj[:, 0], traj[:, 1], c=color_lst[cluster % len(color_lst)])
|
53 |
+
|
54 |
+
plt.imshow(image)
|
55 |
+
# plt.show()
|
56 |
+
plt.axis('off')
|
57 |
+
plt.savefig('trajectory.png', bbox_inches='tight')
|
58 |
+
plt.show()
|
59 |
+
|
60 |
+
|
61 |
+
# 3 - Distance matrix
|
62 |
+
|
63 |
+
def hausdorff( u, v):
|
64 |
+
d = max(directed_hausdorff(u, v)[0], directed_hausdorff(v, u)[0])
|
65 |
+
return d
|
66 |
+
|
67 |
+
|
68 |
+
def build_distance_matrix(traj_lst):
|
69 |
+
# 2 - Trajectory segmentation
|
70 |
+
|
71 |
+
print('Running trajectory segmentation...')
|
72 |
+
degree_threshold = 5
|
73 |
+
|
74 |
+
for traj_index, traj in enumerate(traj_lst):
|
75 |
+
|
76 |
+
hold_index_lst = []
|
77 |
+
previous_azimuth = 1000
|
78 |
+
|
79 |
+
for point_index, point in enumerate(traj[:-1]):
|
80 |
+
next_point = traj[point_index + 1]
|
81 |
+
diff_vector = next_point - point
|
82 |
+
azimuth = (math.degrees(math.atan2(*diff_vector)) + 360) % 360
|
83 |
+
|
84 |
+
if abs(azimuth - previous_azimuth) > degree_threshold:
|
85 |
+
hold_index_lst.append(point_index)
|
86 |
+
previous_azimuth = azimuth
|
87 |
+
hold_index_lst.append(traj.shape[0] - 1) # Last point of trajectory is always added
|
88 |
+
|
89 |
+
traj_lst[traj_index] = traj[hold_index_lst, :]
|
90 |
+
|
91 |
+
print('Building distance matrix...')
|
92 |
+
traj_count = len(traj_lst)
|
93 |
+
D = np.zeros((traj_count, traj_count))
|
94 |
+
|
95 |
+
# This may take a while
|
96 |
+
for i in range(traj_count):
|
97 |
+
if i % 20 == 0:
|
98 |
+
print(i)
|
99 |
+
for j in range(i + 1, traj_count):
|
100 |
+
distance = hausdorff(traj_lst[i], traj_lst[j])
|
101 |
+
D[i, j] = distance
|
102 |
+
D[j, i] = distance
|
103 |
+
|
104 |
+
return D
|
105 |
+
|
106 |
+
|
107 |
+
def run_kmedoids(image, traj_lst, D):
|
108 |
+
# 4 - Different clustering methods
|
109 |
+
|
110 |
+
# 4.1 - kmedoids
|
111 |
+
|
112 |
+
traj_count = len(traj_lst)
|
113 |
+
|
114 |
+
k = 3 # The number of clusters
|
115 |
+
medoid_center_lst, cluster2index_lst = kMedoids(D, k)
|
116 |
+
|
117 |
+
cluster_lst = np.empty((traj_count,), dtype=int)
|
118 |
+
|
119 |
+
for cluster in cluster2index_lst:
|
120 |
+
cluster_lst[cluster2index_lst[cluster]] = cluster
|
121 |
+
|
122 |
+
plot_cluster(image, traj_lst, cluster_lst)
|
123 |
+
|
124 |
+
|
125 |
+
def run_dbscan(image, traj_lst, D):
|
126 |
+
mdl = DBSCAN(eps=400, min_samples=10)
|
127 |
+
cluster_lst = mdl.fit_predict(D)
|
128 |
+
|
129 |
+
plot_cluster(image, traj_lst, cluster_lst)
|
130 |
+
|
131 |
+
|
132 |
+
|
cwalt/kmedoid.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Fri May 20 15:18:56 2022
|
5 |
+
|
6 |
+
@author: dinesh
|
7 |
+
"""
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import math
|
11 |
+
|
12 |
+
def kMedoids(D, k, tmax=100):
|
13 |
+
# determine dimensions of distance matrix D
|
14 |
+
m, n = D.shape
|
15 |
+
|
16 |
+
np.fill_diagonal(D, math.inf)
|
17 |
+
|
18 |
+
if k > n:
|
19 |
+
raise Exception('too many medoids')
|
20 |
+
# randomly initialize an array of k medoid indices
|
21 |
+
M = np.arange(n)
|
22 |
+
np.random.shuffle(M)
|
23 |
+
M = np.sort(M[:k])
|
24 |
+
|
25 |
+
# create a copy of the array of medoid indices
|
26 |
+
Mnew = np.copy(M)
|
27 |
+
|
28 |
+
# initialize a dictionary to represent clusters
|
29 |
+
C = {}
|
30 |
+
for t in range(tmax):
|
31 |
+
# determine clusters, i. e. arrays of data indices
|
32 |
+
J = np.argmin(D[:,M], axis=1)
|
33 |
+
|
34 |
+
for kappa in range(k):
|
35 |
+
C[kappa] = np.where(J==kappa)[0]
|
36 |
+
# update cluster medoids
|
37 |
+
for kappa in range(k):
|
38 |
+
J = np.mean(D[np.ix_(C[kappa],C[kappa])],axis=1)
|
39 |
+
j = np.argmin(J)
|
40 |
+
Mnew[kappa] = C[kappa][j]
|
41 |
+
np.sort(Mnew)
|
42 |
+
# check for convergence
|
43 |
+
if np.array_equal(M, Mnew):
|
44 |
+
break
|
45 |
+
M = np.copy(Mnew)
|
46 |
+
else:
|
47 |
+
# final update of cluster memberships
|
48 |
+
J = np.argmin(D[:,M], axis=1)
|
49 |
+
for kappa in range(k):
|
50 |
+
C[kappa] = np.where(J==kappa)[0]
|
51 |
+
|
52 |
+
np.fill_diagonal(D, 0)
|
53 |
+
|
54 |
+
# return results
|
55 |
+
return M, C
|
cwalt/utils.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Fri May 20 15:16:56 2022
|
5 |
+
|
6 |
+
@author: dinesh
|
7 |
+
"""
|
8 |
+
|
9 |
+
import json
|
10 |
+
import cv2
|
11 |
+
from PIL import Image
|
12 |
+
import numpy as np
|
13 |
+
from dateutil.parser import parse
|
14 |
+
|
15 |
+
def bb_intersection_over_union(box1, box2):
|
16 |
+
#print(box1, box2)
|
17 |
+
boxA = box1.copy()
|
18 |
+
boxB = box2.copy()
|
19 |
+
boxA[2] = boxA[0]+boxA[2]
|
20 |
+
boxA[3] = boxA[1]+boxA[3]
|
21 |
+
boxB[2] = boxB[0]+boxB[2]
|
22 |
+
boxB[3] = boxB[1]+boxB[3]
|
23 |
+
# determine the (x, y)-coordinates of the intersection rectangle
|
24 |
+
xA = max(boxA[0], boxB[0])
|
25 |
+
yA = max(boxA[1], boxB[1])
|
26 |
+
xB = min(boxA[2], boxB[2])
|
27 |
+
yB = min(boxA[3], boxB[3])
|
28 |
+
|
29 |
+
# compute the area of intersection rectangle
|
30 |
+
interArea = abs(max((xB - xA, 0)) * max((yB - yA), 0))
|
31 |
+
|
32 |
+
if interArea == 0:
|
33 |
+
return 0
|
34 |
+
# compute the area of both the prediction and ground-truth
|
35 |
+
# rectangles
|
36 |
+
boxAArea = abs((boxA[2] - boxA[0]) * (boxA[3] - boxA[1]))
|
37 |
+
boxBArea = abs((boxB[2] - boxB[0]) * (boxB[3] - boxB[1]))
|
38 |
+
|
39 |
+
# compute the intersection over union by taking the intersection
|
40 |
+
# area and dividing it by the sum of prediction + ground-truth
|
41 |
+
# areas - the interesection area
|
42 |
+
iou = interArea / float(boxAArea + boxBArea - interArea)
|
43 |
+
return iou
|
44 |
+
|
45 |
+
def bb_intersection_over_union_unoccluded(box1, box2, threshold=0.01):
|
46 |
+
#print(box1, box2)
|
47 |
+
boxA = box1.copy()
|
48 |
+
boxB = box2.copy()
|
49 |
+
boxA[2] = boxA[0]+boxA[2]
|
50 |
+
boxA[3] = boxA[1]+boxA[3]
|
51 |
+
boxB[2] = boxB[0]+boxB[2]
|
52 |
+
boxB[3] = boxB[1]+boxB[3]
|
53 |
+
# determine the (x, y)-coordinates of the intersection rectangle
|
54 |
+
xA = max(boxA[0], boxB[0])
|
55 |
+
yA = max(boxA[1], boxB[1])
|
56 |
+
xB = min(boxA[2], boxB[2])
|
57 |
+
yB = min(boxA[3], boxB[3])
|
58 |
+
|
59 |
+
# compute the area of intersection rectangle
|
60 |
+
interArea = abs(max((xB - xA, 0)) * max((yB - yA), 0))
|
61 |
+
|
62 |
+
if interArea == 0:
|
63 |
+
return 0
|
64 |
+
# compute the area of both the prediction and ground-truth
|
65 |
+
# rectangles
|
66 |
+
boxAArea = abs((boxA[2] - boxA[0]) * (boxA[3] - boxA[1]))
|
67 |
+
boxBArea = abs((boxB[2] - boxB[0]) * (boxB[3] - boxB[1]))
|
68 |
+
|
69 |
+
# compute the intersection over union by taking the intersection
|
70 |
+
# area and dividing it by the sum of prediction + ground-truth
|
71 |
+
# areas - the interesection area
|
72 |
+
iou = interArea / float(boxAArea + boxBArea - interArea)
|
73 |
+
|
74 |
+
#print(iou)
|
75 |
+
# return the intersection over union value
|
76 |
+
occlusion = False
|
77 |
+
if iou > threshold and iou < 1:
|
78 |
+
#print(boxA[3], boxB[3], boxB[1])
|
79 |
+
if boxA[3] < boxB[3]:# and boxA[3] > boxB[1]:
|
80 |
+
if boxB[2] > boxA[0]:# and boxB[2] < boxA[2]:
|
81 |
+
#print('first', (boxB[2] - boxA[0])/(boxA[2] - boxA[0]))
|
82 |
+
if (min(boxB[2],boxA[2]) - boxA[0])/(boxA[2] - boxA[0]) > threshold:
|
83 |
+
occlusion = True
|
84 |
+
|
85 |
+
if boxB[0] < boxA[2]: # boxB[0] > boxA[0] and
|
86 |
+
#print('second', (boxA[2] - boxB[0])/(boxA[2] - boxA[0]))
|
87 |
+
if (boxA[2] - max(boxB[0],boxA[0]))/(boxA[2] - boxA[0]) > threshold:
|
88 |
+
occlusion = True
|
89 |
+
if occlusion == False:
|
90 |
+
iou = iou*0
|
91 |
+
#asas
|
92 |
+
# asas
|
93 |
+
#iou = 0.9 #iou*0
|
94 |
+
#print(box1, box2, iou, occlusion)
|
95 |
+
return iou
|
96 |
+
def draw_tracks(image, tracks):
|
97 |
+
"""
|
98 |
+
Draw on input image.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
image (numpy.ndarray): image
|
102 |
+
tracks (list): list of tracks to be drawn on the image.
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
numpy.ndarray: image with the track-ids drawn on it.
|
106 |
+
"""
|
107 |
+
|
108 |
+
for trk in tracks:
|
109 |
+
|
110 |
+
trk_id = trk[1]
|
111 |
+
xmin = trk[2]
|
112 |
+
ymin = trk[3]
|
113 |
+
width = trk[4]
|
114 |
+
height = trk[5]
|
115 |
+
|
116 |
+
xcentroid, ycentroid = int(xmin + 0.5*width), int(ymin + 0.5*height)
|
117 |
+
|
118 |
+
text = "ID {}".format(trk_id)
|
119 |
+
|
120 |
+
cv2.putText(image, text, (xcentroid - 10, ycentroid - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
|
121 |
+
cv2.circle(image, (xcentroid, ycentroid), 4, (0, 255, 0), -1)
|
122 |
+
|
123 |
+
return image
|
124 |
+
|
125 |
+
|
126 |
+
def draw_bboxes(image, tracks):
|
127 |
+
"""
|
128 |
+
Draw the bounding boxes about detected objects in the image.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
image (numpy.ndarray): Image or video frame.
|
132 |
+
bboxes (numpy.ndarray): Bounding boxes pixel coordinates as (xmin, ymin, width, height)
|
133 |
+
confidences (numpy.ndarray): Detection confidence or detection probability.
|
134 |
+
class_ids (numpy.ndarray): Array containing class ids (aka label ids) of each detected object.
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
numpy.ndarray: image with the bounding boxes drawn on it.
|
138 |
+
"""
|
139 |
+
|
140 |
+
for trk in tracks:
|
141 |
+
xmin = int(trk[2])
|
142 |
+
ymin = int(trk[3])
|
143 |
+
width = int(trk[4])
|
144 |
+
height = int(trk[5])
|
145 |
+
clr = (np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255))
|
146 |
+
cv2.rectangle(image, (xmin, ymin), (xmin + width, ymin + height), clr, 2)
|
147 |
+
|
148 |
+
return image
|
149 |
+
|
150 |
+
|
151 |
+
def num(v):
|
152 |
+
number_as_float = float(v)
|
153 |
+
number_as_int = int(number_as_float)
|
154 |
+
return number_as_int if number_as_float == number_as_int else number_as_float
|
155 |
+
|
156 |
+
|
157 |
+
def parse_bbox(bbox_str):
|
158 |
+
bbox_list = bbox_str.strip('{').strip('}').split(',')
|
159 |
+
bbox_list = [num(elem) for elem in bbox_list]
|
160 |
+
return bbox_list
|
161 |
+
|
162 |
+
def parse_seg(bbox_str):
|
163 |
+
bbox_list = bbox_str.strip('{').strip('}').split(',')
|
164 |
+
bbox_list = [num(elem) for elem in bbox_list]
|
165 |
+
ret = bbox_list # []
|
166 |
+
# for i in range(0, len(bbox_list) - 1, 2):
|
167 |
+
# ret.append((bbox_list[i], bbox_list[i + 1]))
|
168 |
+
return ret
|
cwalt_generate.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Sat Jun 4 16:55:58 2022
|
5 |
+
|
6 |
+
@author: dinesh
|
7 |
+
"""
|
8 |
+
from cwalt.CWALT import CWALT_Generation
|
9 |
+
from cwalt.Clip_WALT_Generate import Get_unoccluded_objects
|
10 |
+
|
11 |
+
if __name__ == '__main__':
|
12 |
+
camera_name = 'cam2'
|
13 |
+
Get_unoccluded_objects(camera_name)
|
14 |
+
CWALT_Generation(camera_name)
|
infer.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import ArgumentParser
|
2 |
+
|
3 |
+
from mmdet.apis import inference_detector, init_detector, show_result_pyplot
|
4 |
+
from mmdet.core.mask.utils import encode_mask_results
|
5 |
+
import numpy as np
|
6 |
+
import mmcv
|
7 |
+
import torch
|
8 |
+
from imantics import Polygons, Mask
|
9 |
+
import json
|
10 |
+
import os
|
11 |
+
import cv2, glob
|
12 |
+
|
13 |
+
class detections():
|
14 |
+
def __init__(self, cfg_path, device, model_path = 'data/models/walt_vehicle.pth'):
|
15 |
+
self.model = init_detector(cfg_path, model_path, device=device)
|
16 |
+
self.all_preds = []
|
17 |
+
self.all_scores = []
|
18 |
+
self.index = []
|
19 |
+
self.score_thr = 0.6
|
20 |
+
self.result = []
|
21 |
+
self.record_dict = {'model': cfg_path,'results': []}
|
22 |
+
self.detect_count = []
|
23 |
+
|
24 |
+
|
25 |
+
def run_on_image(self, image):
|
26 |
+
self.result = inference_detector(self.model, image)
|
27 |
+
image_labelled = self.model.show_result(image, self.result, score_thr=self.score_thr)
|
28 |
+
return image_labelled
|
29 |
+
|
30 |
+
def process_output(self, count):
|
31 |
+
result = self.result
|
32 |
+
infer_result = {'url': count,
|
33 |
+
'boxes': [],
|
34 |
+
'scores': [],
|
35 |
+
'keypoints': [],
|
36 |
+
'segmentation': [],
|
37 |
+
'label_ids': [],
|
38 |
+
'track': [],
|
39 |
+
'labels': []}
|
40 |
+
|
41 |
+
if isinstance(result, tuple):
|
42 |
+
bbox_result, segm_result = result
|
43 |
+
#segm_result = encode_mask_results(segm_result)
|
44 |
+
if isinstance(segm_result, tuple):
|
45 |
+
segm_result = segm_result[0] # ms rcnn
|
46 |
+
bboxes = np.vstack(bbox_result)
|
47 |
+
labels = [np.full(bbox.shape[0], i, dtype=np.int32) for i, bbox in enumerate(bbox_result)]
|
48 |
+
|
49 |
+
labels = np.concatenate(labels)
|
50 |
+
segms = None
|
51 |
+
if segm_result is not None and len(labels) > 0: # non empty
|
52 |
+
segms = mmcv.concat_list(segm_result)
|
53 |
+
if isinstance(segms[0], torch.Tensor):
|
54 |
+
segms = torch.stack(segms, dim=0).detach().cpu().numpy()
|
55 |
+
else:
|
56 |
+
segms = np.stack(segms, axis=0)
|
57 |
+
|
58 |
+
for i, (bbox, label, segm) in enumerate(zip(bboxes, labels, segms)):
|
59 |
+
if bbox[-1].item() <0.3:
|
60 |
+
continue
|
61 |
+
box = [bbox[0].item(), bbox[1].item(), bbox[2].item(), bbox[3].item()]
|
62 |
+
polygons = Mask(segm).polygons()
|
63 |
+
|
64 |
+
infer_result['boxes'].append(box)
|
65 |
+
infer_result['segmentation'].append(polygons.segmentation)
|
66 |
+
infer_result['scores'].append(bbox[-1].item())
|
67 |
+
infer_result['labels'].append(self.model.CLASSES[label])
|
68 |
+
infer_result['label_ids'].append(label)
|
69 |
+
self.record_dict['results'].append(infer_result)
|
70 |
+
self.detect_count = labels
|
71 |
+
|
72 |
+
def write_json(self, filename):
|
73 |
+
with open(filename + '.json', 'w') as f:
|
74 |
+
json.dump(self.record_dict, f)
|
75 |
+
|
76 |
+
|
77 |
+
def main():
|
78 |
+
detect_people = detections('configs/walt/walt_people.py', 'cuda:0', model_path='data/models/walt_people.pth')
|
79 |
+
detect = detections('configs/walt/walt_vehicle.py', 'cuda:0', model_path='data/models/walt_vehicle.pth')
|
80 |
+
filenames = sorted(glob.glob('demo/images/*'))
|
81 |
+
count = 0
|
82 |
+
for filename in filenames:
|
83 |
+
img=cv2.imread(filename)
|
84 |
+
try:
|
85 |
+
img = detect_people.run_on_image(img)
|
86 |
+
img = detect.run_on_image(img)
|
87 |
+
except:
|
88 |
+
continue
|
89 |
+
count=count+1
|
90 |
+
|
91 |
+
try:
|
92 |
+
import os
|
93 |
+
os.makedirs(os.path.dirname(filename.replace('demo','demo/results/')))
|
94 |
+
os.mkdirs(os.path.dirname(filename))
|
95 |
+
except:
|
96 |
+
print('done')
|
97 |
+
cv2.imwrite(filename.replace('demo','demo/results/'),img)
|
98 |
+
if count == 30000:
|
99 |
+
break
|
100 |
+
try:
|
101 |
+
detect.process_output(count)
|
102 |
+
except:
|
103 |
+
continue
|
104 |
+
'''
|
105 |
+
|
106 |
+
np.savez('FC', a= detect.record_dict)
|
107 |
+
with open('check.json', 'w') as f:
|
108 |
+
json.dump(detect.record_dict, f)
|
109 |
+
detect.write_json('seq3')
|
110 |
+
asas
|
111 |
+
detect.process_output(0)
|
112 |
+
'''
|
113 |
+
if __name__ == "__main__":
|
114 |
+
main()
|
test.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
import mmcv
|
6 |
+
import torch
|
7 |
+
from mmcv import Config, DictAction
|
8 |
+
from mmcv.cnn import fuse_conv_bn
|
9 |
+
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
10 |
+
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
|
11 |
+
wrap_fp16_model)
|
12 |
+
|
13 |
+
from mmdet.apis import multi_gpu_test, single_gpu_test
|
14 |
+
from walt.datasets import (build_dataloader, build_dataset,
|
15 |
+
replace_ImageToTensor)
|
16 |
+
from mmdet.models import build_detector
|
17 |
+
|
18 |
+
|
19 |
+
def parse_args():
|
20 |
+
parser = argparse.ArgumentParser(
|
21 |
+
description='MMDet test (and eval) a model')
|
22 |
+
parser.add_argument('config', help='test config file path')
|
23 |
+
parser.add_argument('checkpoint', help='checkpoint file')
|
24 |
+
parser.add_argument('--out', help='output result file in pickle format')
|
25 |
+
parser.add_argument(
|
26 |
+
'--fuse-conv-bn',
|
27 |
+
action='store_true',
|
28 |
+
help='Whether to fuse conv and bn, this will slightly increase'
|
29 |
+
'the inference speed')
|
30 |
+
parser.add_argument(
|
31 |
+
'--format-only',
|
32 |
+
action='store_true',
|
33 |
+
help='Format the output results without perform evaluation. It is'
|
34 |
+
'useful when you want to format the result to a specific format and '
|
35 |
+
'submit it to the test server')
|
36 |
+
parser.add_argument(
|
37 |
+
'--eval',
|
38 |
+
type=str,
|
39 |
+
nargs='+',
|
40 |
+
help='evaluation metrics, which depends on the dataset, e.g., "bbox",'
|
41 |
+
' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC')
|
42 |
+
parser.add_argument('--show', action='store_true', help='show results')
|
43 |
+
parser.add_argument(
|
44 |
+
'--show-dir', help='directory where painted images will be saved')
|
45 |
+
parser.add_argument(
|
46 |
+
'--show-score-thr',
|
47 |
+
type=float,
|
48 |
+
default=0.3,
|
49 |
+
help='score threshold (default: 0.3)')
|
50 |
+
parser.add_argument(
|
51 |
+
'--gpu-collect',
|
52 |
+
action='store_true',
|
53 |
+
help='whether to use gpu to collect results.')
|
54 |
+
parser.add_argument(
|
55 |
+
'--tmpdir',
|
56 |
+
help='tmp directory used for collecting results from multiple '
|
57 |
+
'workers, available when gpu-collect is not specified')
|
58 |
+
parser.add_argument(
|
59 |
+
'--cfg-options',
|
60 |
+
nargs='+',
|
61 |
+
action=DictAction,
|
62 |
+
help='override some settings in the used config, the key-value pair '
|
63 |
+
'in xxx=yyy format will be merged into config file. If the value to '
|
64 |
+
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
65 |
+
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
66 |
+
'Note that the quotation marks are necessary and that no white space '
|
67 |
+
'is allowed.')
|
68 |
+
parser.add_argument(
|
69 |
+
'--options',
|
70 |
+
nargs='+',
|
71 |
+
action=DictAction,
|
72 |
+
help='custom options for evaluation, the key-value pair in xxx=yyy '
|
73 |
+
'format will be kwargs for dataset.evaluate() function (deprecate), '
|
74 |
+
'change to --eval-options instead.')
|
75 |
+
parser.add_argument(
|
76 |
+
'--eval-options',
|
77 |
+
nargs='+',
|
78 |
+
action=DictAction,
|
79 |
+
help='custom options for evaluation, the key-value pair in xxx=yyy '
|
80 |
+
'format will be kwargs for dataset.evaluate() function')
|
81 |
+
parser.add_argument(
|
82 |
+
'--launcher',
|
83 |
+
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
84 |
+
default='none',
|
85 |
+
help='job launcher')
|
86 |
+
parser.add_argument('--local_rank', type=int, default=0)
|
87 |
+
args = parser.parse_args()
|
88 |
+
if 'LOCAL_RANK' not in os.environ:
|
89 |
+
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
90 |
+
|
91 |
+
if args.options and args.eval_options:
|
92 |
+
raise ValueError(
|
93 |
+
'--options and --eval-options cannot be both '
|
94 |
+
'specified, --options is deprecated in favor of --eval-options')
|
95 |
+
if args.options:
|
96 |
+
warnings.warn('--options is deprecated in favor of --eval-options')
|
97 |
+
args.eval_options = args.options
|
98 |
+
return args
|
99 |
+
|
100 |
+
|
101 |
+
def main():
|
102 |
+
args = parse_args()
|
103 |
+
|
104 |
+
assert args.out or args.eval or args.format_only or args.show \
|
105 |
+
or args.show_dir, \
|
106 |
+
('Please specify at least one operation (save/eval/format/show the '
|
107 |
+
'results / save the results) with the argument "--out", "--eval"'
|
108 |
+
', "--format-only", "--show" or "--show-dir"')
|
109 |
+
|
110 |
+
if args.eval and args.format_only:
|
111 |
+
raise ValueError('--eval and --format_only cannot be both specified')
|
112 |
+
|
113 |
+
if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
|
114 |
+
raise ValueError('The output file must be a pkl file.')
|
115 |
+
|
116 |
+
cfg = Config.fromfile(args.config)
|
117 |
+
if args.cfg_options is not None:
|
118 |
+
cfg.merge_from_dict(args.cfg_options)
|
119 |
+
# import modules from string list.
|
120 |
+
if cfg.get('custom_imports', None):
|
121 |
+
from mmcv.utils import import_modules_from_strings
|
122 |
+
import_modules_from_strings(**cfg['custom_imports'])
|
123 |
+
# set cudnn_benchmark
|
124 |
+
if cfg.get('cudnn_benchmark', False):
|
125 |
+
torch.backends.cudnn.benchmark = True
|
126 |
+
cfg.model.pretrained = None
|
127 |
+
if cfg.model.get('neck'):
|
128 |
+
if isinstance(cfg.model.neck, list):
|
129 |
+
for neck_cfg in cfg.model.neck:
|
130 |
+
if neck_cfg.get('rfp_backbone'):
|
131 |
+
if neck_cfg.rfp_backbone.get('pretrained'):
|
132 |
+
neck_cfg.rfp_backbone.pretrained = None
|
133 |
+
elif cfg.model.neck.get('rfp_backbone'):
|
134 |
+
if cfg.model.neck.rfp_backbone.get('pretrained'):
|
135 |
+
cfg.model.neck.rfp_backbone.pretrained = None
|
136 |
+
|
137 |
+
# in case the test dataset is concatenated
|
138 |
+
samples_per_gpu = 7
|
139 |
+
if isinstance(cfg.data.test, dict):
|
140 |
+
cfg.data.test.test_mode = True
|
141 |
+
samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)
|
142 |
+
if samples_per_gpu > 1:
|
143 |
+
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
|
144 |
+
cfg.data.test.pipeline = replace_ImageToTensor(
|
145 |
+
cfg.data.test.pipeline)
|
146 |
+
elif isinstance(cfg.data.test, list):
|
147 |
+
for ds_cfg in cfg.data.test:
|
148 |
+
ds_cfg.test_mode = True
|
149 |
+
samples_per_gpu = max(
|
150 |
+
[ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])
|
151 |
+
if samples_per_gpu > 1:
|
152 |
+
for ds_cfg in cfg.data.test:
|
153 |
+
ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)
|
154 |
+
|
155 |
+
# init distributed env first, since logger depends on the dist info.
|
156 |
+
if args.launcher == 'none':
|
157 |
+
distributed = False
|
158 |
+
else:
|
159 |
+
distributed = True
|
160 |
+
init_dist(args.launcher, **cfg.dist_params)
|
161 |
+
|
162 |
+
# build the dataloader
|
163 |
+
print(samples_per_gpu,cfg.data.workers_per_gpu,)
|
164 |
+
dataset = build_dataset(cfg.data.test)
|
165 |
+
data_loader = build_dataloader(
|
166 |
+
dataset,
|
167 |
+
samples_per_gpu=samples_per_gpu,
|
168 |
+
workers_per_gpu=cfg.data.workers_per_gpu,
|
169 |
+
dist=distributed,
|
170 |
+
shuffle=False)
|
171 |
+
|
172 |
+
# build the model and load checkpoint
|
173 |
+
cfg.model.train_cfg = None
|
174 |
+
model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
|
175 |
+
fp16_cfg = cfg.get('fp16', None)
|
176 |
+
if fp16_cfg is not None:
|
177 |
+
wrap_fp16_model(model)
|
178 |
+
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
|
179 |
+
if args.fuse_conv_bn:
|
180 |
+
model = fuse_conv_bn(model)
|
181 |
+
# old versions did not save class info in checkpoints, this walkaround is
|
182 |
+
# for backward compatibility
|
183 |
+
if 'CLASSES' in checkpoint.get('meta', {}):
|
184 |
+
model.CLASSES = checkpoint['meta']['CLASSES']
|
185 |
+
else:
|
186 |
+
model.CLASSES = dataset.CLASSES
|
187 |
+
|
188 |
+
if not distributed:
|
189 |
+
model = MMDataParallel(model, device_ids=[0])
|
190 |
+
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
|
191 |
+
args.show_score_thr)
|
192 |
+
else:
|
193 |
+
model = MMDistributedDataParallel(
|
194 |
+
model.cuda(),
|
195 |
+
device_ids=[torch.cuda.current_device()],
|
196 |
+
broadcast_buffers=False)
|
197 |
+
outputs = multi_gpu_test(model, data_loader, args.tmpdir,
|
198 |
+
args.gpu_collect)
|
199 |
+
import numpy as np
|
200 |
+
|
201 |
+
rank, _ = get_dist_info()
|
202 |
+
if rank == 0:
|
203 |
+
if args.out:
|
204 |
+
print(f'\nwriting results to {args.out}')
|
205 |
+
mmcv.dump(outputs, args.out)
|
206 |
+
kwargs = {} if args.eval_options is None else args.eval_options
|
207 |
+
if args.format_only:
|
208 |
+
dataset.format_results(outputs, **kwargs)
|
209 |
+
if args.eval:
|
210 |
+
eval_kwargs = cfg.get('evaluation', {}).copy()
|
211 |
+
# hard-code way to remove EvalHook args
|
212 |
+
for key in [
|
213 |
+
'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
|
214 |
+
'rule'
|
215 |
+
]:
|
216 |
+
eval_kwargs.pop(key, None)
|
217 |
+
eval_kwargs.update(dict(metric=args.eval, **kwargs))
|
218 |
+
data_evaluated = dataset.evaluate(outputs, **eval_kwargs)
|
219 |
+
np.save(args.checkpoint+'_new1', data_evaluated)
|
220 |
+
print(data_evaluated)
|
221 |
+
|
222 |
+
print(dataset.evaluate(outputs, **eval_kwargs))
|
223 |
+
|
224 |
+
|
225 |
+
if __name__ == '__main__':
|
226 |
+
main()
|
train.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import copy
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
import time
|
6 |
+
import warnings
|
7 |
+
|
8 |
+
import mmcv
|
9 |
+
import torch
|
10 |
+
from mmcv import Config, DictAction
|
11 |
+
from mmcv.runner import get_dist_info, init_dist
|
12 |
+
from mmcv.utils import get_git_hash
|
13 |
+
|
14 |
+
from mmdet import __version__
|
15 |
+
from mmdet.apis import set_random_seed
|
16 |
+
from mmdet.models import build_detector
|
17 |
+
from mmdet.utils import collect_env, get_root_logger
|
18 |
+
from walt.apis import train_detector
|
19 |
+
from walt.datasets import build_dataset
|
20 |
+
|
21 |
+
|
22 |
+
def parse_args():
|
23 |
+
parser = argparse.ArgumentParser(description='Train a detector')
|
24 |
+
parser.add_argument('config', help='train config file path')
|
25 |
+
parser.add_argument('--work-dir', help='the dir to save logs and models')
|
26 |
+
parser.add_argument(
|
27 |
+
'--resume-from', help='the checkpoint file to resume from')
|
28 |
+
parser.add_argument(
|
29 |
+
'--no-validate',
|
30 |
+
action='store_true',
|
31 |
+
help='whether not to evaluate the checkpoint during training')
|
32 |
+
group_gpus = parser.add_mutually_exclusive_group()
|
33 |
+
group_gpus.add_argument(
|
34 |
+
'--gpus',
|
35 |
+
type=int,
|
36 |
+
help='number of gpus to use '
|
37 |
+
'(only applicable to non-distributed training)')
|
38 |
+
group_gpus.add_argument(
|
39 |
+
'--gpu-ids',
|
40 |
+
type=int,
|
41 |
+
nargs='+',
|
42 |
+
help='ids of gpus to use '
|
43 |
+
'(only applicable to non-distributed training)')
|
44 |
+
parser.add_argument('--seed', type=int, default=None, help='random seed')
|
45 |
+
parser.add_argument(
|
46 |
+
'--deterministic',
|
47 |
+
action='store_true',
|
48 |
+
help='whether to set deterministic options for CUDNN backend.')
|
49 |
+
parser.add_argument(
|
50 |
+
'--options',
|
51 |
+
nargs='+',
|
52 |
+
action=DictAction,
|
53 |
+
help='override some settings in the used config, the key-value pair '
|
54 |
+
'in xxx=yyy format will be merged into config file (deprecate), '
|
55 |
+
'change to --cfg-options instead.')
|
56 |
+
parser.add_argument(
|
57 |
+
'--cfg-options',
|
58 |
+
nargs='+',
|
59 |
+
action=DictAction,
|
60 |
+
help='override some settings in the used config, the key-value pair '
|
61 |
+
'in xxx=yyy format will be merged into config file. If the value to '
|
62 |
+
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
63 |
+
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
64 |
+
'Note that the quotation marks are necessary and that no white space '
|
65 |
+
'is allowed.')
|
66 |
+
parser.add_argument(
|
67 |
+
'--launcher',
|
68 |
+
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
69 |
+
default='none',
|
70 |
+
help='job launcher')
|
71 |
+
parser.add_argument('--local_rank', type=int, default=0)
|
72 |
+
args = parser.parse_args()
|
73 |
+
if 'LOCAL_RANK' not in os.environ:
|
74 |
+
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
75 |
+
|
76 |
+
if args.options and args.cfg_options:
|
77 |
+
raise ValueError(
|
78 |
+
'--options and --cfg-options cannot be both '
|
79 |
+
'specified, --options is deprecated in favor of --cfg-options')
|
80 |
+
if args.options:
|
81 |
+
warnings.warn('--options is deprecated in favor of --cfg-options')
|
82 |
+
args.cfg_options = args.options
|
83 |
+
|
84 |
+
return args
|
85 |
+
|
86 |
+
|
87 |
+
def main():
|
88 |
+
args = parse_args()
|
89 |
+
|
90 |
+
cfg = Config.fromfile(args.config)
|
91 |
+
if args.cfg_options is not None:
|
92 |
+
cfg.merge_from_dict(args.cfg_options)
|
93 |
+
# import modules from string list.
|
94 |
+
if cfg.get('custom_imports', None):
|
95 |
+
from mmcv.utils import import_modules_from_strings
|
96 |
+
import_modules_from_strings(**cfg['custom_imports'])
|
97 |
+
# set cudnn_benchmark
|
98 |
+
if cfg.get('cudnn_benchmark', False):
|
99 |
+
torch.backends.cudnn.benchmark = True
|
100 |
+
|
101 |
+
# work_dir is determined in this priority: CLI > segment in file > filename
|
102 |
+
if args.work_dir is not None:
|
103 |
+
# update configs according to CLI args if args.work_dir is not None
|
104 |
+
cfg.work_dir = args.work_dir
|
105 |
+
elif cfg.get('work_dir', None) is None:
|
106 |
+
# use config filename as default work_dir if cfg.work_dir is None
|
107 |
+
cfg.work_dir = osp.join('./work_dirs',
|
108 |
+
osp.splitext(osp.basename(args.config))[0])
|
109 |
+
|
110 |
+
if args.resume_from is not None:
|
111 |
+
cfg.resume_from = args.resume_from
|
112 |
+
if args.gpu_ids is not None:
|
113 |
+
cfg.gpu_ids = args.gpu_ids
|
114 |
+
else:
|
115 |
+
cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
|
116 |
+
|
117 |
+
# init distributed env first, since logger depends on the dist info.
|
118 |
+
if args.launcher == 'none':
|
119 |
+
distributed = False
|
120 |
+
else:
|
121 |
+
distributed = True
|
122 |
+
init_dist(args.launcher, **cfg.dist_params)
|
123 |
+
# re-set gpu_ids with distributed training mode
|
124 |
+
_, world_size = get_dist_info()
|
125 |
+
cfg.gpu_ids = range(world_size)
|
126 |
+
|
127 |
+
|
128 |
+
# create work_dir
|
129 |
+
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
|
130 |
+
# dump config
|
131 |
+
cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
|
132 |
+
# init the logger before other steps
|
133 |
+
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
134 |
+
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
|
135 |
+
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
|
136 |
+
|
137 |
+
# init the meta dict to record some important information such as
|
138 |
+
# environment info and seed, which will be logged
|
139 |
+
meta = dict()
|
140 |
+
# log env info
|
141 |
+
env_info_dict = collect_env()
|
142 |
+
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
|
143 |
+
dash_line = '-' * 60 + '\n'
|
144 |
+
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
|
145 |
+
dash_line)
|
146 |
+
meta['env_info'] = env_info
|
147 |
+
meta['config'] = cfg.pretty_text
|
148 |
+
# log some basic info
|
149 |
+
logger.info(f'Distributed training: {distributed}')
|
150 |
+
logger.info(f'Config:\n{cfg.pretty_text}')
|
151 |
+
|
152 |
+
# set random seeds
|
153 |
+
if args.seed is not None:
|
154 |
+
logger.info(f'Set random seed to {args.seed}, '
|
155 |
+
f'deterministic: {args.deterministic}')
|
156 |
+
set_random_seed(args.seed, deterministic=args.deterministic)
|
157 |
+
cfg.seed = args.seed
|
158 |
+
meta['seed'] = args.seed
|
159 |
+
meta['exp_name'] = osp.basename(args.config)
|
160 |
+
|
161 |
+
model = build_detector(
|
162 |
+
cfg.model,
|
163 |
+
train_cfg=cfg.get('train_cfg'),
|
164 |
+
test_cfg=cfg.get('test_cfg'))
|
165 |
+
|
166 |
+
datasets = [build_dataset(cfg.data.train)]
|
167 |
+
if len(cfg.workflow) == 2:
|
168 |
+
val_dataset = copy.deepcopy(cfg.data.val)
|
169 |
+
val_dataset.pipeline = cfg.data.train.pipeline
|
170 |
+
datasets.append(build_dataset(val_dataset))
|
171 |
+
if cfg.checkpoint_config is not None:
|
172 |
+
# save mmdet version, config file content and class names in
|
173 |
+
# checkpoints as meta data
|
174 |
+
cfg.checkpoint_config.meta = dict(
|
175 |
+
mmdet_version=__version__ + get_git_hash()[:7],
|
176 |
+
CLASSES=datasets[0].CLASSES)
|
177 |
+
|
178 |
+
# add an attribute for visualization convenience
|
179 |
+
model.CLASSES = datasets[0].CLASSES
|
180 |
+
train_detector(
|
181 |
+
model,
|
182 |
+
datasets,
|
183 |
+
cfg,
|
184 |
+
distributed=distributed,
|
185 |
+
validate=(not args.no_validate),
|
186 |
+
timestamp=timestamp,
|
187 |
+
meta=meta)
|
188 |
+
|
189 |
+
|
190 |
+
if __name__ == '__main__':
|
191 |
+
main()
|