|
|
|
|
|
import os |
|
import cv2 |
|
import numpy as np |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
from tempfile import NamedTemporaryFile |
|
from typing import List, Tuple, Dict |
|
|
|
import streamlit as st |
|
from PIL import Image |
|
from streamlit_drawable_canvas import st_canvas |
|
|
|
from lane_detection import YOLOVideoDetector, LABEL_MAP |
|
|
|
|
|
|
|
|
|
def extract_four_points(js) -> List[Tuple[int,int]]: |
|
""" |
|
Return exactly four (x,y) clicks from streamlit_drawable_canvas JSON, or None. |
|
""" |
|
if not js or "objects" not in js: |
|
return None |
|
pts = [] |
|
for obj in js["objects"]: |
|
if obj.get("type") in {"circle", "rect"}: |
|
x = int(obj["left"] + obj.get("radius", 0)) |
|
y = int(obj["top"] + obj.get("radius", 0)) |
|
pts.append((x, y)) |
|
if len(pts) == 4: |
|
return pts |
|
return None |
|
|
|
|
|
def draw_poly(img: np.ndarray, pts: List[Tuple[int,int]], color: Tuple[int,int,int]): |
|
""" |
|
Draw a closed polygon (4 points) on img in the specified color. |
|
""" |
|
cv2.polylines(img, [np.array(pts, np.int32)], True, color, 2, cv2.LINE_AA) |
|
|
|
|
|
|
|
|
|
st.set_page_config(page_title="π¦ MultiβLane Congestion Demo", layout="wide") |
|
st.title("π¦ MultiβLane Vehicle Congestion Demo") |
|
|
|
|
|
if "num_lanes" not in st.session_state: |
|
st.session_state.num_lanes = None |
|
st.session_state.current_lane = 0 |
|
st.session_state.lanes = [] |
|
st.session_state.video_path = None |
|
st.session_state.video_uploaded = False |
|
|
|
|
|
if st.session_state.num_lanes is None: |
|
n = st.number_input( |
|
"How many lanes would you like to define? (1β8)", |
|
min_value=1, |
|
max_value=8, |
|
value=2 |
|
) |
|
if st.button("β Set Number of Lanes"): |
|
st.session_state.num_lanes = int(n) |
|
st.session_state.lanes = [None] * st.session_state.num_lanes |
|
st.stop() |
|
|
|
|
|
if not st.session_state.video_uploaded: |
|
uploaded = st.file_uploader( |
|
"Upload video (formats: mp4, avi, mov, mkv)", |
|
type=["mp4","avi","mov","mkv"] |
|
) |
|
if uploaded: |
|
tmpfile = NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded.name)[1]) |
|
tmpfile.write(uploaded.read()) |
|
tmpfile.flush() |
|
st.session_state.video_path = tmpfile.name |
|
st.session_state.video_uploaded = True |
|
else: |
|
st.stop() |
|
|
|
|
|
cap = cv2.VideoCapture(st.session_state.video_path) |
|
ret, first_frame = cap.read() |
|
cap.release() |
|
if not ret: |
|
st.error("β Could not read the first frame of the video.") |
|
st.stop() |
|
|
|
frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB) |
|
h_orig, w_orig = frame_rgb.shape[:2] |
|
|
|
|
|
MAX_W = 800 |
|
if w_orig > MAX_W: |
|
scale = MAX_W / w_orig |
|
disp_w = MAX_W |
|
disp_h = int(h_orig * scale) |
|
frame_disp = cv2.resize(frame_rgb, (disp_w, disp_h), interpolation=cv2.INTER_AREA) |
|
else: |
|
scale = 1.0 |
|
disp_w, disp_h = w_orig, h_orig |
|
frame_disp = frame_rgb.copy() |
|
|
|
|
|
colors = [ |
|
(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), |
|
(255, 0, 255), (0, 255, 255), (128, 255, 0), (255, 128, 0) |
|
] |
|
|
|
if st.session_state.current_lane < st.session_state.num_lanes: |
|
idx = st.session_state.current_lane |
|
color = colors[idx % len(colors)] |
|
st.subheader(f"2οΈβ£ Click exactly 4 points for Lane #{idx+1}") |
|
st.caption("Draw 4 small circles on the image, then press **Confirm Lane**.") |
|
|
|
canvas = st_canvas( |
|
fill_color="rgba(0,0,0,0)", |
|
stroke_width=2, |
|
stroke_color=f"#{color[2]:02X}{color[1]:02X}{color[0]:02X}", |
|
background_image=Image.fromarray(frame_disp), |
|
drawing_mode="point", |
|
key=f"lane_canvas_{idx}", |
|
height=disp_h, |
|
width=disp_w, |
|
update_streamlit=True |
|
) |
|
|
|
pts_scaled = extract_four_points(canvas.json_data) |
|
if pts_scaled: |
|
preview = frame_disp.copy() |
|
draw_poly(preview, pts_scaled, color) |
|
st.image(preview, caption=f"Preview β Lane {idx+1}", use_column_width=True) |
|
|
|
if st.button(f"Confirm Lane {idx+1}"): |
|
if pts_scaled and len(pts_scaled) == 4: |
|
|
|
orig_pts = [(int(x/scale), int(y/scale)) for (x,y) in pts_scaled] |
|
st.session_state.lanes[idx] = orig_pts |
|
st.session_state.current_lane += 1 |
|
else: |
|
st.warning("β Please click exactly 4 points.") |
|
st.stop() |
|
|
|
|
|
st.subheader("β
All lanes defined:") |
|
confirm_img = frame_rgb.copy() |
|
for i, poly in enumerate(st.session_state.lanes): |
|
c = colors[i % len(colors)] |
|
draw_poly(confirm_img, poly, c) |
|
cv2.putText( |
|
confirm_img, f"L{i+1}", (poly[0][0], poly[0][1] - 10), |
|
cv2.FONT_HERSHEY_SIMPLEX, 1.0, c, 2, cv2.LINE_AA |
|
) |
|
st.image(confirm_img, caption="All lane regions overlaid", use_column_width=True) |
|
|
|
|
|
st.subheader("π§ Congestion Thresholds") |
|
col1, col2 = st.columns(2) |
|
with col1: |
|
low_thresh = st.number_input( |
|
"Green if PCE <", |
|
min_value=0.0, |
|
max_value=20.0, |
|
value=3.5, |
|
step=0.1, |
|
format="%.1f", |
|
help="Values below this will be colored green" |
|
) |
|
with col2: |
|
high_thresh = st.number_input( |
|
"Red if PCE >", |
|
min_value=0.0, |
|
max_value=20.0, |
|
value=6.5, |
|
step=0.1, |
|
format="%.1f", |
|
help="Values above this will be colored red" |
|
) |
|
st.caption("Values between green/red thresholds will be yellow.") |
|
|
|
if st.button("π Run Congestion Analysis"): |
|
out_tmp = NamedTemporaryFile(delete=False, suffix=".mp4").name |
|
|
|
regions: Dict[int, List[Tuple[int,int]]] = { |
|
i: st.session_state.lanes[i] for i in range(st.session_state.num_lanes) |
|
} |
|
|
|
|
|
detector = YOLOVideoDetector( |
|
"Weights/last.pt", |
|
st.session_state.video_path, |
|
out_tmp, |
|
regions |
|
) |
|
|
|
detector.classes = list(LABEL_MAP.keys()) |
|
detector.conf = 0.35 |
|
detector.scale = 1.5 |
|
|
|
with st.spinner("Processing videoβthis may take a while..."): |
|
df = detector.process_video() |
|
|
|
st.success("β
Detection + annotation complete!") |
|
|
|
|
|
PCE = { |
|
"auto": 0.8, |
|
"bus": 4.0, |
|
"car": 1.0, |
|
"electric-rickshaw": 0.8, |
|
"large-sized-truck": 4.5, |
|
"medium-sized-truck": 3.5, |
|
"motorbike": 0.5, |
|
"small-sized-truck": 3.0, |
|
} |
|
|
|
for rid in regions.keys(): |
|
def lane_pce(row, rid=rid): |
|
total = 0.0 |
|
for vt, factor in PCE.items(): |
|
coln = f"{vt}_{rid}" |
|
cnt = row.get(coln, 0) |
|
total += int(cnt) * factor |
|
return total |
|
|
|
df[f"PCE_lane{rid}"] = df.apply(lane_pce, axis=1) |
|
df[f"PCE_lane{rid}_avg"] = df[f"PCE_lane{rid}"].rolling(window=5, min_periods=1).mean() |
|
|
|
|
|
num_lanes = len(regions) |
|
fig, axes = plt.subplots(num_lanes, 1, figsize=(10, 3 * num_lanes), sharex=True) |
|
|
|
if num_lanes == 1: |
|
axes = [axes] |
|
|
|
for rid, ax in zip(regions.keys(), axes): |
|
x = df["Frame Number"].values |
|
y = df[f"PCE_lane{rid}_avg"].values |
|
|
|
|
|
ax.plot(x, y, color="gray", linewidth=1.2) |
|
|
|
|
|
colors_list = [] |
|
for yi in y: |
|
if yi < low_thresh: |
|
colors_list.append("green") |
|
elif yi > high_thresh: |
|
colors_list.append("red") |
|
else: |
|
colors_list.append("yellow") |
|
|
|
ax.scatter(x, y, c=colors_list, s=20, edgecolors="black", linewidths=0.3) |
|
|
|
ax.set_title(f"Lane {rid} PCE (rolling average)") |
|
ax.set_ylabel("PCE") |
|
ax.grid(alpha=0.3) |
|
|
|
axes[-1].set_xlabel("Frame Number") |
|
plt.tight_layout() |
|
|
|
st.subheader("π LaneβWise Congestion Plots") |
|
st.pyplot(fig) |
|
|
|
|
|
st.subheader("π¬ Annotated Output Video") |
|
with open(out_tmp, "rb") as f: |
|
st.video(f.read()) |
|
|
|
csv_bytes = df.to_csv(index=False).encode("utf-8") |
|
st.download_button( |
|
label="Download full counts + PCE CSV", |
|
data=csv_bytes, |
|
file_name="counts_and_pce.csv", |
|
mime="text/csv" |
|
) |
|
|