iamsuman commited on
Commit
6e700f4
Β·
1 Parent(s): 0d23f0d

show prediction on video batch

Browse files
Files changed (1) hide show
  1. app.py +116 -53
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import cv2
3
  import requests
4
  import os
 
5
 
6
  from ultralytics import YOLO
7
 
@@ -91,68 +92,130 @@ interface_image = gr.Interface(
91
  examples=path,
92
  cache_examples=False,
93
  )
94
-
95
- def show_preds_video(video_path):
96
- results = model.track(source=video_path, persist=True, tracker="bytetrack.yaml", verbose=False, stream=True)
97
-
98
- ripe_ids = set()
99
- unripe_ids = set()
100
-
101
- # Get video frame dimensions for centering text
102
  cap = cv2.VideoCapture(video_path)
103
  if not cap.isOpened():
104
  print("Error: Could not open video.")
105
  return
 
106
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  cap.release()
 
108
 
109
- for output in results:
110
- frame = output.orig_img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
- if output.boxes and output.boxes.id is not None:
113
- names = model.model.names
114
- boxes = output.boxes
115
- ids = boxes.id.cpu().numpy().astype(int)
116
- classes = boxes.cls.cpu().numpy().astype(int)
117
-
118
- for box, cls, track_id in zip(boxes.xyxy, classes, ids):
119
- x1, y1, x2, y2 = map(int, box)
120
- class_name = names[cls]
121
-
122
- # Define BGR colors directly for OpenCV functions
123
- if class_name.lower() == "ripe":
124
- # To get RED in Gradio (RGB), you need to use (255, 0, 0) BGR
125
- # Note: You were using (0, 0, 255) which is Blue in RGB after conversion.
126
- color = (0, 0, 255)
127
- ripe_ids.add(track_id)
128
- else:
129
- # To get GREEN in Gradio (RGB), you need to use (0, 255, 0) BGR.
130
- # This color is already correct.
131
- color = (0, 255, 0)
132
- unripe_ids.add(track_id)
133
-
134
- cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
135
- cv2.putText(frame, f"{class_name.capitalize()} ID:{track_id}",
136
- (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
137
-
138
- ripe_count_text = f"Ripe: {len(ripe_ids)}"
139
- unripe_count_text = f"Unripe: {len(unripe_ids)}"
140
- full_text = f"{ripe_count_text} | {unripe_count_text}"
141
-
142
- # Get text size to center it
143
- (text_width, text_height), baseline = cv2.getTextSize(full_text, cv2.FONT_HERSHEY_SIMPLEX, 1, 2)
144
- text_x = (frame_width - text_width) // 2
145
- text_y = 40 # A fixed position at the top
146
-
147
- # Display the counts at the top center
148
- cv2.putText(frame, full_text, (text_x, text_y),
149
- cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
150
 
151
- # This line is crucial for the fix.
152
- # It correctly converts the frame from BGR to RGB for Gradio.
153
- yield cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
154
 
155
- print(f"Final Counts β†’ Ripe: {len(ripe_ids)}, Unripe: {len(unripe_ids)}")
 
 
156
 
157
  inputs_video = [
158
  gr.components.Video(label="Input Video"),
@@ -162,7 +225,7 @@ outputs_video = [
162
  gr.components.Image(type="numpy", label="Output Image"),
163
  ]
164
  interface_video = gr.Interface(
165
- fn=show_preds_video,
166
  inputs=inputs_video,
167
  outputs=outputs_video,
168
  title="Ripe And Unripe Tomatoes Detection",
 
2
  import cv2
3
  import requests
4
  import os
5
+ from collections import deque
6
 
7
  from ultralytics import YOLO
8
 
 
92
  examples=path,
93
  cache_examples=False,
94
  )
95
+ def show_preds_video_batch(video_path, batch_size=16):
 
 
 
 
 
 
 
96
  cap = cv2.VideoCapture(video_path)
97
  if not cap.isOpened():
98
  print("Error: Could not open video.")
99
  return
100
+
101
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
102
+ ripe_ids, unripe_ids = set(), set()
103
+ frame_buffer = deque()
104
+ names = model.model.names # cache model class names
105
+
106
+ def process_batch(frames, results):
107
+ nonlocal ripe_ids, unripe_ids
108
+ for frame, output in zip(frames, results):
109
+ if output.boxes and output.boxes.id is not None:
110
+ boxes = output.boxes
111
+ ids = boxes.id.cpu().numpy().astype(int)
112
+ classes = boxes.cls.cpu().numpy().astype(int)
113
+
114
+ for box, cls, track_id in zip(boxes.xyxy, classes, ids):
115
+ x1, y1, x2, y2 = map(int, box)
116
+ class_name = names[cls]
117
+ color = (0, 0, 255) if class_name.lower() == "ripe" else (0, 255, 0)
118
+
119
+ if class_name.lower() == "ripe":
120
+ ripe_ids.add(track_id)
121
+ else:
122
+ unripe_ids.add(track_id)
123
+
124
+ cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
125
+ cv2.putText(frame, f"{class_name.capitalize()} ID:{track_id}",
126
+ (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
127
+
128
+ # Draw total counts
129
+ full_text = f"Ripe: {len(ripe_ids)} | Unripe: {len(unripe_ids)}"
130
+ (text_width, _), _ = cv2.getTextSize(full_text, cv2.FONT_HERSHEY_SIMPLEX, 1, 2)
131
+ text_x = (frame_width - text_width) // 2
132
+ cv2.putText(frame, full_text, (text_x, 40),
133
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
134
+
135
+ yield cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
136
+
137
+ while True:
138
+ ret, frame = cap.read()
139
+ if not ret:
140
+ break
141
+ frame_buffer.append(frame)
142
+
143
+ if len(frame_buffer) == batch_size:
144
+ results = model.track(source=list(frame_buffer), persist=True, tracker="bytetrack.yaml", verbose=False)
145
+ yield from process_batch(frame_buffer, results)
146
+ frame_buffer.clear()
147
+
148
+ # process remaining frames
149
+ if frame_buffer:
150
+ results = model.track(source=list(frame_buffer), persist=True, tracker="bytetrack.yaml", verbose=False)
151
+ yield from process_batch(frame_buffer, results)
152
+
153
  cap.release()
154
+ print(f"Final Counts β†’ Ripe: {len(ripe_ids)}, Unripe: {len(unripe_ids)}")
155
 
156
+ # def show_preds_video(video_path):
157
+ # results = model.track(source=video_path, persist=True, tracker="bytetrack.yaml", verbose=False, stream=True)
158
+
159
+ # ripe_ids = set()
160
+ # unripe_ids = set()
161
+
162
+ # # Get video frame dimensions for centering text
163
+ # cap = cv2.VideoCapture(video_path)
164
+ # if not cap.isOpened():
165
+ # print("Error: Could not open video.")
166
+ # return
167
+ # frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
168
+ # cap.release()
169
+
170
+ # for output in results:
171
+ # frame = output.orig_img
172
 
173
+ # if output.boxes and output.boxes.id is not None:
174
+ # names = model.model.names
175
+ # boxes = output.boxes
176
+ # ids = boxes.id.cpu().numpy().astype(int)
177
+ # classes = boxes.cls.cpu().numpy().astype(int)
178
+
179
+ # for box, cls, track_id in zip(boxes.xyxy, classes, ids):
180
+ # x1, y1, x2, y2 = map(int, box)
181
+ # class_name = names[cls]
182
+
183
+ # # Define BGR colors directly for OpenCV functions
184
+ # if class_name.lower() == "ripe":
185
+ # # To get RED in Gradio (RGB), you need to use (255, 0, 0) BGR
186
+ # # Note: You were using (0, 0, 255) which is Blue in RGB after conversion.
187
+ # color = (0, 0, 255)
188
+ # ripe_ids.add(track_id)
189
+ # else:
190
+ # # To get GREEN in Gradio (RGB), you need to use (0, 255, 0) BGR.
191
+ # # This color is already correct.
192
+ # color = (0, 255, 0)
193
+ # unripe_ids.add(track_id)
194
+
195
+ # cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
196
+ # cv2.putText(frame, f"{class_name.capitalize()} ID:{track_id}",
197
+ # (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
198
+
199
+ # ripe_count_text = f"Ripe: {len(ripe_ids)}"
200
+ # unripe_count_text = f"Unripe: {len(unripe_ids)}"
201
+ # full_text = f"{ripe_count_text} | {unripe_count_text}"
202
+
203
+ # # Get text size to center it
204
+ # (text_width, text_height), baseline = cv2.getTextSize(full_text, cv2.FONT_HERSHEY_SIMPLEX, 1, 2)
205
+ # text_x = (frame_width - text_width) // 2
206
+ # text_y = 40 # A fixed position at the top
207
+
208
+ # # Display the counts at the top center
209
+ # cv2.putText(frame, full_text, (text_x, text_y),
210
+ # cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
211
 
212
+ # # This line is crucial for the fix.
213
+ # # It correctly converts the frame from BGR to RGB for Gradio.
214
+ # yield cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
215
 
216
+ # print(f"Final Counts β†’ Ripe: {len(ripe_ids)}, Unripe: {len(unripe_ids)}")
217
+
218
+
219
 
220
  inputs_video = [
221
  gr.components.Video(label="Input Video"),
 
225
  gr.components.Image(type="numpy", label="Output Image"),
226
  ]
227
  interface_video = gr.Interface(
228
+ fn=show_preds_video_batch,
229
  inputs=inputs_video,
230
  outputs=outputs_video,
231
  title="Ripe And Unripe Tomatoes Detection",