Wanli
Add script to evaluate face detection by WIDERFace (#70)
60ba673
import os
import tqdm
import pickle
import numpy as np
from scipy.io import loadmat
import cv2 as cv
def get_gt_boxes(gt_dir):
""" gt dir: (wider_face_val.mat, wider_easy_val.mat, wider_medium_val.mat, wider_hard_val.mat)"""
gt_mat = loadmat(os.path.join(gt_dir, 'wider_face_val.mat'))
hard_mat = loadmat(os.path.join(gt_dir, 'wider_hard_val.mat'))
medium_mat = loadmat(os.path.join(gt_dir, 'wider_medium_val.mat'))
easy_mat = loadmat(os.path.join(gt_dir, 'wider_easy_val.mat'))
facebox_list = gt_mat['face_bbx_list']
event_list = gt_mat['event_list']
file_list = gt_mat['file_list']
hard_gt_list = hard_mat['gt_list']
medium_gt_list = medium_mat['gt_list']
easy_gt_list = easy_mat['gt_list']
return facebox_list, event_list, file_list, hard_gt_list, medium_gt_list, easy_gt_list
def get_gt_boxes_from_txt(gt_path, cache_dir):
cache_file = os.path.join(cache_dir, 'gt_cache.pkl')
if os.path.exists(cache_file):
f = open(cache_file, 'rb')
boxes = pickle.load(f)
f.close()
return boxes
f = open(gt_path, 'r')
state = 0
lines = f.readlines()
lines = list(map(lambda x: x.rstrip('\r\n'), lines))
boxes = {}
print(len(lines))
f.close()
current_boxes = []
current_name = None
for line in lines:
if state == 0 and '--' in line:
state = 1
current_name = line
continue
if state == 1:
state = 2
continue
if state == 2 and '--' in line:
state = 1
boxes[current_name] = np.array(current_boxes).astype('float32')
current_name = line
current_boxes = []
continue
if state == 2:
box = [float(x) for x in line.split(' ')[:4]]
current_boxes.append(box)
continue
f = open(cache_file, 'wb')
pickle.dump(boxes, f)
f.close()
return boxes
def norm_score(pred):
""" norm score
pred {key: [[x1,y1,x2,y2,s]]}
"""
max_score = 0
min_score = 1
for _, k in pred.items():
for _, v in k.items():
if len(v) == 0:
continue
_min = np.min(v[:, -1])
_max = np.max(v[:, -1])
max_score = max(_max, max_score)
min_score = min(_min, min_score)
diff = max_score - min_score
for _, k in pred.items():
for _, v in k.items():
if len(v) == 0:
continue
v[:, -1] = (v[:, -1] - min_score) / diff
def bbox_overlaps(a, b):
"""
return iou of a and b, numpy version for data augenmentation
"""
lt = np.maximum(a[:, np.newaxis, 0:2], b[:, 0:2])
rb = np.minimum(a[:, np.newaxis, 2:4], b[:, 2:4])
area_i = np.prod(rb - lt + 1, axis=2) * (lt < rb).all(axis=2)
area_a = np.prod(a[:, 2:4] - a[:, 0:2] + 1, axis=1)
area_b = np.prod(b[:, 2:4] - b[:, 0:2] + 1, axis=1)
return area_i / (area_a[:, np.newaxis] + area_b - area_i)
def image_eval(pred, gt, ignore, iou_thresh):
""" single image evaluation
pred: Nx5
gt: Nx4
ignore:
"""
_pred = pred.copy()
_gt = gt.copy()
pred_recall = np.zeros(_pred.shape[0])
recall_list = np.zeros(_gt.shape[0])
proposal_list = np.ones(_pred.shape[0])
_pred[:, 2] = _pred[:, 2] + _pred[:, 0]
_pred[:, 3] = _pred[:, 3] + _pred[:, 1]
_gt[:, 2] = _gt[:, 2] + _gt[:, 0]
_gt[:, 3] = _gt[:, 3] + _gt[:, 1]
overlaps = bbox_overlaps(_pred[:, :4], _gt)
for h in range(_pred.shape[0]):
gt_overlap = overlaps[h]
max_overlap, max_idx = gt_overlap.max(), gt_overlap.argmax()
if max_overlap >= iou_thresh:
if ignore[max_idx] == 0:
recall_list[max_idx] = -1
proposal_list[h] = -1
elif recall_list[max_idx] == 0:
recall_list[max_idx] = 1
r_keep_index = np.where(recall_list == 1)[0]
pred_recall[h] = len(r_keep_index)
return pred_recall, proposal_list
def img_pr_info(thresh_num, pred_info, proposal_list, pred_recall):
pr_info = np.zeros((thresh_num, 2)).astype('float')
for t in range(thresh_num):
thresh = 1 - (t + 1) / thresh_num
r_index = np.where(pred_info[:, 4] >= thresh)[0]
if len(r_index) == 0:
pr_info[t, 0] = 0
pr_info[t, 1] = 0
else:
r_index = r_index[-1]
p_index = np.where(proposal_list[:r_index + 1] == 1)[0]
pr_info[t, 0] = len(p_index)
pr_info[t, 1] = pred_recall[r_index]
return pr_info
def dataset_pr_info(thresh_num, pr_curve, count_face):
_pr_curve = np.zeros((thresh_num, 2))
for i in range(thresh_num):
_pr_curve[i, 0] = pr_curve[i, 1] / pr_curve[i, 0]
_pr_curve[i, 1] = pr_curve[i, 1] / count_face
return _pr_curve
def voc_ap(rec, prec):
# correct AP calculation
# first append sentinel values at the end
mrec = np.concatenate(([0.], rec, [1.]))
mpre = np.concatenate(([0.], prec, [0.]))
# compute the precision envelope
for i in range(mpre.size - 1, 0, -1):
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
# to calculate area under PR curve, look for points
# where X axis (recall) changes value
i = np.where(mrec[1:] != mrec[:-1])[0]
# and sum (\Delta recall) * prec
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
return ap
def evaluation(pred, gt_path, iou_thresh=0.5):
norm_score(pred)
facebox_list, event_list, file_list, hard_gt_list, medium_gt_list, easy_gt_list = get_gt_boxes(gt_path)
event_num = len(event_list)
thresh_num = 1000
settings = ['easy', 'medium', 'hard']
setting_gts = [easy_gt_list, medium_gt_list, hard_gt_list]
aps = []
for setting_id in range(3):
# different setting
gt_list = setting_gts[setting_id]
count_face = 0
pr_curve = np.zeros((thresh_num, 2)).astype('float')
# [hard, medium, easy]
pbar = tqdm.tqdm(range(event_num))
for i in pbar:
pbar.set_description('Processing {}'.format(settings[setting_id]))
event_name = str(event_list[i][0][0])
img_list = file_list[i][0]
pred_list = pred[event_name]
sub_gt_list = gt_list[i][0]
# img_pr_info_list = np.zeros((len(img_list), thresh_num, 2))
gt_bbx_list = facebox_list[i][0]
for j in range(len(img_list)):
pred_info = pred_list[str(img_list[j][0][0])]
gt_boxes = gt_bbx_list[j][0].astype('float')
keep_index = sub_gt_list[j][0]
count_face += len(keep_index)
if len(gt_boxes) == 0 or len(pred_info) == 0:
continue
ignore = np.zeros(gt_boxes.shape[0])
if len(keep_index) != 0:
ignore[keep_index - 1] = 1
pred_recall, proposal_list = image_eval(pred_info, gt_boxes, ignore, iou_thresh)
_img_pr_info = img_pr_info(thresh_num, pred_info, proposal_list, pred_recall)
pr_curve += _img_pr_info
pr_curve = dataset_pr_info(thresh_num, pr_curve, count_face)
propose = pr_curve[:, 0]
recall = pr_curve[:, 1]
ap = voc_ap(recall, propose)
aps.append(ap)
return aps
class WIDERFace:
def __init__(self, root, split='val'):
self.aps = []
self.widerface_root = root
self._split = split
self.widerface_img_paths = {
'val': os.path.join(self.widerface_root, 'WIDER_val', 'images'),
'test': os.path.join(self.widerface_root, 'WIDER_test', 'images')
}
self.widerface_split_fpaths = {
'val': os.path.join(self.widerface_root, 'wider_face_split', 'wider_face_val.mat'),
'test': os.path.join(self.widerface_root, 'wider_face_split', 'wider_face_test.mat')
}
self.img_list, self.num_img = self.load_list()
@property
def name(self):
return self.__class__.__name__
def load_list(self):
n_imgs = 0
flist = []
split_fpath = self.widerface_split_fpaths[self._split]
img_path = self.widerface_img_paths[self._split]
anno_data = loadmat(split_fpath)
event_list = anno_data.get('event_list')
file_list = anno_data.get('file_list')
for event_idx, event in enumerate(event_list):
event_name = event[0][0]
for f_idx, f in enumerate(file_list[event_idx][0]):
f_name = f[0][0]
f_path = os.path.join(img_path, event_name, f_name + '.jpg')
flist.append(f_path)
n_imgs += 1
return flist, n_imgs
def __getitem__(self, index):
img = cv.imread(self.img_list[index])
event, name = self.img_list[index].split(os.sep)[-2:]
return event, name, img
def eval(self, model):
results_list = dict()
pbar = tqdm.tqdm(self)
pbar.set_description_str("Evaluating {} with {} val set".format(model.name, self.name))
# forward
for event_name, img_name, img in pbar:
img_shape = [img.shape[1], img.shape[0]]
model.setInputSize(img_shape)
det = model.infer(img)
if not results_list.get(event_name):
results_list[event_name] = dict()
if det is None:
det = np.array([[10, 10, 20, 20, 0.002]])
else:
det = np.append(np.around(det[:, :4], 1), np.around(det[:, -1], 3).reshape(-1, 1), axis=1)
results_list[event_name][img_name.rstrip('.jpg')] = det
self.aps = evaluation(results_list, os.path.join(self.widerface_root, 'eval_tools', 'ground_truth'))
def print_result(self):
print("==================== Results ====================")
print("Easy Val AP: {}".format(self.aps[0]))
print("Medium Val AP: {}".format(self.aps[1]))
print("Hard Val AP: {}".format(self.aps[2]))
print("=================================================")