xangcastle's picture
add sorter and improve detections
ab952d9
raw
history blame
4.97 kB
import gradio as gr
import numpy as np
import cv2
from norfair import Detection, Tracker, Video
from detector.utils import detect_plates, detect_chars, imcrop, send_request, draw_text
from threading import Thread
DISTANCE_THRESHOLD_BBOX: float = 0.7
DISTANCE_THRESHOLD_CENTROID: int = 30
MAX_DISTANCE: int = 10000
def yolo_to_norfair(yolo_detections):
norfair_detections = []
detections_as_xyxy = yolo_detections.xyxy[0]
for detection_as_xyxy in detections_as_xyxy:
bbox = np.array(
[
[detection_as_xyxy[0].item(), detection_as_xyxy[1].item()],
[detection_as_xyxy[2].item(), detection_as_xyxy[3].item()],
]
)
scores = np.array(
[detection_as_xyxy[4].item(), detection_as_xyxy[4].item()]
)
norfair_detections.append(
Detection(
points=bbox, scores=scores, label=int(detection_as_xyxy[-1].item())
)
)
return norfair_detections
def fn_image(foto):
plates_text = []
plates = detect_plates(foto)
records = plates.pandas().xyxy[0].to_dict(orient='records')
if records:
for plate in records:
xi, yi, xf, yf = int(plate['xmin']), int(plate['ymin']), int(plate['xmax']), int(plate['ymax'])
crop = imcrop(foto, (xi, yi, xf, yf))
if len(crop) > 0:
cv2.rectangle(foto, (xi, yi), (xf, yf), (0, 255, 0), 2)
text = detect_chars(crop)
draw_text(foto, text, (xi, yi))
plates_text.append(text)
return foto, plates_text
def fn_video(video, initial_time, duration):
tracker = Tracker(
distance_function="iou_opt",
distance_threshold=DISTANCE_THRESHOLD_BBOX,
)
cap = cv2.VideoCapture(video)
fps = cap.get(cv2.CAP_PROP_FPS)
image_size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
final_video = cv2.VideoWriter('output.mp4', cv2.VideoWriter_fourcc(*'VP90'), fps, image_size)
num_frames = 0
min_frame = int(initial_time * fps)
max_frame = int((initial_time + duration) * fps)
plates = {}
while cap.isOpened():
try:
ret, frame = cap.read()
if not ret:
break
frame_copy = frame.copy()
except Exception as e:
print(e)
continue
if num_frames < min_frame:
num_frames += 1
continue
yolo_detections = detect_plates(frame)
detections = yolo_to_norfair(yolo_detections)
tracked_objects = tracker.update(detections=detections)
for obj in tracked_objects:
if obj.last_detection is not None:
bbox = obj.last_detection.points
bbox = int(bbox[0][0]), int(bbox[0][1]), int(bbox[1][0]), int(bbox[1][1])
if obj.id not in plates.keys():
crop = imcrop(frame, bbox)
text = detect_chars(crop)
plates[obj.id] = text
thread = Thread(target=send_request, args=(frame_copy, text, bbox))
thread.start()
cv2.rectangle(
frame,
(bbox[0], bbox[1]),
(bbox[2], bbox[3]),
(0, 255, 0),
2,
)
draw_text(frame, plates[obj.id], (bbox[0], bbox[1]))
cv2.putText(
frame,
plates[obj.id],
(bbox[0], bbox[1]),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(0, 255, 0),
2,
)
final_video.write(frame)
num_frames += 1
if num_frames == max_frame:
break
cap.release()
final_video.release()
return 'output.mp4', [plates[k] for k in plates.keys()]
image_interface = gr.Interface(
fn=fn_image,
inputs="image",
outputs=["image", "text"],
title="Buscar números de placa en una imagen",
allow_flagging=False,
allow_screenshot=False,
)
video_interface = gr.Interface(
fn=fn_video,
inputs=[
gr.Video(type="file", label="Video"),
gr.Slider(0, 600, value=0, label="Tiempo inicial en segundos", step=1),
gr.Slider(0, 10, value=4, label="Duración en segundos", step=1),
],
outputs=["video", "text"],
title="Buscar números de placa en un video",
allow_flagging=False,
allow_screenshot=False,
)
webcam_interface = gr.Interface(
fn_image,
inputs=[
gr.Image(source='webcam', streaming=True),
],
outputs=gr.Image(type="file"),
live=True,
title="Buscar placa con la cámara",
allow_flagging=False,
allow_screenshot=False,
)
if __name__ == "__main__":
gr.TabbedInterface(
[image_interface, video_interface],
["Fotos", "Videos"],
).launch()