Upload 2 files
Browse files- lane_detection.py +282 -0
- streamlit_app.py +280 -0
lane_detection.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import time
|
4 |
+
from ultralytics import YOLO
|
5 |
+
import torch
|
6 |
+
from collections import defaultdict
|
7 |
+
import pandas as pd
|
8 |
+
from typing import Dict, List, Tuple
|
9 |
+
|
10 |
+
# LABEL_MAP = {
|
11 |
+
# 0: "auto",
|
12 |
+
# 1: "bus",
|
13 |
+
# 2: "car",
|
14 |
+
# 3: "motorcycle",
|
15 |
+
# 4: "mini-bus",
|
16 |
+
# 5: "scooter",
|
17 |
+
# 6: "truck",
|
18 |
+
# }
|
19 |
+
LABEL_MAP = {0: "auto", 1: "bus", 2: "car", 3: "electric-rickshaw", 4: "large-sized-truck",5:'medium-sized-truck',6:'motorbike',7:'small-sized-truck'}
|
20 |
+
|
21 |
+
def draw_text_with_background(
|
22 |
+
image,
|
23 |
+
text,
|
24 |
+
position,
|
25 |
+
font=cv2.FONT_HERSHEY_SIMPLEX,
|
26 |
+
font_scale=1,
|
27 |
+
font_thickness=2,
|
28 |
+
text_color=(255, 255, 255),
|
29 |
+
bg_color=(0, 0, 0),
|
30 |
+
padding=5,
|
31 |
+
):
|
32 |
+
"""Draw `text` on `image` with a filled rectangle behind it."""
|
33 |
+
(text_width, text_height), baseline = cv2.getTextSize(
|
34 |
+
text, font, font_scale, font_thickness
|
35 |
+
)
|
36 |
+
x, y = position
|
37 |
+
rect_y1 = y - text_height - padding - baseline // 2
|
38 |
+
rect_y2 = y + padding - baseline // 2
|
39 |
+
|
40 |
+
cv2.rectangle(
|
41 |
+
image,
|
42 |
+
(x, rect_y1),
|
43 |
+
(x + text_width + 2 * padding, rect_y2),
|
44 |
+
bg_color,
|
45 |
+
-1,
|
46 |
+
)
|
47 |
+
cv2.putText(
|
48 |
+
image,
|
49 |
+
text,
|
50 |
+
(x + padding, y - baseline // 2),
|
51 |
+
font,
|
52 |
+
font_scale,
|
53 |
+
text_color,
|
54 |
+
font_thickness,
|
55 |
+
cv2.LINE_AA,
|
56 |
+
)
|
57 |
+
|
58 |
+
|
59 |
+
def get_color_for_class(cls_id: int):
|
60 |
+
"""Deterministic bright color for each class index."""
|
61 |
+
np.random.seed(cls_id + 37)
|
62 |
+
return tuple(np.random.randint(100, 256, size=3).tolist())
|
63 |
+
|
64 |
+
|
65 |
+
def _inside(pt: Tuple[int, int], poly: np.ndarray) -> bool:
|
66 |
+
"""Point-in-polygon test using OpenCV (nonβzero if inside)."""
|
67 |
+
return cv2.pointPolygonTest(poly, pt, False) >= 0
|
68 |
+
|
69 |
+
|
70 |
+
class YOLOVideoDetector:
|
71 |
+
"""
|
72 |
+
Detect objects on a video and count them **per region**.
|
73 |
+
|
74 |
+
* `regions`: Dict[int, List[Tuple[int,int]]], mapping region id (0,1, β¦) to
|
75 |
+
4+ vertices in *pixel* coordinates (clockwise or anticlockwise).
|
76 |
+
* For each frame, counts are stored in a DataFrame column named
|
77 |
+
`<label>_<region>` (e.g. `car_0`, `bus_1`).
|
78 |
+
"""
|
79 |
+
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
model_path: str,
|
83 |
+
video_path: str,
|
84 |
+
output_path: str,
|
85 |
+
regions: Dict[int, List[Tuple[int, int]]],
|
86 |
+
classes=None,
|
87 |
+
conf: float = 0.35,
|
88 |
+
scale_factor: float = 1.5,
|
89 |
+
):
|
90 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
91 |
+
print(f"Using device: {self.device}")
|
92 |
+
|
93 |
+
self.model = YOLO(model_path)
|
94 |
+
self.video_path = video_path
|
95 |
+
self.output_path = output_path
|
96 |
+
self.conf = conf
|
97 |
+
self.classes = classes
|
98 |
+
self.scale = scale_factor
|
99 |
+
|
100 |
+
# ββββββββ NEW ββββββββ
|
101 |
+
self.regions = {
|
102 |
+
rid: np.array(pts, np.int32) for rid, pts in regions.items() if pts
|
103 |
+
}
|
104 |
+
if not self.regions:
|
105 |
+
raise ValueError("`regions` cannot be empty β provide at least one polygon.")
|
106 |
+
|
107 |
+
# Prepare DataFrame columns once
|
108 |
+
self.df_columns = [
|
109 |
+
"Frame Number",
|
110 |
+
*[
|
111 |
+
f"{LABEL_MAP[c]}_{rid}"
|
112 |
+
for rid in self.regions
|
113 |
+
for c in LABEL_MAP.keys()
|
114 |
+
],
|
115 |
+
]
|
116 |
+
|
117 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
118 |
+
def process_video(self) -> pd.DataFrame:
|
119 |
+
cap = cv2.VideoCapture(self.video_path)
|
120 |
+
if not cap.isOpened():
|
121 |
+
raise ValueError(f"Cannot open video: {self.video_path}")
|
122 |
+
|
123 |
+
ok, first_frame_original = cap.read()
|
124 |
+
if not ok:
|
125 |
+
cap.release()
|
126 |
+
raise ValueError(f"Cannot read first frame from: {self.video_path}")
|
127 |
+
|
128 |
+
h_orig, w_orig = first_frame_original.shape[:2]
|
129 |
+
prediction_counter_df = pd.DataFrame(columns=self.df_columns)
|
130 |
+
|
131 |
+
first_frame_processed = first_frame_original
|
132 |
+
frame_was_rotated = False
|
133 |
+
|
134 |
+
if w_orig < h_orig:
|
135 |
+
print(
|
136 |
+
f"Original frame (h,w): ({h_orig}, {w_orig}). Portrait β rotating 90Β° CW."
|
137 |
+
)
|
138 |
+
first_frame_processed = cv2.rotate(
|
139 |
+
first_frame_original, cv2.ROTATE_90_CLOCKWISE
|
140 |
+
)
|
141 |
+
frame_was_rotated = True
|
142 |
+
else:
|
143 |
+
print(f"Original frame (h,w): ({h_orig}, {w_orig}). Processing as landscape.")
|
144 |
+
|
145 |
+
# ----------------------------------------------------------------
|
146 |
+
base_h, base_w = first_frame_processed.shape[:2]
|
147 |
+
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
|
148 |
+
|
149 |
+
out_w, out_h = int(base_w * self.scale), int(base_h * self.scale)
|
150 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
151 |
+
out = cv2.VideoWriter(self.output_path, fourcc, fps, (out_w, out_h))
|
152 |
+
|
153 |
+
prev_t = time.time()
|
154 |
+
frame_count = 1
|
155 |
+
frame_up = cv2.resize(
|
156 |
+
first_frame_processed, (out_w, out_h), interpolation=cv2.INTER_LINEAR
|
157 |
+
)
|
158 |
+
prev_t = self._process_and_write_frame(
|
159 |
+
frame_up, out, prev_t, prediction_counter_df, frame_count
|
160 |
+
)
|
161 |
+
|
162 |
+
while True:
|
163 |
+
ok, frame_original_loop = cap.read()
|
164 |
+
if not ok:
|
165 |
+
break
|
166 |
+
|
167 |
+
if frame_count % (fps // 2 or 1) == 0: # frame skipping @ β2 fps
|
168 |
+
frame_processed_loop = (
|
169 |
+
cv2.rotate(frame_original_loop, cv2.ROTATE_90_CLOCKWISE)
|
170 |
+
if frame_was_rotated
|
171 |
+
else frame_original_loop
|
172 |
+
)
|
173 |
+
frame_up = cv2.resize(
|
174 |
+
frame_processed_loop, (out_w, out_h), interpolation=cv2.INTER_LINEAR
|
175 |
+
)
|
176 |
+
prev_t = self._process_and_write_frame(
|
177 |
+
frame_up, out, prev_t, prediction_counter_df, frame_count
|
178 |
+
)
|
179 |
+
|
180 |
+
frame_count += 1
|
181 |
+
|
182 |
+
cap.release()
|
183 |
+
out.release()
|
184 |
+
cv2.destroyAllWindows()
|
185 |
+
print(f"Processed {frame_count} frames. Finished β {self.output_path}")
|
186 |
+
return prediction_counter_df.fillna(0)
|
187 |
+
|
188 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
189 |
+
def _process_and_write_frame(
|
190 |
+
self,
|
191 |
+
frame_up: np.ndarray,
|
192 |
+
out_writer: cv2.VideoWriter,
|
193 |
+
prev_t: float,
|
194 |
+
prediction_counter_df: pd.DataFrame,
|
195 |
+
frame_count: int,
|
196 |
+
) -> float:
|
197 |
+
"""Run YOLO on one frame, count per region, annotate, write, return timestamp."""
|
198 |
+
# Draw polygons first (scaled!)
|
199 |
+
scale_x = frame_up.shape[1] / (frame_up.shape[1] / self.scale)
|
200 |
+
scale_y = frame_up.shape[0] / (frame_up.shape[0] / self.scale)
|
201 |
+
for rid, poly in self.regions.items():
|
202 |
+
poly_up = (poly * [self.scale, self.scale]).astype(np.int32)
|
203 |
+
cv2.polylines(frame_up, [poly_up], True, (255, 255, 0), 2)
|
204 |
+
draw_text_with_background(frame_up, f"R{rid}", tuple(poly_up[0]), font_scale=0.8)
|
205 |
+
|
206 |
+
results = self.model.predict(
|
207 |
+
frame_up,
|
208 |
+
conf=self.conf,
|
209 |
+
classes=self.classes,
|
210 |
+
verbose=False,
|
211 |
+
device=self.device,
|
212 |
+
)
|
213 |
+
|
214 |
+
# counts[region][cls_id] β int
|
215 |
+
counts: Dict[int, Dict[int, int]] = {
|
216 |
+
rid: defaultdict(int) for rid in self.regions
|
217 |
+
}
|
218 |
+
|
219 |
+
if results and len(results[0].boxes):
|
220 |
+
xyxy = results[0].boxes.xyxy.cpu().numpy()
|
221 |
+
scores = results[0].boxes.conf.cpu().numpy()
|
222 |
+
cls_ids = results[0].boxes.cls.int().cpu().tolist()
|
223 |
+
|
224 |
+
for (x1, y1, x2, y2), score, cls_id in zip(xyxy, scores, cls_ids):
|
225 |
+
color = get_color_for_class(cls_id)
|
226 |
+
cv2.rectangle(
|
227 |
+
frame_up, (int(x1), int(y1)), (int(x2), int(y2)), color, 2
|
228 |
+
)
|
229 |
+
label = LABEL_MAP.get(cls_id, f"Class {cls_id}")
|
230 |
+
draw_text_with_background(
|
231 |
+
frame_up,
|
232 |
+
f"{label}: {score:.2f}",
|
233 |
+
(int(x1), int(y1) - 10),
|
234 |
+
font_scale=0.6,
|
235 |
+
font_thickness=1,
|
236 |
+
bg_color=color,
|
237 |
+
padding=3,
|
238 |
+
)
|
239 |
+
|
240 |
+
# Region assignment based on *centre* of the box
|
241 |
+
cx, cy = int((x1 + x2) / 2), int((y1 + y2) / 2)
|
242 |
+
for rid, poly in self.regions.items():
|
243 |
+
poly_up = (poly * [self.scale, self.scale]).astype(np.int32)
|
244 |
+
if _inside((cx, cy), poly_up):
|
245 |
+
counts[rid][cls_id] += 1
|
246 |
+
break # one region per detection
|
247 |
+
|
248 |
+
# βββ Overlay perβregion counts + update DataFrame βββ
|
249 |
+
df_idx = len(prediction_counter_df)
|
250 |
+
prediction_counter_df.at[df_idx, "Frame Number"] = frame_count
|
251 |
+
|
252 |
+
y_off = 30
|
253 |
+
for rid, cls_dict in counts.items():
|
254 |
+
for cls_id, cnt in cls_dict.items():
|
255 |
+
label = LABEL_MAP.get(cls_id, f"Class {cls_id}")
|
256 |
+
col_name = f"{label}_{rid}"
|
257 |
+
prediction_counter_df.at[df_idx, col_name] = cnt
|
258 |
+
draw_text_with_background(
|
259 |
+
frame_up,
|
260 |
+
f"{label}_{rid}: {cnt}",
|
261 |
+
(10, y_off),
|
262 |
+
font_scale=0.7,
|
263 |
+
font_thickness=2,
|
264 |
+
padding=6,
|
265 |
+
)
|
266 |
+
y_off += 25
|
267 |
+
|
268 |
+
# FPS overlay
|
269 |
+
now = time.time()
|
270 |
+
fps_live = 1.0 / (now - prev_t) if (now - prev_t) > 0 else 0.0
|
271 |
+
draw_text_with_background(
|
272 |
+
frame_up,
|
273 |
+
f"FPS: {fps_live:.1f}",
|
274 |
+
(10, frame_up.shape[0] - 20),
|
275 |
+
bg_color=(0, 0, 0),
|
276 |
+
text_color=(0, 255, 0),
|
277 |
+
font_scale=0.8,
|
278 |
+
padding=4,
|
279 |
+
)
|
280 |
+
|
281 |
+
out_writer.write(frame_up)
|
282 |
+
return now
|
streamlit_app.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app.py
|
2 |
+
|
3 |
+
import os
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from tempfile import NamedTemporaryFile
|
9 |
+
from typing import List, Tuple, Dict
|
10 |
+
|
11 |
+
import streamlit as st
|
12 |
+
from PIL import Image
|
13 |
+
from streamlit_drawable_canvas import st_canvas
|
14 |
+
|
15 |
+
from lane_detection import YOLOVideoDetector, LABEL_MAP # Your detector + LABEL_MAP from detection.py
|
16 |
+
|
17 |
+
|
18 |
+
# βββββββββββββββββββββββ Helper Functions βββββββββββββββββββββββ
|
19 |
+
|
20 |
+
def extract_four_points(js) -> List[Tuple[int,int]]:
|
21 |
+
"""
|
22 |
+
Return exactly four (x,y) clicks from streamlit_drawable_canvas JSON, or None.
|
23 |
+
"""
|
24 |
+
if not js or "objects" not in js:
|
25 |
+
return None
|
26 |
+
pts = []
|
27 |
+
for obj in js["objects"]:
|
28 |
+
if obj.get("type") in {"circle", "rect"}:
|
29 |
+
x = int(obj["left"] + obj.get("radius", 0))
|
30 |
+
y = int(obj["top"] + obj.get("radius", 0))
|
31 |
+
pts.append((x, y))
|
32 |
+
if len(pts) == 4:
|
33 |
+
return pts
|
34 |
+
return None
|
35 |
+
|
36 |
+
|
37 |
+
def draw_poly(img: np.ndarray, pts: List[Tuple[int,int]], color: Tuple[int,int,int]):
|
38 |
+
"""
|
39 |
+
Draw a closed polygon (4 points) on img in the specified color.
|
40 |
+
"""
|
41 |
+
cv2.polylines(img, [np.array(pts, np.int32)], True, color, 2, cv2.LINE_AA)
|
42 |
+
|
43 |
+
|
44 |
+
# βββββββββββββββββββββββββ Streamlit App βββββββββββββββββββββββββ
|
45 |
+
|
46 |
+
st.set_page_config(page_title="π¦ MultiβLane Congestion Demo", layout="wide")
|
47 |
+
st.title("π¦ MultiβLane Vehicle Congestion Demo")
|
48 |
+
|
49 |
+
# Initialize session state
|
50 |
+
if "num_lanes" not in st.session_state:
|
51 |
+
st.session_state.num_lanes = None
|
52 |
+
st.session_state.current_lane = 0
|
53 |
+
st.session_state.lanes = []
|
54 |
+
st.session_state.video_path = None
|
55 |
+
st.session_state.video_uploaded = False
|
56 |
+
|
57 |
+
# βββββββββββββββββββ Step 1: Choose Number of Lanes βββββββββββββββββββ
|
58 |
+
if st.session_state.num_lanes is None:
|
59 |
+
n = st.number_input(
|
60 |
+
"How many lanes would you like to define? (1β8)",
|
61 |
+
min_value=1,
|
62 |
+
max_value=8,
|
63 |
+
value=2
|
64 |
+
)
|
65 |
+
if st.button("β Set Number of Lanes"):
|
66 |
+
st.session_state.num_lanes = int(n)
|
67 |
+
st.session_state.lanes = [None] * st.session_state.num_lanes
|
68 |
+
st.stop()
|
69 |
+
|
70 |
+
# βββββββββββββββββββ Step 2: Upload a Video βββββββββββββββββββ
|
71 |
+
if not st.session_state.video_uploaded:
|
72 |
+
uploaded = st.file_uploader(
|
73 |
+
"Upload video (formats: mp4, avi, mov, mkv)",
|
74 |
+
type=["mp4","avi","mov","mkv"]
|
75 |
+
)
|
76 |
+
if uploaded:
|
77 |
+
tmpfile = NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded.name)[1])
|
78 |
+
tmpfile.write(uploaded.read())
|
79 |
+
tmpfile.flush()
|
80 |
+
st.session_state.video_path = tmpfile.name
|
81 |
+
st.session_state.video_uploaded = True
|
82 |
+
else:
|
83 |
+
st.stop()
|
84 |
+
|
85 |
+
# βββββββββββββββββββ Step 3: Grab First Frame & Scale βββββββββββββββββββ
|
86 |
+
cap = cv2.VideoCapture(st.session_state.video_path)
|
87 |
+
ret, first_frame = cap.read()
|
88 |
+
cap.release()
|
89 |
+
if not ret:
|
90 |
+
st.error("β Could not read the first frame of the video.")
|
91 |
+
st.stop()
|
92 |
+
|
93 |
+
frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
|
94 |
+
h_orig, w_orig = frame_rgb.shape[:2]
|
95 |
+
|
96 |
+
# If the frame is wider than 800 px, scale down
|
97 |
+
MAX_W = 800
|
98 |
+
if w_orig > MAX_W:
|
99 |
+
scale = MAX_W / w_orig
|
100 |
+
disp_w = MAX_W
|
101 |
+
disp_h = int(h_orig * scale)
|
102 |
+
frame_disp = cv2.resize(frame_rgb, (disp_w, disp_h), interpolation=cv2.INTER_AREA)
|
103 |
+
else:
|
104 |
+
scale = 1.0
|
105 |
+
disp_w, disp_h = w_orig, h_orig
|
106 |
+
frame_disp = frame_rgb.copy()
|
107 |
+
|
108 |
+
# βββββββββββββββββββ Step 4: Draw 4 Points Per Lane βββββββββββββββββββ
|
109 |
+
colors = [
|
110 |
+
(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0),
|
111 |
+
(255, 0, 255), (0, 255, 255), (128, 255, 0), (255, 128, 0)
|
112 |
+
]
|
113 |
+
|
114 |
+
if st.session_state.current_lane < st.session_state.num_lanes:
|
115 |
+
idx = st.session_state.current_lane
|
116 |
+
color = colors[idx % len(colors)]
|
117 |
+
st.subheader(f"2οΈβ£ Click exactly 4 points for Lane #{idx+1}")
|
118 |
+
st.caption("Draw 4 small circles on the image, then press **Confirm Lane**.")
|
119 |
+
|
120 |
+
canvas = st_canvas(
|
121 |
+
fill_color="rgba(0,0,0,0)",
|
122 |
+
stroke_width=2,
|
123 |
+
stroke_color=f"#{color[2]:02X}{color[1]:02X}{color[0]:02X}",
|
124 |
+
background_image=Image.fromarray(frame_disp),
|
125 |
+
drawing_mode="point",
|
126 |
+
key=f"lane_canvas_{idx}",
|
127 |
+
height=disp_h,
|
128 |
+
width=disp_w,
|
129 |
+
update_streamlit=True
|
130 |
+
)
|
131 |
+
|
132 |
+
pts_scaled = extract_four_points(canvas.json_data)
|
133 |
+
if pts_scaled:
|
134 |
+
preview = frame_disp.copy()
|
135 |
+
draw_poly(preview, pts_scaled, color)
|
136 |
+
st.image(preview, caption=f"Preview β Lane {idx+1}", use_column_width=True)
|
137 |
+
|
138 |
+
if st.button(f"Confirm Lane {idx+1}"):
|
139 |
+
if pts_scaled and len(pts_scaled) == 4:
|
140 |
+
# Convert scaled points back to original resolution
|
141 |
+
orig_pts = [(int(x/scale), int(y/scale)) for (x,y) in pts_scaled]
|
142 |
+
st.session_state.lanes[idx] = orig_pts
|
143 |
+
st.session_state.current_lane += 1
|
144 |
+
else:
|
145 |
+
st.warning("β Please click exactly 4 points.")
|
146 |
+
st.stop()
|
147 |
+
|
148 |
+
# βββββββββββββββββββ Step 5: Display All Lane Polygons βββββββββββββββββββ
|
149 |
+
st.subheader("β
All lanes defined:")
|
150 |
+
confirm_img = frame_rgb.copy()
|
151 |
+
for i, poly in enumerate(st.session_state.lanes):
|
152 |
+
c = colors[i % len(colors)]
|
153 |
+
draw_poly(confirm_img, poly, c)
|
154 |
+
cv2.putText(
|
155 |
+
confirm_img, f"L{i+1}", (poly[0][0], poly[0][1] - 10),
|
156 |
+
cv2.FONT_HERSHEY_SIMPLEX, 1.0, c, 2, cv2.LINE_AA
|
157 |
+
)
|
158 |
+
st.image(confirm_img, caption="All lane regions overlaid", use_column_width=True)
|
159 |
+
|
160 |
+
# βββββββββββββββββββ Step 6: Input Thresholds & Run Congestion Analysis βββββββββββββββββββ
|
161 |
+
st.subheader("π§ Congestion Thresholds")
|
162 |
+
col1, col2 = st.columns(2)
|
163 |
+
with col1:
|
164 |
+
low_thresh = st.number_input(
|
165 |
+
"Green if PCE <",
|
166 |
+
min_value=0.0,
|
167 |
+
max_value=20.0,
|
168 |
+
value=3.5,
|
169 |
+
step=0.1,
|
170 |
+
format="%.1f",
|
171 |
+
help="Values below this will be colored green"
|
172 |
+
)
|
173 |
+
with col2:
|
174 |
+
high_thresh = st.number_input(
|
175 |
+
"Red if PCE >",
|
176 |
+
min_value=0.0,
|
177 |
+
max_value=20.0,
|
178 |
+
value=6.5,
|
179 |
+
step=0.1,
|
180 |
+
format="%.1f",
|
181 |
+
help="Values above this will be colored red"
|
182 |
+
)
|
183 |
+
st.caption("Values between green/red thresholds will be yellow.")
|
184 |
+
|
185 |
+
if st.button("π Run Congestion Analysis"):
|
186 |
+
out_tmp = NamedTemporaryFile(delete=False, suffix=".mp4").name
|
187 |
+
|
188 |
+
regions: Dict[int, List[Tuple[int,int]]] = {
|
189 |
+
i: st.session_state.lanes[i] for i in range(st.session_state.num_lanes)
|
190 |
+
}
|
191 |
+
|
192 |
+
# Instantiate YOLOVideoDetector with 4 positional args
|
193 |
+
detector = YOLOVideoDetector(
|
194 |
+
"Weights/last.pt", # <-- replace with your actual .pt path
|
195 |
+
st.session_state.video_path,
|
196 |
+
out_tmp,
|
197 |
+
regions
|
198 |
+
)
|
199 |
+
# Assign optional attributes
|
200 |
+
detector.classes = list(LABEL_MAP.keys())
|
201 |
+
detector.conf = 0.35
|
202 |
+
detector.scale = 1.5
|
203 |
+
|
204 |
+
with st.spinner("Processing videoβthis may take a while..."):
|
205 |
+
df = detector.process_video()
|
206 |
+
|
207 |
+
st.success("β
Detection + annotation complete!")
|
208 |
+
|
209 |
+
# βββββββββββββββββ Compute Per-Lane PCE & Rolling Average βββββββββββββββββ
|
210 |
+
PCE = {
|
211 |
+
"auto": 0.8,
|
212 |
+
"bus": 4.0,
|
213 |
+
"car": 1.0,
|
214 |
+
"electric-rickshaw": 0.8,
|
215 |
+
"large-sized-truck": 4.5,
|
216 |
+
"medium-sized-truck": 3.5,
|
217 |
+
"motorbike": 0.5,
|
218 |
+
"small-sized-truck": 3.0,
|
219 |
+
}
|
220 |
+
|
221 |
+
for rid in regions.keys():
|
222 |
+
def lane_pce(row, rid=rid):
|
223 |
+
total = 0.0
|
224 |
+
for vt, factor in PCE.items():
|
225 |
+
coln = f"{vt}_{rid}"
|
226 |
+
cnt = row.get(coln, 0)
|
227 |
+
total += int(cnt) * factor
|
228 |
+
return total
|
229 |
+
|
230 |
+
df[f"PCE_lane{rid}"] = df.apply(lane_pce, axis=1)
|
231 |
+
df[f"PCE_lane{rid}_avg"] = df[f"PCE_lane{rid}"].rolling(window=5, min_periods=1).mean()
|
232 |
+
|
233 |
+
# βββββββββββββββββββ Lane-Wise Subplots with Smooth Line + Colored Markers βββββββββββββββββββ
|
234 |
+
num_lanes = len(regions)
|
235 |
+
fig, axes = plt.subplots(num_lanes, 1, figsize=(10, 3 * num_lanes), sharex=True)
|
236 |
+
|
237 |
+
if num_lanes == 1:
|
238 |
+
axes = [axes]
|
239 |
+
|
240 |
+
for rid, ax in zip(regions.keys(), axes):
|
241 |
+
x = df["Frame Number"].values
|
242 |
+
y = df[f"PCE_lane{rid}_avg"].values
|
243 |
+
|
244 |
+
# 1) Plot a smooth continuous line in dark gray
|
245 |
+
ax.plot(x, y, color="gray", linewidth=1.2)
|
246 |
+
|
247 |
+
# 2) Overlay colored markers at each frame
|
248 |
+
colors_list = []
|
249 |
+
for yi in y:
|
250 |
+
if yi < low_thresh:
|
251 |
+
colors_list.append("green")
|
252 |
+
elif yi > high_thresh:
|
253 |
+
colors_list.append("red")
|
254 |
+
else:
|
255 |
+
colors_list.append("yellow")
|
256 |
+
|
257 |
+
ax.scatter(x, y, c=colors_list, s=20, edgecolors="black", linewidths=0.3)
|
258 |
+
|
259 |
+
ax.set_title(f"Lane {rid} PCE (rolling average)")
|
260 |
+
ax.set_ylabel("PCE")
|
261 |
+
ax.grid(alpha=0.3)
|
262 |
+
|
263 |
+
axes[-1].set_xlabel("Frame Number")
|
264 |
+
plt.tight_layout()
|
265 |
+
|
266 |
+
st.subheader("π LaneβWise Congestion Plots")
|
267 |
+
st.pyplot(fig)
|
268 |
+
|
269 |
+
# βββββββββββββββββββ Display Annotated Video & CSV Download βββββββββββββββββββ
|
270 |
+
st.subheader("π¬ Annotated Output Video")
|
271 |
+
with open(out_tmp, "rb") as f:
|
272 |
+
st.video(f.read())
|
273 |
+
|
274 |
+
csv_bytes = df.to_csv(index=False).encode("utf-8")
|
275 |
+
st.download_button(
|
276 |
+
label="Download full counts + PCE CSV",
|
277 |
+
data=csv_bytes,
|
278 |
+
file_name="counts_and_pce.csv",
|
279 |
+
mime="text/csv"
|
280 |
+
)
|