iamsuman commited on
Commit
55c29e2
·
verified ·
1 Parent(s): ed3a0b8

updated app file

Browse files
Files changed (1) hide show
  1. app.py +74 -114
app.py CHANGED
@@ -2,167 +2,127 @@ import gradio as gr
2
  import cv2
3
  import requests
4
  import os
5
-
6
  from ultralytics import YOLO
7
 
 
 
 
 
 
 
 
8
  file_urls = [
9
- 'https://huggingface.co/spaces/iamsuman/ripe-and-unripe-tomatoes-detection/resolve/main/samples/riped_tomato_93.jpeg?download=true',
10
- 'https://huggingface.co/spaces/iamsuman/ripe-and-unripe-tomatoes-detection/resolve/main/samples/unriped_tomato_18.jpeg?download=true',
11
- 'https://huggingface.co/spaces/iamsuman/ripe-and-unripe-tomatoes-detection/resolve/main/samples/tomatoes.mp4?download=true',
12
  ]
13
 
 
14
  def download_file(url, save_name):
15
- url = url
16
  if not os.path.exists(save_name):
17
  file = requests.get(url)
18
  open(save_name, 'wb').write(file.content)
19
 
 
20
  for i, url in enumerate(file_urls):
21
  if 'mp4' in file_urls[i]:
22
- download_file(
23
- file_urls[i],
24
- f"video.mp4"
25
- )
26
  else:
27
- download_file(
28
- file_urls[i],
29
- f"image_{i}.jpg"
30
- )
31
 
 
32
  model = YOLO('best.pt')
33
- path = [['image_0.jpg'], ['image_1.jpg']]
34
- video_path = [['video.mp4']]
35
-
36
-
37
 
 
 
 
38
 
 
39
  def show_preds_image(image_path):
40
  image = cv2.imread(image_path)
41
  outputs = model.predict(source=image_path)
42
  results = outputs[0].cpu().numpy()
43
 
44
- # Print the detected objects' information (class, coordinates, and probability)
45
- box = results[0].boxes
46
- names = model.model.names
47
  boxes = results.boxes
 
48
 
49
  for box, conf, cls in zip(boxes.xyxy, boxes.conf, boxes.cls):
50
-
51
  x1, y1, x2, y2 = map(int, box)
52
 
53
  class_name = names[int(cls)]
54
- print(class_name, "class_name", class_name.lower() == 'ripe')
55
- if class_name.lower() == 'ripe':
56
- color = (0, 0, 255) # Red for ripe
57
- else:
58
- color = (0, 255, 0) # Green for unripe
59
-
60
- # Draw rectangle around object
61
- cv2.rectangle(
62
- image,
63
- (x1, y1),
64
- (x2, y2),
65
- color=color,
66
- thickness=2,
67
- lineType=cv2.LINE_AA
68
- )
69
-
70
- # Display class label on top of rectangle
71
  label = f"{class_name.capitalize()}: {conf:.2f}"
72
- cv2.putText(image, label, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, # Use the same color as the rectangle
73
- 2,
74
- cv2.LINE_AA)
75
-
76
- # Convert image to RGB (Gradio expects RGB format)
77
  return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
 
 
 
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- inputs_image = [
81
- gr.components.Image(type="filepath", label="Input Image"),
82
- ]
83
- outputs_image = [
84
- gr.components.Image(type="numpy", label="Output Image"),
85
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  interface_image = gr.Interface(
87
  fn=show_preds_image,
88
  inputs=inputs_image,
89
  outputs=outputs_image,
90
- title="Ripe And Unripe Tomatoes Detection",
91
  examples=path,
92
  cache_examples=False,
93
  )
94
 
95
- def show_preds_video(video_path):
96
- cap = cv2.VideoCapture(video_path)
97
- while(cap.isOpened()):
98
- ret, frame = cap.read()
99
- if ret:
100
- frame_copy = frame.copy()
101
- outputs = model.predict(source=frame)
102
- results = outputs[0].cpu().numpy()
103
-
104
- boxes = results.boxes
105
- confidences = boxes.conf
106
- classes = boxes.cls
107
- names = model.model.names
108
-
109
- for box, conf, cls in zip(boxes.xyxy, confidences, classes):
110
- x1, y1, x2, y2 = map(int, box)
111
-
112
- # Determine color based on class
113
- class_name = names[int(cls)]
114
- if class_name.lower() == 'ripe':
115
- color = (0, 0, 255) # Red for ripe
116
- else:
117
- color = (0, 255, 0) # Green for unripe
118
-
119
- # Draw rectangle around object
120
- cv2.rectangle(
121
- frame_copy,
122
- (x1, y1),
123
- (x2, y2),
124
- color=color,
125
- thickness=2,
126
- lineType=cv2.LINE_AA
127
- )
128
-
129
- # Display class label on top of rectangle with capitalized class name
130
- label = f"{class_name.capitalize()}: {conf:.2f}"
131
- cv2.putText(
132
- frame_copy,
133
- label,
134
- (x1, y1 - 10), # Position slightly above the top of the rectangle
135
- cv2.FONT_HERSHEY_SIMPLEX,
136
- 0.5,
137
- color, # Use the same color as the rectangle
138
- 1,
139
- cv2.LINE_AA
140
- )
141
-
142
- # Convert frame to RGB (Gradio expects RGB format)
143
- yield cv2.cvtColor(frame_copy, cv2.COLOR_BGR2RGB)
144
- else:
145
- break
146
-
147
- cap.release()
148
-
149
- inputs_video = [
150
- gr.components.Video(label="Input Video"),
151
-
152
- ]
153
- outputs_video = [
154
- gr.components.Image(type="numpy", label="Output Image"),
155
- ]
156
  interface_video = gr.Interface(
157
  fn=show_preds_video,
158
  inputs=inputs_video,
159
  outputs=outputs_video,
160
- title="Ripe And Unripe Tomatoes Detection",
161
  examples=video_path,
162
  cache_examples=False,
163
  )
164
 
 
165
  gr.TabbedInterface(
166
  [interface_image, interface_video],
167
- tab_names=['Image inference', 'Video inference']
168
- ).queue().launch()
 
2
  import cv2
3
  import requests
4
  import os
5
+ import random
6
  from ultralytics import YOLO
7
 
8
+ # Define class names based on YOLO labels
9
+ class_names = {0: 'AluCan', 1: 'Glass', 2: 'PET', 3: 'HDPEM'}
10
+
11
+ # Generate random colors for each class
12
+ class_colors = {cls: (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) for cls in class_names}
13
+
14
+ # File URLs for sample images and video
15
  file_urls = [
16
+ 'https://huggingface.co/spaces/iamsuman/ripe-and-unripe-tomatoes-detection/resolve/main/samples/AluCan1,000.jpg?download=true',
17
+ 'https://huggingface.co/spaces/iamsuman/ripe-and-unripe-tomatoes-detection/resolve/main/samples/Glass847.jpg?download=true',
18
+ 'https://huggingface.co/spaces/iamsuman/ripe-and-unripe-tomatoes-detection/resolve/main/samples/sample_waste.mp4?download=true',
19
  ]
20
 
21
+ # Function to download files
22
  def download_file(url, save_name):
 
23
  if not os.path.exists(save_name):
24
  file = requests.get(url)
25
  open(save_name, 'wb').write(file.content)
26
 
27
+ # Download images and video
28
  for i, url in enumerate(file_urls):
29
  if 'mp4' in file_urls[i]:
30
+ download_file(file_urls[i], "video.mp4")
 
 
 
31
  else:
32
+ download_file(file_urls[i], f"image_{i}.jpg")
 
 
 
33
 
34
+ # Load YOLO model
35
  model = YOLO('best.pt')
 
 
 
 
36
 
37
+ # Sample paths
38
+ path = [['image_0.jpg'], ['image_1.jpg']]
39
+ video_path = [['video.mp4']]
40
 
41
+ # Function to process and display predictions on images
42
  def show_preds_image(image_path):
43
  image = cv2.imread(image_path)
44
  outputs = model.predict(source=image_path)
45
  results = outputs[0].cpu().numpy()
46
 
 
 
 
47
  boxes = results.boxes
48
+ names = model.model.names
49
 
50
  for box, conf, cls in zip(boxes.xyxy, boxes.conf, boxes.cls):
 
51
  x1, y1, x2, y2 = map(int, box)
52
 
53
  class_name = names[int(cls)]
54
+ color = class_colors.get(int(cls), (255, 255, 255)) # Default to white if class is unknown
55
+
56
+ # Draw bounding box
57
+ cv2.rectangle(image, (x1, y1), (x2, y2), color=color, thickness=2, lineType=cv2.LINE_AA)
58
+
59
+ # Display class label
 
 
 
 
 
 
 
 
 
 
 
60
  label = f"{class_name.capitalize()}: {conf:.2f}"
61
+ cv2.putText(image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2, cv2.LINE_AA)
62
+
 
 
 
63
  return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
64
+
65
+ # Function to process and display predictions on video
66
+ def show_preds_video(video_path):
67
+ cap = cv2.VideoCapture(video_path)
68
 
69
+ while cap.isOpened():
70
+ ret, frame = cap.read()
71
+ if not ret:
72
+ break
73
+
74
+ frame_copy = frame.copy()
75
+ outputs = model.predict(source=frame)
76
+ results = outputs[0].cpu().numpy()
77
+
78
+ boxes = results.boxes
79
+ confidences = boxes.conf
80
+ classes = boxes.cls
81
+ names = model.model.names
82
 
83
+ for box, conf, cls in zip(boxes.xyxy, confidences, classes):
84
+ x1, y1, x2, y2 = map(int, box)
85
+
86
+ class_name = names[int(cls)]
87
+ color = class_colors.get(int(cls), (255, 255, 255)) # Default to white if class is unknown
88
+
89
+ # Draw bounding box
90
+ cv2.rectangle(frame_copy, (x1, y1), (x2, y2), color=color, thickness=2, lineType=cv2.LINE_AA)
91
+
92
+ # Display class label
93
+ label = f"{class_name.capitalize()}: {conf:.2f}"
94
+ cv2.putText(frame_copy, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1, cv2.LINE_AA)
95
+
96
+ yield cv2.cvtColor(frame_copy, cv2.COLOR_BGR2RGB)
97
+
98
+ cap.release()
99
+
100
+ # Gradio Image Interface
101
+ inputs_image = [gr.Image(type="filepath", label="Input Image")]
102
+ outputs_image = [gr.Image(type="numpy", label="Output Image")]
103
  interface_image = gr.Interface(
104
  fn=show_preds_image,
105
  inputs=inputs_image,
106
  outputs=outputs_image,
107
+ title="Waste Detection",
108
  examples=path,
109
  cache_examples=False,
110
  )
111
 
112
+ # Gradio Video Interface
113
+ inputs_video = [gr.Video(label="Input Video")]
114
+ outputs_video = [gr.Image(type="numpy", label="Output Image")]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  interface_video = gr.Interface(
116
  fn=show_preds_video,
117
  inputs=inputs_video,
118
  outputs=outputs_video,
119
+ title="Waste Detection",
120
  examples=video_path,
121
  cache_examples=False,
122
  )
123
 
124
+ # Launch Gradio App
125
  gr.TabbedInterface(
126
  [interface_image, interface_video],
127
+ tab_names=['Image Inference', 'Video Inference']
128
+ ).queue().launch()