File size: 10,236 Bytes
c05d727
 
 
 
55c29e2
c05d727
bb0cd33
 
0a6ffff
 
c05d727
bb0cd33
 
55c29e2
0a6ffff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb0cd33
55c29e2
c05d727
970733c
 
35cc57d
c05d727
 
 
bb0cd33
 
35cc57d
 
 
 
 
 
bb0cd33
35cc57d
 
c05d727
bb0cd33
c05d727
bb0cd33
 
c05d727
bb0cd33
c05d727
bb0cd33
 
 
c05d727
4b5873a
bb0cd33
 
 
 
 
 
 
 
 
 
 
 
4b5873a
bb0cd33
c05d727
bb0cd33
c05d727
bb0cd33
c05d727
4b5873a
bb0cd33
 
4b5873a
bb0cd33
 
4b5873a
bb0cd33
 
 
55c29e2
 
 
bb0cd33
 
 
55c29e2
bb0cd33
55c29e2
 
bb0cd33
 
 
55c29e2
bb0cd33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a6ffff
bb0cd33
 
 
 
 
 
 
 
 
0a6ffff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb0cd33
 
 
 
 
 
0a6ffff
bb0cd33
 
 
 
 
 
 
 
 
0a6ffff
bb0cd33
 
 
 
 
 
 
 
 
0a6ffff
 
 
 
bb0cd33
 
 
 
 
 
 
 
0a6ffff
bb0cd33
 
 
0a6ffff
 
bb0cd33
0a6ffff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb0cd33
 
 
 
 
 
 
 
 
 
 
 
 
0a6ffff
bb0cd33
 
 
 
 
 
 
 
 
 
 
0a6ffff
bb0cd33
 
 
 
 
 
 
 
0a6ffff
 
 
 
 
 
 
 
 
bb0cd33
 
 
 
c05d727
 
bb0cd33
 
 
 
 
c05d727
 
 
bb0cd33
c05d727
bb0cd33
 
 
 
 
 
c05d727
 
 
bb0cd33
c05d727
 
55c29e2
bb0cd33
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
import gradio as gr
import cv2
import requests
import os
import random
from ultralytics import YOLO
import numpy as np
from collections import defaultdict
import sqlite3
import time

# Import the supervision library
import supervision as sv


# --- Initialize SQLite DB for logging ---
conn = sqlite3.connect("detection_log.db", check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
    CREATE TABLE IF NOT EXISTS detections (
        timestamp REAL,
        frame_number INTEGER,
        bin_name TEXT,
        class_name TEXT,
        count INTEGER
    )
''')
conn.commit()

# --- File Downloading ---
# File URLs for sample images and video
file_urls = [
    'https://huggingface.co/spaces/iamsuman/waste-detection/resolve/main/samples/mix2.jpg?download=true',
    'https://huggingface.co/spaces/iamsuman/waste-detection/resolve/main/samples/mix11.jpg?download=true',
    'https://huggingface.co/spaces/iamsuman/waste-detection/resolve/main/samples/sample_waste.mp4?download=true',
]

def download_file(url, save_name):
    """Downloads a file from a URL, overwriting if it exists."""
    print(f"Downloading from: {url}")
    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()  # Check for HTTP errors
        with open(save_name, 'wb') as file:
            for chunk in response.iter_content(1024):
                file.write(chunk)
        print(f"Downloaded and overwrote: {save_name}")
    except requests.exceptions.RequestException as e:
        print(f"Error downloading {url}: {e}")

# Download sample images and video for the examples
for i, url in enumerate(file_urls):
    if 'mp4' in url:
        download_file(url, "video.mp4")
    else:
        download_file(url, f"image_{i}.jpg")

# --- Model and Class Configuration ---
# Load your custom YOLO model
# IMPORTANT: Replace 'best.pt' with the path to your model trained on the 12 classes.
model = YOLO('best.pt')

# Get class names and generate colors dynamically from the loaded model
# This is the best practice as it ensures names and colors match the model's output.
class_names = model.model.names
class_colors = {
    name: (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
    for name in class_names.values()
}

# Define paths for Gradio examples
image_example_paths = [['image_0.jpg'], ['image_1.jpg']]
video_example_path = [['video.mp4']]


# --- Image Processing Function ---
def show_preds_image(image_path):
    """Processes a single image and overlays YOLO predictions."""
    image = cv2.imread(image_path)
    outputs = model.predict(source=image_path, verbose=False)
    results = outputs[0].cpu().numpy()

    # Convert to supervision Detections object for easier handling
    detections = sv.Detections.from_ultralytics(outputs[0])

    # Annotate the image with bounding boxes and labels
    for i, (box, conf, cls) in enumerate(zip(detections.xyxy, detections.confidence, detections.class_id)):
        x1, y1, x2, y2 = map(int, box)
        class_name = class_names[cls]
        color = class_colors[class_name]
        
        # Draw bounding box
        cv2.rectangle(image, (x1, y1), (x2, y2), color=color, thickness=2, lineType=cv2.LINE_AA)
        
        # Create and display label
        label = f"{class_name}: {conf:.2f}"
        cv2.putText(image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2, cv2.LINE_AA)

    return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)


# --- Video Processing Function (with Supervision) ---
def process_video_with_two_side_bins(video_path):
    generator = sv.get_video_frames_generator(video_path)

    try:
        first_frame = next(generator)
    except StopIteration:
        blank_frame = np.zeros((480, 640, 3), dtype=np.uint8)
        yield cv2.cvtColor(blank_frame, cv2.COLOR_BGR2RGB)
        return
    
    frame_height, frame_width, _ = first_frame.shape

    bins = [
        {
            "name": "Recycle Bin",
            "coords": (
                int(frame_width * 0.05),
                int(frame_height * 0.5),
                int(frame_width * 0.25),
                int(frame_height * 0.95),
            ),
            "color": (200, 16, 46),  # Blue-ish
        },
        {
            "name": "Trash Bin",
            "coords": (
                int(frame_width * 0.75),
                int(frame_height * 0.5),
                int(frame_width * 0.95),
                int(frame_height * 0.95),
            ),
            "color": (50, 50, 50),  # Red-ish
        },
    ]

    box_annotator = sv.BoxAnnotator(thickness=2)
    label_annotator = sv.LabelAnnotator(
        text_scale=1.2,
        text_thickness=3,
        text_position=sv.Position.TOP_LEFT,
    )

    tracker = sv.ByteTrack()

    items_in_bins = {bin_["name"]: set() for bin_ in bins}
    class_counts_per_bin = {bin_["name"]: defaultdict(int) for bin_ in bins}

    frame_number = 0
    BATCH_SIZE = 10
    LOGGED_OBJECT_TTL_SECONDS = 300  # 5 minutes
    insert_buffer = []
    logged_objects = {}

    for frame in generator:
        frame_number += 1
        current_time = time.time()

        # Prune old logged objects every BATCH_SIZE frames
        if frame_number % BATCH_SIZE == 0:
            keys_to_remove = [key for key, ts in logged_objects.items()
                            if current_time - ts > LOGGED_OBJECT_TTL_SECONDS]
            for key in keys_to_remove:
                del logged_objects[key]

        results = model(frame, verbose=False)[0]
        detections = sv.Detections.from_ultralytics(results)
        tracked_detections = tracker.update_with_detections(detections)

        annotated_frame = frame.copy()

        # Draw bins and labels
        for bin_ in bins:
            x1, y1, x2, y2 = bin_["coords"]
            color = bin_["color"]
            cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), color=color, thickness=3)
            cv2.putText(
                annotated_frame,
                bin_["name"],
                (x1 + 5, y1 - 15),
                cv2.FONT_HERSHEY_SIMPLEX,
                1.5,
                color,
                3,
                cv2.LINE_AA,
            )

        if tracked_detections.tracker_id is None:
            yield cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
            continue

        # Clear counts for this frame
        for bin_name in class_counts_per_bin:
            class_counts_per_bin[bin_name].clear()

        for box, track_id, class_id in zip(
            tracked_detections.xyxy,
            tracked_detections.tracker_id,
            tracked_detections.class_id,
        ):
            x1, y1, x2, y2 = map(int, box)
            cx = (x1 + x2) // 2
            cy = (y1 + y2) // 2
            class_name = class_names[class_id]

            for bin_ in bins:
                bx1, by1, bx2, by2 = bin_["coords"]
                bin_name = bin_["name"]

                if (bx1 <= cx <= bx2) and (by1 <= cy <= by2):
                    key = (track_id, bin_name, class_name)

                    if track_id not in items_in_bins[bin_name]:
                        items_in_bins[bin_name].add(track_id)
                        class_counts_per_bin[bin_name][class_name] += 1

                    if key not in logged_objects:
                        timestamp = time.time()
                        insert_buffer.append((timestamp, frame_number, bin_name, class_name, 1))
                        logged_objects[key] = current_time

        # Batch insert every BATCH_SIZE frames
        if frame_number % BATCH_SIZE == 0 and insert_buffer:
            cursor.executemany('''
                INSERT INTO detections (timestamp, frame_number, bin_name, class_name, count)
                VALUES (?, ?, ?, ?, ?)
            ''', insert_buffer)
            conn.commit()
            insert_buffer.clear()

        labels = [
            f"#{tid} {class_names[cid]}"
            for cid, tid in zip(tracked_detections.class_id, tracked_detections.tracker_id)
        ]

        annotated_frame = box_annotator.annotate(
            scene=annotated_frame, detections=tracked_detections
        )
        annotated_frame = label_annotator.annotate(
            scene=annotated_frame, detections=tracked_detections, labels=labels
        )

        # Display counts per bin
        y_pos = 50
        for bin_name, class_count_dict in class_counts_per_bin.items():
            text = (
                f"{bin_name}: "
                + ", ".join(f"{cls}={count}" for cls, count in class_count_dict.items())
            )
            cv2.putText(
                annotated_frame,
                text,
                (30, y_pos),
                cv2.FONT_HERSHEY_SIMPLEX,
                1.1,
                (255, 255, 255),
                3,
                cv2.LINE_AA,
            )
            y_pos += 40

        yield cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)

    # Insert any remaining buffered data at end
    if insert_buffer:
        cursor.executemany('''
            INSERT INTO detections (timestamp, frame_number, bin_name, class_name, count)
            VALUES (?, ?, ?, ?, ?)
        ''', insert_buffer)
        conn.commit()
        insert_buffer.clear()



# --- Gradio Interface Setup ---
# Gradio Interface for Image Processing
interface_image = gr.Interface(
    fn=show_preds_image,
    inputs=gr.Image(type="filepath", label="Input Image"),
    outputs=gr.Image(type="numpy", label="Output Image"),
    title="Waste Detection (Image)",
    description="Upload an image to see waste detection results.",
    examples=image_example_paths,
    cache_examples=False,
)

# Gradio Interface for Video Processing
interface_video = gr.Interface(
    fn=process_video_with_two_side_bins,
    inputs=gr.Video(label="Input Video"),
    outputs=gr.Image(type="numpy", label="Output Video Stream"),
    title="Waste Tracking and Counting (Video)",
    description="Upload a video to see real-time object tracking and counting.",
    examples=video_example_path,
    cache_examples=False,
)

# Launch the Gradio App with separate tabs for each interface
gr.TabbedInterface(
    [interface_image, interface_video],
    tab_names=['Image Inference', 'Video Inference']
).queue().launch(debug=True)