YOLOv8Detection / scripts /test_yolob8_tflite.py
Jaime García Villena
move scripts and create README
82604cd
raw
history blame
6.7 kB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Oct 4 16:44:12 2023
@author: lin
"""
import glob
import sys
sys.path.append('../../..')
import os
import cv2
import json
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
# from utils.bbox_op import non_max_supression
def one_multiple_iou(box, boxes, box_area, boxes_area):
"""
Compute the intersection over union. 1 to multiple
Inputs:
box: numpy array with 1 box, ymin, xmin, ymax, xmax
boxes: numpy array with shape [N, 4] holding N boxes
Outputs:
a numpy array with shape [N*1] representing box areas
"""
# this is the iou of the box against all other boxes
assert boxes.shape[0] == boxes_area.shape[0]
ymin = np.maximum(box[0], boxes[:, 0]) # bottom
xmin = np.maximum(box[1], boxes[:, 1]) # left
ymax = np.minimum(box[2], boxes[:, 2]) # top
xmax = np.minimum(box[3], boxes[:, 3]) # rifht
# we ignore areas where the intersection side would be negative
# this is done by using maxing the side length by 0
intersections = np.maximum(ymax - ymin, 0) * np.maximum(xmax - xmin, 0)
# each union is then the box area
# added to each other box area minusing their intersection calculated above
unions = box_area + boxes_area - intersections
# element wise division
# if the intersection is 0, then their ratio is 0
ious = intersections / unions
return ious
def select_non_overlapping_bboxes(boxes, scores, iou_th):
ymin = boxes[:, 0]
ymax = boxes[:, 2]
xmin = boxes[:, 1]
xmax = boxes[:, 3]
# box coordinate ranges are inclusive-inclusive
areas = (ymax - ymin) * (xmax - xmin)
scores_indexes = list(np.argsort(scores))
keep_idx = []
while len(scores_indexes) > 0:
index = scores_indexes.pop()
keep_idx.append(index)
ious = one_multiple_iou(
boxes[index], boxes[scores_indexes], areas[index], areas[scores_indexes]
)
filtered_indexes = set((ious > iou_th).nonzero()[0])
scores_indexes = [
v for (i, v) in enumerate(scores_indexes) if i not in filtered_indexes
]
return keep_idx
def non_max_supression(boxes, scores, classes, iou_th):
"""
remover overlaped boundingboxes. Starting by the box with the highest score
if the iou is greater than the threshold, remove it, else keep it.
Inputs:
boxes: numpy array with shape [N, 4] holding N boxes。 [ymin, xmin, ymax, xmax]
scores: numpy array with shape [N, 1] holding the prediction score of each box
classes: numpy array with shape [N, 1] holding the class that each box belongs
iou_th: intersection over union threshold to consider the overlapping boxes have detect 2 objects
Output:
boxes, scores, classes with intersection over union ratio less than the threshold.
"""
# assert boxes.shape[0] == scores.shape[0]
if len(scores) == 0:
return boxes, scores, classes
keep_idx = select_non_overlapping_bboxes(boxes, scores, iou_th)
return boxes[keep_idx], scores[keep_idx], classes[keep_idx]
def preprocess(img_path):
image_np = cv2.imread(img_path)
image_np = center_crop(image_np)
image_np = cv2.resize(image_np, (640, 640))
#image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
image_np = image_np.astype(float)
image_np /= 255.0
return image_np
def center_crop(img):
width, height = img.shape[1], img.shape[0]
crop_size = width if width < height else height
mid_x, mid_y = int(width/2), int(height/2)
cs2 = int(crop_size/2)
crop_img = img[mid_y-cs2:mid_y+cs2, mid_x-cs2:mid_x+cs2]
return crop_img
def postprocess_prediction(preds):
bboxes = preds[0][:4]
class_prob = preds[0, 4:]
classes = np.argmax(class_prob, axis=0)
scores = np.max(class_prob, axis=0)
# filter by threshold
valid_idx = np.where(scores>=min_th)[0]
bboxes = bboxes[:, valid_idx]
classes = classes[valid_idx]
scores = scores[valid_idx]
bboxes = bboxes.transpose()
bboxes = bboxes*640
xmin = bboxes[:,0]-bboxes[:,2]//2
xmax = bboxes[:,0]+bboxes[:,2]//2
ymin = bboxes[:,1]-bboxes[:,3]//2
ymax = bboxes[:,1]+bboxes[:,3]//2
xmin = np.clip(xmin, 0, 640)
ymin = np.clip(ymin, 0, 640)
bboxes = np.vstack([ymin, xmin, ymax, xmax])
bboxes = bboxes.transpose()
bboxes = bboxes.astype(int)
bboxes, scores, classes = non_max_supression(bboxes, scores, classes, iou_th=0.5)
idx = np.argsort(scores)[::-1]
bboxes = bboxes[idx]
classes = classes[idx]
scores = scores[idx]
return bboxes, classes, scores
def plot_prediction(image_np, bboxes, classes, scores, label_map):
color=(255,0,0)
thickness=5
font_scale=3
for i, box in enumerate(bboxes):
box = bboxes[i, :]
ymin, xmin, ymax, xmax = box
image_np = cv2.rectangle(image_np, (xmin, ymin), (xmax, ymax), color=color, thickness=thickness)
text_x = xmin - 10 if xmin > 20 else xmin + 10
text_y = ymin - 10 if ymin > 20 else ymin + 10
display_str = label_map[str(classes[i])]
cv2.putText(
image_np,
display_str,
(text_x, text_y),
cv2.FONT_HERSHEY_SIMPLEX,
font_scale,
color,
thickness,
)
plt.imshow(image_np)
plt.show()
def predict_yolo_tflite(intenpreter, image_np):
input_tensor = np.expand_dims(image_np, axis=0)
input_tensor = tf.convert_to_tensor(input_tensor, dtype=tf.float32)
interpreter.set_tensor(input_details[0]['index'], input_tensor.numpy())
interpreter.invoke()
preds = interpreter.get_tensor(output_details[0]['index'])
return preds
if __name__ == "__main__":
min_th = 0.1
labels_json = "coco_labels.json"
with open(labels_json) as f:
label_map = json.load(f)
img_path = "test_images"
saved_tflite = "tflite_model.tflite"
# load model
interpreter = tf.lite.Interpreter(model_path=saved_tflite)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details)
print(output_details)
images = glob.glob(os.path.join(img_path, "*"))
for img in images:
image_np = preprocess(img)
print(image_np.shape)
# image_np = np.array(Image.open(image_path))
preds = predict_yolo_tflite(interpreter, image_np)
bboxes, classes, scores = postprocess_prediction(preds)
plot_prediction(image_np, bboxes, classes, scores, label_map)