hb-setosys's picture
Update app.py
e5e492a verified
raw
history blame
4.42 kB
import os
import cv2
import numpy as np
import torch
import logging
from ultralytics import YOLO
from sort import Sort
import gradio as gr
# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
# Load YOLOv12x model
MODEL_PATH = "yolov12x.pt"
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"Model file '{MODEL_PATH}' not found.")
model = YOLO(MODEL_PATH)
# COCO dataset class ID for truck
TRUCK_CLASS_ID = 7 # "truck"
# Initialize SORT tracker
tracker = Sort()
# Minimum confidence threshold for detection
CONFIDENCE_THRESHOLD = 0.4 # Adjust based on performance
# Distance threshold to avoid duplicate counts
DISTANCE_THRESHOLD = 50
# Dictionary to define keyword-based time intervals
TIME_INTERVALS = {
"one": 1, "two": 2, "three": 3, "four": 4, "five": 5,
"six": 6, "seven": 7, "eight": 8, "nine": 9, "ten": 10, "eleven": 11
}
def determine_time_interval(video_filename):
""" Determines frame skip interval based on keywords in the filename. """
logging.info(f"Checking filename: {video_filename}")
for keyword, interval in TIME_INTERVALS.items():
if keyword in video_filename:
logging.info(f"Matched keyword: {keyword} -> Interval: {interval}")
return interval
logging.info("No keyword match, using default interval: 5")
return 5 # Default interval
def count_unique_trucks(video_path):
""" Counts unique trucks in a video using YOLOv12x and SORT tracking. """
if not os.path.exists(video_path):
return {"Error": "Video file not found."}
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return {"Error": "Unable to open video file."}
unique_truck_ids = set()
truck_history = {}
# Get FPS and total frames
fps = int(cap.get(cv2.CAP_PROP_FPS)) or 30 # Default to 30 if retrieval fails
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 1
# Extract filename and determine time interval
video_filename = os.path.basename(video_path).lower()
time_interval = determine_time_interval(video_filename)
# Ensure frame_skip does not exceed total frames
frame_skip = min(fps * time_interval, max(1, total_frames // 2))
frame_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break # End of video
frame_count += 1
if frame_count % frame_skip != 0:
continue # Skip frames based on interval
# Run YOLOv12x inference
results = model(frame, verbose=False)
detections = []
for result in results:
for box in result.boxes:
class_id = int(box.cls.item())
confidence = float(box.conf.item())
if class_id == TRUCK_CLASS_ID and confidence > CONFIDENCE_THRESHOLD:
x1, y1, x2, y2 = map(int, box.xyxy[0])
detections.append([x1, y1, x2, y2, confidence])
if detections:
tracked_objects = tracker.update(np.array(detections))
else:
tracked_objects = []
for obj in tracked_objects:
truck_id = int(obj[4])
truck_center = ((obj[0] + obj[2]) / 2, (obj[1] + obj[3]) / 2)
if truck_id in truck_history:
last_position = truck_history[truck_id]["position"]
distance = np.linalg.norm(np.array(truck_center) - np.array(last_position))
if distance > DISTANCE_THRESHOLD:
unique_truck_ids.add(truck_id)
else:
truck_history[truck_id] = {"position": truck_center}
unique_truck_ids.add(truck_id)
cap.release()
return {"Total Unique Trucks": len(unique_truck_ids)}
# Gradio UI function
def analyze_video(video_file):
if not video_file:
return "Error: No video file uploaded."
result = count_unique_trucks(video_file)
return "\n".join([f"{key}: {value}" for key, value in result.items()])
# Define Gradio interface
iface = gr.Interface(
fn=analyze_video,
inputs=gr.Video(label="Upload Video"),
outputs=gr.Textbox(label="Analysis Result"),
title="YOLOv12x Unique Truck Counter",
description="Upload a video to count unique trucks using YOLOv12x and SORT tracking."
)
# Launch the Gradio app
if __name__ == "__main__":
iface.launch()