File size: 8,601 Bytes
c05d727
 
 
 
55c29e2
c05d727
bb0cd33
 
c05d727
bb0cd33
 
55c29e2
bb0cd33
55c29e2
c05d727
970733c
 
35cc57d
c05d727
 
 
bb0cd33
 
35cc57d
 
 
 
 
 
bb0cd33
35cc57d
 
c05d727
bb0cd33
c05d727
bb0cd33
 
c05d727
bb0cd33
c05d727
bb0cd33
 
 
c05d727
4b5873a
bb0cd33
 
 
 
 
 
 
 
 
 
 
 
4b5873a
bb0cd33
c05d727
bb0cd33
c05d727
bb0cd33
c05d727
4b5873a
bb0cd33
 
4b5873a
bb0cd33
 
4b5873a
bb0cd33
 
 
55c29e2
 
 
bb0cd33
 
 
55c29e2
bb0cd33
55c29e2
 
bb0cd33
 
55c29e2
bb0cd33
 
 
 
55c29e2
bb0cd33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c05d727
 
bb0cd33
 
 
 
 
c05d727
 
 
bb0cd33
c05d727
bb0cd33
 
 
 
 
 
c05d727
 
 
bb0cd33
c05d727
 
55c29e2
bb0cd33
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
import gradio as gr
import cv2
import requests
import os
import random
from ultralytics import YOLO
import numpy as np
from collections import defaultdict

# Import the supervision library
import supervision as sv

# --- File Downloading ---
# File URLs for sample images and video
file_urls = [
    'https://huggingface.co/spaces/iamsuman/waste-detection/resolve/main/samples/mix2.jpg?download=true',
    'https://huggingface.co/spaces/iamsuman/waste-detection/resolve/main/samples/mix11.jpg?download=true',
    'https://huggingface.co/spaces/iamsuman/waste-detection/resolve/main/samples/sample_waste.mp4?download=true',
]

def download_file(url, save_name):
    """Downloads a file from a URL, overwriting if it exists."""
    print(f"Downloading from: {url}")
    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()  # Check for HTTP errors
        with open(save_name, 'wb') as file:
            for chunk in response.iter_content(1024):
                file.write(chunk)
        print(f"Downloaded and overwrote: {save_name}")
    except requests.exceptions.RequestException as e:
        print(f"Error downloading {url}: {e}")

# Download sample images and video for the examples
for i, url in enumerate(file_urls):
    if 'mp4' in url:
        download_file(url, "video.mp4")
    else:
        download_file(url, f"image_{i}.jpg")

# --- Model and Class Configuration ---
# Load your custom YOLO model
# IMPORTANT: Replace 'best.pt' with the path to your model trained on the 12 classes.
model = YOLO('best.pt')

# Get class names and generate colors dynamically from the loaded model
# This is the best practice as it ensures names and colors match the model's output.
class_names = model.model.names
class_colors = {
    name: (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
    for name in class_names.values()
}

# Define paths for Gradio examples
image_example_paths = [['image_0.jpg'], ['image_1.jpg']]
video_example_path = [['video.mp4']]


# --- Image Processing Function ---
def show_preds_image(image_path):
    """Processes a single image and overlays YOLO predictions."""
    image = cv2.imread(image_path)
    outputs = model.predict(source=image_path, verbose=False)
    results = outputs[0].cpu().numpy()

    # Convert to supervision Detections object for easier handling
    detections = sv.Detections.from_ultralytics(outputs[0])

    # Annotate the image with bounding boxes and labels
    for i, (box, conf, cls) in enumerate(zip(detections.xyxy, detections.confidence, detections.class_id)):
        x1, y1, x2, y2 = map(int, box)
        class_name = class_names[cls]
        color = class_colors[class_name]
        
        # Draw bounding box
        cv2.rectangle(image, (x1, y1), (x2, y2), color=color, thickness=2, lineType=cv2.LINE_AA)
        
        # Create and display label
        label = f"{class_name}: {conf:.2f}"
        cv2.putText(image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2, cv2.LINE_AA)

    return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)


# --- Video Processing Function (with Supervision) ---
def process_video_with_two_side_bins(video_path):

    if video_path is None:
        return
    
    generator = sv.get_video_frames_generator(video_path)

    try:
        first_frame = next(generator)
    except StopIteration:
        print("No frames found in the provided video input.")
        # Option 1: Return or yield a blank frame or error image
        # For example, yield a blank black image of fixed size:
        blank_frame = np.zeros((480, 640, 3), dtype=np.uint8)
        yield cv2.cvtColor(blank_frame, cv2.COLOR_BGR2RGB)
        return
    
    first_frame = next(generator)
    frame_height, frame_width, _ = first_frame.shape

    # Define two bins: recyle and trash sides

    bins = [
        {
            "name": "Recycle Bin",
            "coords": (
                int(frame_width * 0.05),
                int(frame_height * 0.5),
                int(frame_width * 0.25),
                int(frame_height * 0.95),
            ),
            "color": (200, 16, 46),  # Blue-ish
        },
        {
            "name": "Trash Bin",
            "coords": (
                int(frame_width * 0.75),
                int(frame_height * 0.5),
                int(frame_width * 0.95),
                int(frame_height * 0.95),
            ),
            "color": (50, 50, 50),  # Red-ish
        },
    ]

    box_annotator = sv.BoxAnnotator(thickness=2)
    label_annotator = sv.LabelAnnotator(
        text_scale=1.2,  # bigger text size
        text_thickness=3,
        text_position=sv.Position.TOP_LEFT,
    )

    tracker = sv.ByteTrack()

    items_in_bins = {bin_["name"]: set() for bin_ in bins}
    class_counts_per_bin = {bin_["name"]: defaultdict(int) for bin_ in bins}

    for i, frame in enumerate(generator):
        results = model(frame, verbose=False)[0]
        detections = sv.Detections.from_ultralytics(results)
        tracked_detections = tracker.update_with_detections(detections)

        annotated_frame = frame.copy()

        # Draw bins and bigger labels
        for bin_ in bins:
            x1, y1, x2, y2 = bin_["coords"]
            color = bin_["color"]
            cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), color=color, thickness=3)
            cv2.putText(
                annotated_frame,
                bin_["name"],
                (x1 + 5, y1 - 15),
                cv2.FONT_HERSHEY_SIMPLEX,
                1.5,  # bigger font
                color,
                3,
                cv2.LINE_AA,
            )

        if tracked_detections.tracker_id is None:
            yield cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
            continue

        for box, track_id, class_id in zip(
            tracked_detections.xyxy,
            tracked_detections.tracker_id,
            tracked_detections.class_id,
        ):
            x1, y1, x2, y2 = map(int, box)
            cx = (x1 + x2) // 2
            cy = (y1 + y2) // 2

            for bin_ in bins:
                bx1, by1, bx2, by2 = bin_["coords"]
                if (bx1 <= cx <= bx2) and (by1 <= cy <= by2):
                    if track_id not in items_in_bins[bin_["name"]]:
                        items_in_bins[bin_["name"]].add(track_id)
                        class_name = class_names[class_id]
                        class_counts_per_bin[bin_["name"]][class_name] += 1

        labels = [
            f"#{tid} {class_names[cid]}"
            for cid, tid in zip(tracked_detections.class_id, tracked_detections.tracker_id)
        ]

        annotated_frame = box_annotator.annotate(
            scene=annotated_frame, detections=tracked_detections
        )
        annotated_frame = label_annotator.annotate(
            scene=annotated_frame, detections=tracked_detections, labels=labels
        )

        # Show counts per bin with bigger font
        y_pos = 50
        for bin_name, class_count_dict in class_counts_per_bin.items():
            text = (
                f"{bin_name}: "
                + ", ".join(f"{cls}={count}" for cls, count in class_count_dict.items())
            )
            cv2.putText(
                annotated_frame,
                text,
                (30, y_pos),
                cv2.FONT_HERSHEY_SIMPLEX,
                1.1,  # bigger font for counts
                (255, 255, 255),
                3,
                cv2.LINE_AA,
            )
            y_pos += 40

        yield cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)



# --- Gradio Interface Setup ---
# Gradio Interface for Image Processing
interface_image = gr.Interface(
    fn=show_preds_image,
    inputs=gr.Image(type="filepath", label="Input Image"),
    outputs=gr.Image(type="numpy", label="Output Image"),
    title="Waste Detection (Image)",
    description="Upload an image to see waste detection results.",
    examples=image_example_paths,
    cache_examples=False,
)

# Gradio Interface for Video Processing
interface_video = gr.Interface(
    fn=process_video_with_two_side_bins,
    inputs=gr.Video(label="Input Video"),
    outputs=gr.Image(type="numpy", label="Output Video Stream"),
    title="Waste Tracking and Counting (Video)",
    description="Upload a video to see real-time object tracking and counting.",
    examples=video_example_path,
    cache_examples=False,
)

# Launch the Gradio App with separate tabs for each interface
gr.TabbedInterface(
    [interface_image, interface_video],
    tab_names=['Image Inference', 'Video Inference']
).queue().launch(debug=True)