waste-detection / app.py
iamsuman's picture
support bins count
bb0cd33
raw
history blame
8.6 kB
import gradio as gr
import cv2
import requests
import os
import random
from ultralytics import YOLO
import numpy as np
from collections import defaultdict
# Import the supervision library
import supervision as sv
# --- File Downloading ---
# File URLs for sample images and video
file_urls = [
'https://huggingface.co/spaces/iamsuman/waste-detection/resolve/main/samples/mix2.jpg?download=true',
'https://huggingface.co/spaces/iamsuman/waste-detection/resolve/main/samples/mix11.jpg?download=true',
'https://huggingface.co/spaces/iamsuman/waste-detection/resolve/main/samples/sample_waste.mp4?download=true',
]
def download_file(url, save_name):
"""Downloads a file from a URL, overwriting if it exists."""
print(f"Downloading from: {url}")
try:
response = requests.get(url, stream=True)
response.raise_for_status() # Check for HTTP errors
with open(save_name, 'wb') as file:
for chunk in response.iter_content(1024):
file.write(chunk)
print(f"Downloaded and overwrote: {save_name}")
except requests.exceptions.RequestException as e:
print(f"Error downloading {url}: {e}")
# Download sample images and video for the examples
for i, url in enumerate(file_urls):
if 'mp4' in url:
download_file(url, "video.mp4")
else:
download_file(url, f"image_{i}.jpg")
# --- Model and Class Configuration ---
# Load your custom YOLO model
# IMPORTANT: Replace 'best.pt' with the path to your model trained on the 12 classes.
model = YOLO('best.pt')
# Get class names and generate colors dynamically from the loaded model
# This is the best practice as it ensures names and colors match the model's output.
class_names = model.model.names
class_colors = {
name: (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
for name in class_names.values()
}
# Define paths for Gradio examples
image_example_paths = [['image_0.jpg'], ['image_1.jpg']]
video_example_path = [['video.mp4']]
# --- Image Processing Function ---
def show_preds_image(image_path):
"""Processes a single image and overlays YOLO predictions."""
image = cv2.imread(image_path)
outputs = model.predict(source=image_path, verbose=False)
results = outputs[0].cpu().numpy()
# Convert to supervision Detections object for easier handling
detections = sv.Detections.from_ultralytics(outputs[0])
# Annotate the image with bounding boxes and labels
for i, (box, conf, cls) in enumerate(zip(detections.xyxy, detections.confidence, detections.class_id)):
x1, y1, x2, y2 = map(int, box)
class_name = class_names[cls]
color = class_colors[class_name]
# Draw bounding box
cv2.rectangle(image, (x1, y1), (x2, y2), color=color, thickness=2, lineType=cv2.LINE_AA)
# Create and display label
label = f"{class_name}: {conf:.2f}"
cv2.putText(image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2, cv2.LINE_AA)
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# --- Video Processing Function (with Supervision) ---
def process_video_with_two_side_bins(video_path):
if video_path is None:
return
generator = sv.get_video_frames_generator(video_path)
try:
first_frame = next(generator)
except StopIteration:
print("No frames found in the provided video input.")
# Option 1: Return or yield a blank frame or error image
# For example, yield a blank black image of fixed size:
blank_frame = np.zeros((480, 640, 3), dtype=np.uint8)
yield cv2.cvtColor(blank_frame, cv2.COLOR_BGR2RGB)
return
first_frame = next(generator)
frame_height, frame_width, _ = first_frame.shape
# Define two bins: recyle and trash sides
bins = [
{
"name": "Recycle Bin",
"coords": (
int(frame_width * 0.05),
int(frame_height * 0.5),
int(frame_width * 0.25),
int(frame_height * 0.95),
),
"color": (200, 16, 46), # Blue-ish
},
{
"name": "Trash Bin",
"coords": (
int(frame_width * 0.75),
int(frame_height * 0.5),
int(frame_width * 0.95),
int(frame_height * 0.95),
),
"color": (50, 50, 50), # Red-ish
},
]
box_annotator = sv.BoxAnnotator(thickness=2)
label_annotator = sv.LabelAnnotator(
text_scale=1.2, # bigger text size
text_thickness=3,
text_position=sv.Position.TOP_LEFT,
)
tracker = sv.ByteTrack()
items_in_bins = {bin_["name"]: set() for bin_ in bins}
class_counts_per_bin = {bin_["name"]: defaultdict(int) for bin_ in bins}
for i, frame in enumerate(generator):
results = model(frame, verbose=False)[0]
detections = sv.Detections.from_ultralytics(results)
tracked_detections = tracker.update_with_detections(detections)
annotated_frame = frame.copy()
# Draw bins and bigger labels
for bin_ in bins:
x1, y1, x2, y2 = bin_["coords"]
color = bin_["color"]
cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), color=color, thickness=3)
cv2.putText(
annotated_frame,
bin_["name"],
(x1 + 5, y1 - 15),
cv2.FONT_HERSHEY_SIMPLEX,
1.5, # bigger font
color,
3,
cv2.LINE_AA,
)
if tracked_detections.tracker_id is None:
yield cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
continue
for box, track_id, class_id in zip(
tracked_detections.xyxy,
tracked_detections.tracker_id,
tracked_detections.class_id,
):
x1, y1, x2, y2 = map(int, box)
cx = (x1 + x2) // 2
cy = (y1 + y2) // 2
for bin_ in bins:
bx1, by1, bx2, by2 = bin_["coords"]
if (bx1 <= cx <= bx2) and (by1 <= cy <= by2):
if track_id not in items_in_bins[bin_["name"]]:
items_in_bins[bin_["name"]].add(track_id)
class_name = class_names[class_id]
class_counts_per_bin[bin_["name"]][class_name] += 1
labels = [
f"#{tid} {class_names[cid]}"
for cid, tid in zip(tracked_detections.class_id, tracked_detections.tracker_id)
]
annotated_frame = box_annotator.annotate(
scene=annotated_frame, detections=tracked_detections
)
annotated_frame = label_annotator.annotate(
scene=annotated_frame, detections=tracked_detections, labels=labels
)
# Show counts per bin with bigger font
y_pos = 50
for bin_name, class_count_dict in class_counts_per_bin.items():
text = (
f"{bin_name}: "
+ ", ".join(f"{cls}={count}" for cls, count in class_count_dict.items())
)
cv2.putText(
annotated_frame,
text,
(30, y_pos),
cv2.FONT_HERSHEY_SIMPLEX,
1.1, # bigger font for counts
(255, 255, 255),
3,
cv2.LINE_AA,
)
y_pos += 40
yield cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
# --- Gradio Interface Setup ---
# Gradio Interface for Image Processing
interface_image = gr.Interface(
fn=show_preds_image,
inputs=gr.Image(type="filepath", label="Input Image"),
outputs=gr.Image(type="numpy", label="Output Image"),
title="Waste Detection (Image)",
description="Upload an image to see waste detection results.",
examples=image_example_paths,
cache_examples=False,
)
# Gradio Interface for Video Processing
interface_video = gr.Interface(
fn=process_video_with_two_side_bins,
inputs=gr.Video(label="Input Video"),
outputs=gr.Image(type="numpy", label="Output Video Stream"),
title="Waste Tracking and Counting (Video)",
description="Upload a video to see real-time object tracking and counting.",
examples=video_example_path,
cache_examples=False,
)
# Launch the Gradio App with separate tabs for each interface
gr.TabbedInterface(
[interface_image, interface_video],
tab_names=['Image Inference', 'Video Inference']
).queue().launch(debug=True)