Spaces:
Paused
Paused
import gradio as gr | |
import cv2 | |
import requests | |
import os | |
import random | |
from ultralytics import YOLO | |
import numpy as np | |
from collections import defaultdict | |
# Import the supervision library | |
import supervision as sv | |
# --- File Downloading --- | |
# File URLs for sample images and video | |
file_urls = [ | |
'https://huggingface.co/spaces/iamsuman/waste-detection/resolve/main/samples/mix2.jpg?download=true', | |
'https://huggingface.co/spaces/iamsuman/waste-detection/resolve/main/samples/mix11.jpg?download=true', | |
'https://huggingface.co/spaces/iamsuman/waste-detection/resolve/main/samples/sample_waste.mp4?download=true', | |
] | |
def download_file(url, save_name): | |
"""Downloads a file from a URL, overwriting if it exists.""" | |
print(f"Downloading from: {url}") | |
try: | |
response = requests.get(url, stream=True) | |
response.raise_for_status() # Check for HTTP errors | |
with open(save_name, 'wb') as file: | |
for chunk in response.iter_content(1024): | |
file.write(chunk) | |
print(f"Downloaded and overwrote: {save_name}") | |
except requests.exceptions.RequestException as e: | |
print(f"Error downloading {url}: {e}") | |
# Download sample images and video for the examples | |
for i, url in enumerate(file_urls): | |
if 'mp4' in url: | |
download_file(url, "video.mp4") | |
else: | |
download_file(url, f"image_{i}.jpg") | |
# --- Model and Class Configuration --- | |
# Load your custom YOLO model | |
# IMPORTANT: Replace 'best.pt' with the path to your model trained on the 12 classes. | |
model = YOLO('best.pt') | |
# Get class names and generate colors dynamically from the loaded model | |
# This is the best practice as it ensures names and colors match the model's output. | |
class_names = model.model.names | |
class_colors = { | |
name: (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) | |
for name in class_names.values() | |
} | |
# Define paths for Gradio examples | |
image_example_paths = [['image_0.jpg'], ['image_1.jpg']] | |
video_example_path = [['video.mp4']] | |
# --- Image Processing Function --- | |
def show_preds_image(image_path): | |
"""Processes a single image and overlays YOLO predictions.""" | |
image = cv2.imread(image_path) | |
outputs = model.predict(source=image_path, verbose=False) | |
results = outputs[0].cpu().numpy() | |
# Convert to supervision Detections object for easier handling | |
detections = sv.Detections.from_ultralytics(outputs[0]) | |
# Annotate the image with bounding boxes and labels | |
for i, (box, conf, cls) in enumerate(zip(detections.xyxy, detections.confidence, detections.class_id)): | |
x1, y1, x2, y2 = map(int, box) | |
class_name = class_names[cls] | |
color = class_colors[class_name] | |
# Draw bounding box | |
cv2.rectangle(image, (x1, y1), (x2, y2), color=color, thickness=2, lineType=cv2.LINE_AA) | |
# Create and display label | |
label = f"{class_name}: {conf:.2f}" | |
cv2.putText(image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2, cv2.LINE_AA) | |
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
# --- Video Processing Function (with Supervision) --- | |
def process_video_with_two_side_bins(video_path): | |
if video_path is None: | |
return | |
generator = sv.get_video_frames_generator(video_path) | |
try: | |
first_frame = next(generator) | |
except StopIteration: | |
print("No frames found in the provided video input.") | |
# Option 1: Return or yield a blank frame or error image | |
# For example, yield a blank black image of fixed size: | |
blank_frame = np.zeros((480, 640, 3), dtype=np.uint8) | |
yield cv2.cvtColor(blank_frame, cv2.COLOR_BGR2RGB) | |
return | |
first_frame = next(generator) | |
frame_height, frame_width, _ = first_frame.shape | |
# Define two bins: recyle and trash sides | |
bins = [ | |
{ | |
"name": "Recycle Bin", | |
"coords": ( | |
int(frame_width * 0.05), | |
int(frame_height * 0.5), | |
int(frame_width * 0.25), | |
int(frame_height * 0.95), | |
), | |
"color": (200, 16, 46), # Blue-ish | |
}, | |
{ | |
"name": "Trash Bin", | |
"coords": ( | |
int(frame_width * 0.75), | |
int(frame_height * 0.5), | |
int(frame_width * 0.95), | |
int(frame_height * 0.95), | |
), | |
"color": (50, 50, 50), # Red-ish | |
}, | |
] | |
box_annotator = sv.BoxAnnotator(thickness=2) | |
label_annotator = sv.LabelAnnotator( | |
text_scale=1.2, # bigger text size | |
text_thickness=3, | |
text_position=sv.Position.TOP_LEFT, | |
) | |
tracker = sv.ByteTrack() | |
items_in_bins = {bin_["name"]: set() for bin_ in bins} | |
class_counts_per_bin = {bin_["name"]: defaultdict(int) for bin_ in bins} | |
for i, frame in enumerate(generator): | |
results = model(frame, verbose=False)[0] | |
detections = sv.Detections.from_ultralytics(results) | |
tracked_detections = tracker.update_with_detections(detections) | |
annotated_frame = frame.copy() | |
# Draw bins and bigger labels | |
for bin_ in bins: | |
x1, y1, x2, y2 = bin_["coords"] | |
color = bin_["color"] | |
cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), color=color, thickness=3) | |
cv2.putText( | |
annotated_frame, | |
bin_["name"], | |
(x1 + 5, y1 - 15), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
1.5, # bigger font | |
color, | |
3, | |
cv2.LINE_AA, | |
) | |
if tracked_detections.tracker_id is None: | |
yield cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB) | |
continue | |
for box, track_id, class_id in zip( | |
tracked_detections.xyxy, | |
tracked_detections.tracker_id, | |
tracked_detections.class_id, | |
): | |
x1, y1, x2, y2 = map(int, box) | |
cx = (x1 + x2) // 2 | |
cy = (y1 + y2) // 2 | |
for bin_ in bins: | |
bx1, by1, bx2, by2 = bin_["coords"] | |
if (bx1 <= cx <= bx2) and (by1 <= cy <= by2): | |
if track_id not in items_in_bins[bin_["name"]]: | |
items_in_bins[bin_["name"]].add(track_id) | |
class_name = class_names[class_id] | |
class_counts_per_bin[bin_["name"]][class_name] += 1 | |
labels = [ | |
f"#{tid} {class_names[cid]}" | |
for cid, tid in zip(tracked_detections.class_id, tracked_detections.tracker_id) | |
] | |
annotated_frame = box_annotator.annotate( | |
scene=annotated_frame, detections=tracked_detections | |
) | |
annotated_frame = label_annotator.annotate( | |
scene=annotated_frame, detections=tracked_detections, labels=labels | |
) | |
# Show counts per bin with bigger font | |
y_pos = 50 | |
for bin_name, class_count_dict in class_counts_per_bin.items(): | |
text = ( | |
f"{bin_name}: " | |
+ ", ".join(f"{cls}={count}" for cls, count in class_count_dict.items()) | |
) | |
cv2.putText( | |
annotated_frame, | |
text, | |
(30, y_pos), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
1.1, # bigger font for counts | |
(255, 255, 255), | |
3, | |
cv2.LINE_AA, | |
) | |
y_pos += 40 | |
yield cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB) | |
# --- Gradio Interface Setup --- | |
# Gradio Interface for Image Processing | |
interface_image = gr.Interface( | |
fn=show_preds_image, | |
inputs=gr.Image(type="filepath", label="Input Image"), | |
outputs=gr.Image(type="numpy", label="Output Image"), | |
title="Waste Detection (Image)", | |
description="Upload an image to see waste detection results.", | |
examples=image_example_paths, | |
cache_examples=False, | |
) | |
# Gradio Interface for Video Processing | |
interface_video = gr.Interface( | |
fn=process_video_with_two_side_bins, | |
inputs=gr.Video(label="Input Video"), | |
outputs=gr.Image(type="numpy", label="Output Video Stream"), | |
title="Waste Tracking and Counting (Video)", | |
description="Upload a video to see real-time object tracking and counting.", | |
examples=video_example_path, | |
cache_examples=False, | |
) | |
# Launch the Gradio App with separate tabs for each interface | |
gr.TabbedInterface( | |
[interface_image, interface_video], | |
tab_names=['Image Inference', 'Video Inference'] | |
).queue().launch(debug=True) |