show prediction on video batch
Browse files
app.py
CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
|
|
2 |
import cv2
|
3 |
import requests
|
4 |
import os
|
|
|
5 |
|
6 |
from ultralytics import YOLO
|
7 |
|
@@ -91,68 +92,130 @@ interface_image = gr.Interface(
|
|
91 |
examples=path,
|
92 |
cache_examples=False,
|
93 |
)
|
94 |
-
|
95 |
-
def show_preds_video(video_path):
|
96 |
-
results = model.track(source=video_path, persist=True, tracker="bytetrack.yaml", verbose=False, stream=True)
|
97 |
-
|
98 |
-
ripe_ids = set()
|
99 |
-
unripe_ids = set()
|
100 |
-
|
101 |
-
# Get video frame dimensions for centering text
|
102 |
cap = cv2.VideoCapture(video_path)
|
103 |
if not cap.isOpened():
|
104 |
print("Error: Could not open video.")
|
105 |
return
|
|
|
106 |
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
cap.release()
|
|
|
108 |
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
|
155 |
-
|
|
|
|
|
156 |
|
157 |
inputs_video = [
|
158 |
gr.components.Video(label="Input Video"),
|
@@ -162,7 +225,7 @@ outputs_video = [
|
|
162 |
gr.components.Image(type="numpy", label="Output Image"),
|
163 |
]
|
164 |
interface_video = gr.Interface(
|
165 |
-
fn=
|
166 |
inputs=inputs_video,
|
167 |
outputs=outputs_video,
|
168 |
title="Ripe And Unripe Tomatoes Detection",
|
|
|
2 |
import cv2
|
3 |
import requests
|
4 |
import os
|
5 |
+
from collections import deque
|
6 |
|
7 |
from ultralytics import YOLO
|
8 |
|
|
|
92 |
examples=path,
|
93 |
cache_examples=False,
|
94 |
)
|
95 |
+
def show_preds_video_batch(video_path, batch_size=16):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
cap = cv2.VideoCapture(video_path)
|
97 |
if not cap.isOpened():
|
98 |
print("Error: Could not open video.")
|
99 |
return
|
100 |
+
|
101 |
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
102 |
+
ripe_ids, unripe_ids = set(), set()
|
103 |
+
frame_buffer = deque()
|
104 |
+
names = model.model.names # cache model class names
|
105 |
+
|
106 |
+
def process_batch(frames, results):
|
107 |
+
nonlocal ripe_ids, unripe_ids
|
108 |
+
for frame, output in zip(frames, results):
|
109 |
+
if output.boxes and output.boxes.id is not None:
|
110 |
+
boxes = output.boxes
|
111 |
+
ids = boxes.id.cpu().numpy().astype(int)
|
112 |
+
classes = boxes.cls.cpu().numpy().astype(int)
|
113 |
+
|
114 |
+
for box, cls, track_id in zip(boxes.xyxy, classes, ids):
|
115 |
+
x1, y1, x2, y2 = map(int, box)
|
116 |
+
class_name = names[cls]
|
117 |
+
color = (0, 0, 255) if class_name.lower() == "ripe" else (0, 255, 0)
|
118 |
+
|
119 |
+
if class_name.lower() == "ripe":
|
120 |
+
ripe_ids.add(track_id)
|
121 |
+
else:
|
122 |
+
unripe_ids.add(track_id)
|
123 |
+
|
124 |
+
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
|
125 |
+
cv2.putText(frame, f"{class_name.capitalize()} ID:{track_id}",
|
126 |
+
(x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
127 |
+
|
128 |
+
# Draw total counts
|
129 |
+
full_text = f"Ripe: {len(ripe_ids)} | Unripe: {len(unripe_ids)}"
|
130 |
+
(text_width, _), _ = cv2.getTextSize(full_text, cv2.FONT_HERSHEY_SIMPLEX, 1, 2)
|
131 |
+
text_x = (frame_width - text_width) // 2
|
132 |
+
cv2.putText(frame, full_text, (text_x, 40),
|
133 |
+
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
|
134 |
+
|
135 |
+
yield cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
136 |
+
|
137 |
+
while True:
|
138 |
+
ret, frame = cap.read()
|
139 |
+
if not ret:
|
140 |
+
break
|
141 |
+
frame_buffer.append(frame)
|
142 |
+
|
143 |
+
if len(frame_buffer) == batch_size:
|
144 |
+
results = model.track(source=list(frame_buffer), persist=True, tracker="bytetrack.yaml", verbose=False)
|
145 |
+
yield from process_batch(frame_buffer, results)
|
146 |
+
frame_buffer.clear()
|
147 |
+
|
148 |
+
# process remaining frames
|
149 |
+
if frame_buffer:
|
150 |
+
results = model.track(source=list(frame_buffer), persist=True, tracker="bytetrack.yaml", verbose=False)
|
151 |
+
yield from process_batch(frame_buffer, results)
|
152 |
+
|
153 |
cap.release()
|
154 |
+
print(f"Final Counts β Ripe: {len(ripe_ids)}, Unripe: {len(unripe_ids)}")
|
155 |
|
156 |
+
# def show_preds_video(video_path):
|
157 |
+
# results = model.track(source=video_path, persist=True, tracker="bytetrack.yaml", verbose=False, stream=True)
|
158 |
+
|
159 |
+
# ripe_ids = set()
|
160 |
+
# unripe_ids = set()
|
161 |
+
|
162 |
+
# # Get video frame dimensions for centering text
|
163 |
+
# cap = cv2.VideoCapture(video_path)
|
164 |
+
# if not cap.isOpened():
|
165 |
+
# print("Error: Could not open video.")
|
166 |
+
# return
|
167 |
+
# frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
168 |
+
# cap.release()
|
169 |
+
|
170 |
+
# for output in results:
|
171 |
+
# frame = output.orig_img
|
172 |
|
173 |
+
# if output.boxes and output.boxes.id is not None:
|
174 |
+
# names = model.model.names
|
175 |
+
# boxes = output.boxes
|
176 |
+
# ids = boxes.id.cpu().numpy().astype(int)
|
177 |
+
# classes = boxes.cls.cpu().numpy().astype(int)
|
178 |
+
|
179 |
+
# for box, cls, track_id in zip(boxes.xyxy, classes, ids):
|
180 |
+
# x1, y1, x2, y2 = map(int, box)
|
181 |
+
# class_name = names[cls]
|
182 |
+
|
183 |
+
# # Define BGR colors directly for OpenCV functions
|
184 |
+
# if class_name.lower() == "ripe":
|
185 |
+
# # To get RED in Gradio (RGB), you need to use (255, 0, 0) BGR
|
186 |
+
# # Note: You were using (0, 0, 255) which is Blue in RGB after conversion.
|
187 |
+
# color = (0, 0, 255)
|
188 |
+
# ripe_ids.add(track_id)
|
189 |
+
# else:
|
190 |
+
# # To get GREEN in Gradio (RGB), you need to use (0, 255, 0) BGR.
|
191 |
+
# # This color is already correct.
|
192 |
+
# color = (0, 255, 0)
|
193 |
+
# unripe_ids.add(track_id)
|
194 |
+
|
195 |
+
# cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
|
196 |
+
# cv2.putText(frame, f"{class_name.capitalize()} ID:{track_id}",
|
197 |
+
# (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
198 |
+
|
199 |
+
# ripe_count_text = f"Ripe: {len(ripe_ids)}"
|
200 |
+
# unripe_count_text = f"Unripe: {len(unripe_ids)}"
|
201 |
+
# full_text = f"{ripe_count_text} | {unripe_count_text}"
|
202 |
+
|
203 |
+
# # Get text size to center it
|
204 |
+
# (text_width, text_height), baseline = cv2.getTextSize(full_text, cv2.FONT_HERSHEY_SIMPLEX, 1, 2)
|
205 |
+
# text_x = (frame_width - text_width) // 2
|
206 |
+
# text_y = 40 # A fixed position at the top
|
207 |
+
|
208 |
+
# # Display the counts at the top center
|
209 |
+
# cv2.putText(frame, full_text, (text_x, text_y),
|
210 |
+
# cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
|
211 |
|
212 |
+
# # This line is crucial for the fix.
|
213 |
+
# # It correctly converts the frame from BGR to RGB for Gradio.
|
214 |
+
# yield cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
215 |
|
216 |
+
# print(f"Final Counts β Ripe: {len(ripe_ids)}, Unripe: {len(unripe_ids)}")
|
217 |
+
|
218 |
+
|
219 |
|
220 |
inputs_video = [
|
221 |
gr.components.Video(label="Input Video"),
|
|
|
225 |
gr.components.Image(type="numpy", label="Output Image"),
|
226 |
]
|
227 |
interface_video = gr.Interface(
|
228 |
+
fn=show_preds_video_batch,
|
229 |
inputs=inputs_video,
|
230 |
outputs=outputs_video,
|
231 |
title="Ripe And Unripe Tomatoes Detection",
|