iamsuman commited on
Commit
bb0cd33
·
1 Parent(s): 7df41a5

support bins count

Browse files
Files changed (3) hide show
  1. app.py +192 -79
  2. best.pt +2 -2
  3. requirements.txt +6 -46
app.py CHANGED
@@ -4,13 +4,13 @@ import requests
4
  import os
5
  import random
6
  from ultralytics import YOLO
 
 
7
 
8
- # Define class names based on YOLO labels
9
- class_names = {0: 'AluCan', 1: 'Glass', 2: 'PET', 3: 'HDPEM'}
10
-
11
- # Generate random colors for each class
12
- class_colors = {cls: (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) for cls in class_names}
13
 
 
14
  # File URLs for sample images and video
15
  file_urls = [
16
  'https://huggingface.co/spaces/iamsuman/waste-detection/resolve/main/samples/mix2.jpg?download=true',
@@ -18,119 +18,232 @@ file_urls = [
18
  'https://huggingface.co/spaces/iamsuman/waste-detection/resolve/main/samples/sample_waste.mp4?download=true',
19
  ]
20
 
21
- # Function to download files (always overwrites existing ones)
22
  def download_file(url, save_name):
23
- print(f"Downloading from: {url}") # Log the URL
 
24
  try:
25
  response = requests.get(url, stream=True)
26
  response.raise_for_status() # Check for HTTP errors
27
  with open(save_name, 'wb') as file:
28
  for chunk in response.iter_content(1024):
29
  file.write(chunk)
30
- print(f"Downloaded and overwritten: {save_name}")
31
  except requests.exceptions.RequestException as e:
32
  print(f"Error downloading {url}: {e}")
33
 
34
- # Download images and video
35
  for i, url in enumerate(file_urls):
36
- print(i, url)
37
- if 'mp4' in file_urls[i]:
38
- download_file(file_urls[i], f"video.mp4")
39
  else:
40
- download_file(file_urls[i], f"image_{i}.jpg")
41
 
42
- # Load YOLO model
 
 
43
  model = YOLO('best.pt')
44
 
45
- # Sample paths
46
- path = [['image_0.jpg'], ['image_1.jpg']]
47
- video_path = [['video.mp4']]
 
 
 
 
 
 
 
 
 
48
 
49
- # Function to process and display predictions on images
50
  def show_preds_image(image_path):
 
51
  image = cv2.imread(image_path)
52
- outputs = model.predict(source=image_path)
53
  results = outputs[0].cpu().numpy()
54
 
55
- boxes = results.boxes
56
- names = model.model.names
57
 
58
- for box, conf, cls in zip(boxes.xyxy, boxes.conf, boxes.cls):
 
59
  x1, y1, x2, y2 = map(int, box)
60
-
61
- class_name = names[int(cls)]
62
- color = class_colors.get(int(cls), (255, 255, 255)) # Default to white if class is unknown
63
-
64
  # Draw bounding box
65
  cv2.rectangle(image, (x1, y1), (x2, y2), color=color, thickness=2, lineType=cv2.LINE_AA)
66
-
67
- # Display class label
68
- label = f"{class_name.capitalize()}: {conf:.2f}"
69
- cv2.putText(image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2, cv2.LINE_AA)
70
-
71
- return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
72
-
73
- # Function to process and display predictions on video
74
- def show_preds_video(video_path):
75
- cap = cv2.VideoCapture(video_path)
76
-
77
- while cap.isOpened():
78
- ret, frame = cap.read()
79
- if not ret:
80
- break
81
-
82
- frame_copy = frame.copy()
83
- outputs = model.predict(source=frame)
84
- results = outputs[0].cpu().numpy()
85
 
86
- boxes = results.boxes
87
- confidences = boxes.conf
88
- classes = boxes.cls
89
- names = model.model.names
90
-
91
- for box, conf, cls in zip(boxes.xyxy, confidences, classes):
92
- x1, y1, x2, y2 = map(int, box)
93
-
94
- class_name = names[int(cls)]
95
- color = class_colors.get(int(cls), (255, 255, 255)) # Default to white if class is unknown
96
 
97
- # Draw bounding box
98
- cv2.rectangle(frame_copy, (x1, y1), (x2, y2), color=color, thickness=2, lineType=cv2.LINE_AA)
99
 
100
- # Display class label
101
- label = f"{class_name.capitalize()}: {conf:.2f}"
102
- cv2.putText(frame_copy, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1, cv2.LINE_AA)
103
 
104
- yield cv2.cvtColor(frame_copy, cv2.COLOR_BGR2RGB)
 
105
 
106
- cap.release()
 
 
 
107
 
108
- # Gradio Image Interface
109
- inputs_image = [gr.Image(type="filepath", label="Input Image")]
110
- outputs_image = [gr.Image(type="numpy", label="Output Image")]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  interface_image = gr.Interface(
112
  fn=show_preds_image,
113
- inputs=inputs_image,
114
- outputs=outputs_image,
115
- title="Waste Detection",
116
- examples=path,
 
117
  cache_examples=False,
118
  )
119
 
120
- # Gradio Video Interface
121
- inputs_video = [gr.Video(label="Input Video")]
122
- outputs_video = [gr.Image(type="numpy", label="Output Image")]
123
  interface_video = gr.Interface(
124
- fn=show_preds_video,
125
- inputs=inputs_video,
126
- outputs=outputs_video,
127
- title="Waste Detection",
128
- examples=video_path,
 
129
  cache_examples=False,
130
  )
131
 
132
- # Launch Gradio App
133
  gr.TabbedInterface(
134
  [interface_image, interface_video],
135
  tab_names=['Image Inference', 'Video Inference']
136
- ).queue().launch()
 
4
  import os
5
  import random
6
  from ultralytics import YOLO
7
+ import numpy as np
8
+ from collections import defaultdict
9
 
10
+ # Import the supervision library
11
+ import supervision as sv
 
 
 
12
 
13
+ # --- File Downloading ---
14
  # File URLs for sample images and video
15
  file_urls = [
16
  'https://huggingface.co/spaces/iamsuman/waste-detection/resolve/main/samples/mix2.jpg?download=true',
 
18
  'https://huggingface.co/spaces/iamsuman/waste-detection/resolve/main/samples/sample_waste.mp4?download=true',
19
  ]
20
 
 
21
  def download_file(url, save_name):
22
+ """Downloads a file from a URL, overwriting if it exists."""
23
+ print(f"Downloading from: {url}")
24
  try:
25
  response = requests.get(url, stream=True)
26
  response.raise_for_status() # Check for HTTP errors
27
  with open(save_name, 'wb') as file:
28
  for chunk in response.iter_content(1024):
29
  file.write(chunk)
30
+ print(f"Downloaded and overwrote: {save_name}")
31
  except requests.exceptions.RequestException as e:
32
  print(f"Error downloading {url}: {e}")
33
 
34
+ # Download sample images and video for the examples
35
  for i, url in enumerate(file_urls):
36
+ if 'mp4' in url:
37
+ download_file(url, "video.mp4")
 
38
  else:
39
+ download_file(url, f"image_{i}.jpg")
40
 
41
+ # --- Model and Class Configuration ---
42
+ # Load your custom YOLO model
43
+ # IMPORTANT: Replace 'best.pt' with the path to your model trained on the 12 classes.
44
  model = YOLO('best.pt')
45
 
46
+ # Get class names and generate colors dynamically from the loaded model
47
+ # This is the best practice as it ensures names and colors match the model's output.
48
+ class_names = model.model.names
49
+ class_colors = {
50
+ name: (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
51
+ for name in class_names.values()
52
+ }
53
+
54
+ # Define paths for Gradio examples
55
+ image_example_paths = [['image_0.jpg'], ['image_1.jpg']]
56
+ video_example_path = [['video.mp4']]
57
+
58
 
59
+ # --- Image Processing Function ---
60
  def show_preds_image(image_path):
61
+ """Processes a single image and overlays YOLO predictions."""
62
  image = cv2.imread(image_path)
63
+ outputs = model.predict(source=image_path, verbose=False)
64
  results = outputs[0].cpu().numpy()
65
 
66
+ # Convert to supervision Detections object for easier handling
67
+ detections = sv.Detections.from_ultralytics(outputs[0])
68
 
69
+ # Annotate the image with bounding boxes and labels
70
+ for i, (box, conf, cls) in enumerate(zip(detections.xyxy, detections.confidence, detections.class_id)):
71
  x1, y1, x2, y2 = map(int, box)
72
+ class_name = class_names[cls]
73
+ color = class_colors[class_name]
74
+
 
75
  # Draw bounding box
76
  cv2.rectangle(image, (x1, y1), (x2, y2), color=color, thickness=2, lineType=cv2.LINE_AA)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ # Create and display label
79
+ label = f"{class_name}: {conf:.2f}"
80
+ cv2.putText(image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2, cv2.LINE_AA)
 
 
 
 
 
 
 
81
 
82
+ return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
83
 
 
 
 
84
 
85
+ # --- Video Processing Function (with Supervision) ---
86
+ def process_video_with_two_side_bins(video_path):
87
 
88
+ if video_path is None:
89
+ return
90
+
91
+ generator = sv.get_video_frames_generator(video_path)
92
 
93
+ try:
94
+ first_frame = next(generator)
95
+ except StopIteration:
96
+ print("No frames found in the provided video input.")
97
+ # Option 1: Return or yield a blank frame or error image
98
+ # For example, yield a blank black image of fixed size:
99
+ blank_frame = np.zeros((480, 640, 3), dtype=np.uint8)
100
+ yield cv2.cvtColor(blank_frame, cv2.COLOR_BGR2RGB)
101
+ return
102
+
103
+ first_frame = next(generator)
104
+ frame_height, frame_width, _ = first_frame.shape
105
+
106
+ # Define two bins: recyle and trash sides
107
+
108
+ bins = [
109
+ {
110
+ "name": "Recycle Bin",
111
+ "coords": (
112
+ int(frame_width * 0.05),
113
+ int(frame_height * 0.5),
114
+ int(frame_width * 0.25),
115
+ int(frame_height * 0.95),
116
+ ),
117
+ "color": (200, 16, 46), # Blue-ish
118
+ },
119
+ {
120
+ "name": "Trash Bin",
121
+ "coords": (
122
+ int(frame_width * 0.75),
123
+ int(frame_height * 0.5),
124
+ int(frame_width * 0.95),
125
+ int(frame_height * 0.95),
126
+ ),
127
+ "color": (50, 50, 50), # Red-ish
128
+ },
129
+ ]
130
+
131
+ box_annotator = sv.BoxAnnotator(thickness=2)
132
+ label_annotator = sv.LabelAnnotator(
133
+ text_scale=1.2, # bigger text size
134
+ text_thickness=3,
135
+ text_position=sv.Position.TOP_LEFT,
136
+ )
137
+
138
+ tracker = sv.ByteTrack()
139
+
140
+ items_in_bins = {bin_["name"]: set() for bin_ in bins}
141
+ class_counts_per_bin = {bin_["name"]: defaultdict(int) for bin_ in bins}
142
+
143
+ for i, frame in enumerate(generator):
144
+ results = model(frame, verbose=False)[0]
145
+ detections = sv.Detections.from_ultralytics(results)
146
+ tracked_detections = tracker.update_with_detections(detections)
147
+
148
+ annotated_frame = frame.copy()
149
+
150
+ # Draw bins and bigger labels
151
+ for bin_ in bins:
152
+ x1, y1, x2, y2 = bin_["coords"]
153
+ color = bin_["color"]
154
+ cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), color=color, thickness=3)
155
+ cv2.putText(
156
+ annotated_frame,
157
+ bin_["name"],
158
+ (x1 + 5, y1 - 15),
159
+ cv2.FONT_HERSHEY_SIMPLEX,
160
+ 1.5, # bigger font
161
+ color,
162
+ 3,
163
+ cv2.LINE_AA,
164
+ )
165
+
166
+ if tracked_detections.tracker_id is None:
167
+ yield cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
168
+ continue
169
+
170
+ for box, track_id, class_id in zip(
171
+ tracked_detections.xyxy,
172
+ tracked_detections.tracker_id,
173
+ tracked_detections.class_id,
174
+ ):
175
+ x1, y1, x2, y2 = map(int, box)
176
+ cx = (x1 + x2) // 2
177
+ cy = (y1 + y2) // 2
178
+
179
+ for bin_ in bins:
180
+ bx1, by1, bx2, by2 = bin_["coords"]
181
+ if (bx1 <= cx <= bx2) and (by1 <= cy <= by2):
182
+ if track_id not in items_in_bins[bin_["name"]]:
183
+ items_in_bins[bin_["name"]].add(track_id)
184
+ class_name = class_names[class_id]
185
+ class_counts_per_bin[bin_["name"]][class_name] += 1
186
+
187
+ labels = [
188
+ f"#{tid} {class_names[cid]}"
189
+ for cid, tid in zip(tracked_detections.class_id, tracked_detections.tracker_id)
190
+ ]
191
+
192
+ annotated_frame = box_annotator.annotate(
193
+ scene=annotated_frame, detections=tracked_detections
194
+ )
195
+ annotated_frame = label_annotator.annotate(
196
+ scene=annotated_frame, detections=tracked_detections, labels=labels
197
+ )
198
+
199
+ # Show counts per bin with bigger font
200
+ y_pos = 50
201
+ for bin_name, class_count_dict in class_counts_per_bin.items():
202
+ text = (
203
+ f"{bin_name}: "
204
+ + ", ".join(f"{cls}={count}" for cls, count in class_count_dict.items())
205
+ )
206
+ cv2.putText(
207
+ annotated_frame,
208
+ text,
209
+ (30, y_pos),
210
+ cv2.FONT_HERSHEY_SIMPLEX,
211
+ 1.1, # bigger font for counts
212
+ (255, 255, 255),
213
+ 3,
214
+ cv2.LINE_AA,
215
+ )
216
+ y_pos += 40
217
+
218
+ yield cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
219
+
220
+
221
+
222
+ # --- Gradio Interface Setup ---
223
+ # Gradio Interface for Image Processing
224
  interface_image = gr.Interface(
225
  fn=show_preds_image,
226
+ inputs=gr.Image(type="filepath", label="Input Image"),
227
+ outputs=gr.Image(type="numpy", label="Output Image"),
228
+ title="Waste Detection (Image)",
229
+ description="Upload an image to see waste detection results.",
230
+ examples=image_example_paths,
231
  cache_examples=False,
232
  )
233
 
234
+ # Gradio Interface for Video Processing
 
 
235
  interface_video = gr.Interface(
236
+ fn=process_video_with_two_side_bins,
237
+ inputs=gr.Video(label="Input Video"),
238
+ outputs=gr.Image(type="numpy", label="Output Video Stream"),
239
+ title="Waste Tracking and Counting (Video)",
240
+ description="Upload a video to see real-time object tracking and counting.",
241
+ examples=video_example_path,
242
  cache_examples=False,
243
  )
244
 
245
+ # Launch the Gradio App with separate tabs for each interface
246
  gr.TabbedInterface(
247
  [interface_image, interface_video],
248
  tab_names=['Image Inference', 'Video Inference']
249
+ ).queue().launch(debug=True)
best.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6ba9e2e924f6acd401c8ef4c4d2b10fe80049dcbcf46597be2196e03986271d6
3
- size 6211875
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea3a57f09eaca3340d951c0f5fcf90604f9544cf0a2643f949a182616e936af1
3
+ size 6215843
requirements.txt CHANGED
@@ -1,47 +1,7 @@
1
- # Ultralytics requirements
2
- # Usage: pip install -r requirements.txt
3
-
4
- # Base ----------------------------------------
5
- hydra-core>=1.2.0
6
- matplotlib>=3.2.2
7
- numpy>=1.18.5
8
- opencv-python>=4.1.1
9
- Pillow>=7.1.2
10
- PyYAML>=5.3.1
11
- requests>=2.23.0
12
- scipy>=1.4.1
13
- torch>=1.7.0
14
- torchvision>=0.8.1
15
- tqdm>=4.64.0
16
  ultralytics
17
-
18
- # Logging -------------------------------------
19
- tensorboard>=2.4.1
20
- # clearml
21
- # comet
22
-
23
- # Plotting ------------------------------------
24
- pandas>=1.1.4
25
- seaborn>=0.11.0
26
-
27
- # Export --------------------------------------
28
- # coremltools>=6.0 # CoreML export
29
- # onnx>=1.12.0 # ONNX export
30
- # onnx-simplifier>=0.4.1 # ONNX simplifier
31
- # nvidia-pyindex # TensorRT export
32
- # nvidia-tensorrt # TensorRT export
33
- # scikit-learn==0.19.2 # CoreML quantization
34
- # tensorflow>=2.4.1 # TF exports (-cpu, -aarch64, -macos)
35
- # tensorflowjs>=3.9.0 # TF.js export
36
- # openvino-dev # OpenVINO export
37
-
38
- # Extras --------------------------------------
39
- ipython # interactive notebook
40
- psutil # system utilization
41
- thop>=0.1.1 # FLOPs computation
42
- # albumentations>=1.0.3
43
- # pycocotools>=2.0.6 # COCO mAP
44
- # roboflow
45
-
46
- # HUB -----------------------------------------
47
- GitPython>=3.1.24
 
1
+ gradio==4.10.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  ultralytics
3
+ Pillow
4
+ pydantic==2.8.2
5
+ pydantic-core==2.20.1
6
+ fastapi==0.112.4
7
+ supervision>=0.26.1 # YOLOv8 inference