Aumkeshchy2003's picture
Update app.py
a83113c verified
raw
history blame
2.98 kB
import cv2
import torch
import numpy as np
import gradio as gr
import time
import os
from pathlib import Path
import onnxruntime as ort
# Set device for ONNX Runtime
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if torch.cuda.is_available() else ['CPUExecutionProvider']
session = ort.InferenceSession("models/yolov5n.onnx", providers=providers)
# Load model class names
class_names = ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light"] # Modify based on model
# Generate random colors for classes
np.random.seed(42)
colors = np.random.uniform(0, 255, size=(len(class_names), 3))
def preprocess(image):
image = cv2.resize(image, (640, 640))
image = image.transpose((2, 0, 1)) / 255.0 # Normalize
image = np.expand_dims(image, axis=0).astype(np.float32)
return image
def detect_objects(image):
start_time = time.time()
image_input = preprocess(image)
outputs = session.run(None, {session.get_inputs()[0].name: image_input})
detections = outputs[0][0]
output_image = image.copy()
for det in detections:
x1, y1, x2, y2, conf, cls = map(int, det[:6])
if conf > 0.6: # Confidence threshold
color = colors[cls].tolist()
cv2.rectangle(output_image, (x1, y1), (x2, y2), color, 2)
label = f"{class_names[cls]} {conf:.2f}"
cv2.putText(output_image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
fps = 1 / (time.time() - start_time)
cv2.putText(output_image, f"FPS: {fps:.2f}", (20, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
return output_image
def real_time_detection():
cap = cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)
cap.set(cv2.CAP_PROP_FPS, 60)
while cap.isOpened():
start_time = time.time()
ret, frame = cap.read()
if not ret:
break
output_frame = detect_objects(frame)
cv2.imshow("Real-Time Object Detection", output_frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
print(f"FPS: {1 / (time.time() - start_time):.2f}")
cap.release()
cv2.destroyAllWindows()
with gr.Blocks(title="YOLOv5 Real-Time Object Detection") as demo:
gr.Markdown("""
# Real-Time Object Detection with YOLOv5
**Upload an image or run real-time detection**
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Upload Image", type="numpy")
detect_button = gr.Button("Detect Objects")
start_rt_button = gr.Button("Start Real-Time Detection")
with gr.Column():
output_image = gr.Image(label="Detection Results", type="numpy")
detect_button.click(detect_objects, inputs=input_image, outputs=output_image)
start_rt_button.click(lambda: real_time_detection(), None, None)
demo.launch()