import gradio as gr import cv2 import requests import os import random from ultralytics import YOLO import numpy as np from collections import defaultdict import sqlite3 import time # Import the supervision library import supervision as sv # --- Initialize SQLite DB for logging --- conn = sqlite3.connect("detection_log.db", check_same_thread=False) cursor = conn.cursor() cursor.execute(''' CREATE TABLE IF NOT EXISTS detections ( timestamp REAL, frame_number INTEGER, bin_name TEXT, class_name TEXT, count INTEGER ) ''') conn.commit() # --- File Downloading --- # File URLs for sample images and video file_urls = [ 'https://huggingface.co/spaces/iamsuman/waste-detection/resolve/main/samples/mix2.jpg?download=true', 'https://huggingface.co/spaces/iamsuman/waste-detection/resolve/main/samples/mix11.jpg?download=true', 'https://huggingface.co/spaces/iamsuman/waste-detection/resolve/main/samples/sample_waste.mp4?download=true', ] def download_file(url, save_name): """Downloads a file from a URL, overwriting if it exists.""" print(f"Downloading from: {url}") try: response = requests.get(url, stream=True) response.raise_for_status() # Check for HTTP errors with open(save_name, 'wb') as file: for chunk in response.iter_content(1024): file.write(chunk) print(f"Downloaded and overwrote: {save_name}") except requests.exceptions.RequestException as e: print(f"Error downloading {url}: {e}") # Download sample images and video for the examples for i, url in enumerate(file_urls): if 'mp4' in url: download_file(url, "video.mp4") else: download_file(url, f"image_{i}.jpg") # --- Model and Class Configuration --- # Load your custom YOLO model # IMPORTANT: Replace 'best.pt' with the path to your model trained on the 12 classes. model = YOLO('best.pt') # Get class names and generate colors dynamically from the loaded model # This is the best practice as it ensures names and colors match the model's output. class_names = model.model.names class_colors = { name: (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) for name in class_names.values() } # Define paths for Gradio examples image_example_paths = [['image_0.jpg'], ['image_1.jpg']] video_example_path = [['video.mp4']] # --- Image Processing Function --- def show_preds_image(image_path): """Processes a single image and overlays YOLO predictions.""" image = cv2.imread(image_path) outputs = model.predict(source=image_path, verbose=False) results = outputs[0].cpu().numpy() # Convert to supervision Detections object for easier handling detections = sv.Detections.from_ultralytics(outputs[0]) # Annotate the image with bounding boxes and labels for i, (box, conf, cls) in enumerate(zip(detections.xyxy, detections.confidence, detections.class_id)): x1, y1, x2, y2 = map(int, box) class_name = class_names[cls] color = class_colors[class_name] # Draw bounding box cv2.rectangle(image, (x1, y1), (x2, y2), color=color, thickness=2, lineType=cv2.LINE_AA) # Create and display label label = f"{class_name}: {conf:.2f}" cv2.putText(image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2, cv2.LINE_AA) return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # --- Video Processing Function (with Supervision) --- def process_video_with_two_side_bins(video_path): generator = sv.get_video_frames_generator(video_path) try: first_frame = next(generator) except StopIteration: blank_frame = np.zeros((480, 640, 3), dtype=np.uint8) yield cv2.cvtColor(blank_frame, cv2.COLOR_BGR2RGB) return frame_height, frame_width, _ = first_frame.shape bins = [ { "name": "Recycle Bin", "coords": ( int(frame_width * 0.05), int(frame_height * 0.5), int(frame_width * 0.25), int(frame_height * 0.95), ), "color": (200, 16, 46), # Blue-ish }, { "name": "Trash Bin", "coords": ( int(frame_width * 0.75), int(frame_height * 0.5), int(frame_width * 0.95), int(frame_height * 0.95), ), "color": (50, 50, 50), # Red-ish }, ] box_annotator = sv.BoxAnnotator(thickness=2) label_annotator = sv.LabelAnnotator( text_scale=1.2, text_thickness=3, text_position=sv.Position.TOP_LEFT, ) tracker = sv.ByteTrack() items_in_bins = {bin_["name"]: set() for bin_ in bins} class_counts_per_bin = {bin_["name"]: defaultdict(int) for bin_ in bins} frame_number = 0 BATCH_SIZE = 10 LOGGED_OBJECT_TTL_SECONDS = 300 # 5 minutes insert_buffer = [] logged_objects = {} for frame in generator: frame_number += 1 current_time = time.time() # Prune old logged objects every BATCH_SIZE frames if frame_number % BATCH_SIZE == 0: keys_to_remove = [key for key, ts in logged_objects.items() if current_time - ts > LOGGED_OBJECT_TTL_SECONDS] for key in keys_to_remove: del logged_objects[key] results = model(frame, verbose=False)[0] detections = sv.Detections.from_ultralytics(results) tracked_detections = tracker.update_with_detections(detections) annotated_frame = frame.copy() # Draw bins and labels for bin_ in bins: x1, y1, x2, y2 = bin_["coords"] color = bin_["color"] cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), color=color, thickness=3) cv2.putText( annotated_frame, bin_["name"], (x1 + 5, y1 - 15), cv2.FONT_HERSHEY_SIMPLEX, 1.5, color, 3, cv2.LINE_AA, ) if tracked_detections.tracker_id is None: yield cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB) continue # Clear counts for this frame for bin_name in class_counts_per_bin: class_counts_per_bin[bin_name].clear() for box, track_id, class_id in zip( tracked_detections.xyxy, tracked_detections.tracker_id, tracked_detections.class_id, ): x1, y1, x2, y2 = map(int, box) cx = (x1 + x2) // 2 cy = (y1 + y2) // 2 class_name = class_names[class_id] for bin_ in bins: bx1, by1, bx2, by2 = bin_["coords"] bin_name = bin_["name"] if (bx1 <= cx <= bx2) and (by1 <= cy <= by2): key = (track_id, bin_name, class_name) if track_id not in items_in_bins[bin_name]: items_in_bins[bin_name].add(track_id) class_counts_per_bin[bin_name][class_name] += 1 if key not in logged_objects: timestamp = time.time() insert_buffer.append((timestamp, frame_number, bin_name, class_name, 1)) logged_objects[key] = current_time # Batch insert every BATCH_SIZE frames if frame_number % BATCH_SIZE == 0 and insert_buffer: cursor.executemany(''' INSERT INTO detections (timestamp, frame_number, bin_name, class_name, count) VALUES (?, ?, ?, ?, ?) ''', insert_buffer) conn.commit() insert_buffer.clear() labels = [ f"#{tid} {class_names[cid]}" for cid, tid in zip(tracked_detections.class_id, tracked_detections.tracker_id) ] annotated_frame = box_annotator.annotate( scene=annotated_frame, detections=tracked_detections ) annotated_frame = label_annotator.annotate( scene=annotated_frame, detections=tracked_detections, labels=labels ) # Display counts per bin y_pos = 50 for bin_name, class_count_dict in class_counts_per_bin.items(): text = ( f"{bin_name}: " + ", ".join(f"{cls}={count}" for cls, count in class_count_dict.items()) ) cv2.putText( annotated_frame, text, (30, y_pos), cv2.FONT_HERSHEY_SIMPLEX, 1.1, (255, 255, 255), 3, cv2.LINE_AA, ) y_pos += 40 yield cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB) # Insert any remaining buffered data at end if insert_buffer: cursor.executemany(''' INSERT INTO detections (timestamp, frame_number, bin_name, class_name, count) VALUES (?, ?, ?, ?, ?) ''', insert_buffer) conn.commit() insert_buffer.clear() # --- Gradio Interface Setup --- # Gradio Interface for Image Processing interface_image = gr.Interface( fn=show_preds_image, inputs=gr.Image(type="filepath", label="Input Image"), outputs=gr.Image(type="numpy", label="Output Image"), title="Waste Detection (Image)", description="Upload an image to see waste detection results.", examples=image_example_paths, cache_examples=False, ) # Gradio Interface for Video Processing interface_video = gr.Interface( fn=process_video_with_two_side_bins, inputs=gr.Video(label="Input Video"), outputs=gr.Image(type="numpy", label="Output Video Stream"), title="Waste Tracking and Counting (Video)", description="Upload a video to see real-time object tracking and counting.", examples=video_example_path, cache_examples=False, ) # Launch the Gradio App with separate tabs for each interface gr.TabbedInterface( [interface_image, interface_video], tab_names=['Image Inference', 'Video Inference'] ).queue().launch(debug=True)