Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -2,10 +2,26 @@
|
|
2 |
import streamlit as st
|
3 |
import cv2
|
4 |
import numpy as np
|
|
|
5 |
from ultralytics import YOLO
|
6 |
import supervision as sv
|
7 |
from scipy.spatial import distance as dist
|
8 |
import tempfile
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
# Initialize components
|
11 |
model = YOLO('yolov8s.pt')
|
@@ -20,6 +36,10 @@ calibration_dist = st.number_input("Field width in meters (for speed calibration
|
|
20 |
if 'player_data' not in st.session_state:
|
21 |
st.session_state.player_data = {}
|
22 |
|
|
|
|
|
|
|
|
|
23 |
if uploaded_video:
|
24 |
tfile = tempfile.NamedTemporaryFile(delete=False)
|
25 |
tfile.write(uploaded_video.read())
|
@@ -32,51 +52,108 @@ if uploaded_video:
|
|
32 |
st_frame = st.empty()
|
33 |
frame_count = 0
|
34 |
|
|
|
|
|
|
|
|
|
35 |
while cap.isOpened():
|
36 |
ret, frame = cap.read()
|
37 |
if not ret:
|
38 |
break
|
39 |
|
40 |
-
# Detection + tracking
|
41 |
results = model(frame)[0]
|
42 |
detections = sv.Detections.from_ultralytics(results)
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
detections = tracker.update_with_detections(detections)
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
'positions': [centroid],
|
55 |
-
'timestamps': [frame_count/fps],
|
56 |
-
'distance': 0.0,
|
57 |
-
'speeds': []
|
58 |
-
}
|
59 |
-
else:
|
60 |
-
# Calculate movement metrics
|
61 |
-
prev_pos = st.session_state.player_data[track_id]['positions'][-1]
|
62 |
-
time_diff = (frame_count/fps) - st.session_state.player_data[track_id]['timestamps'][-1]
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
speed = speed_px * pixels_per_meter # Convert to m/s
|
67 |
|
68 |
-
#
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
st_frame.image(frame, channels="BGR")
|
82 |
frame_count += 1
|
@@ -84,13 +161,48 @@ if uploaded_video:
|
|
84 |
cap.release()
|
85 |
|
86 |
# Display final analytics
|
87 |
-
st.subheader("Player Statistics")
|
|
|
|
|
|
|
88 |
for player_id, data in st.session_state.player_data.items():
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import streamlit as st
|
3 |
import cv2
|
4 |
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
from ultralytics import YOLO
|
7 |
import supervision as sv
|
8 |
from scipy.spatial import distance as dist
|
9 |
import tempfile
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
|
12 |
+
# Sidebar with logo and team information
|
13 |
+
with st.sidebar:
|
14 |
+
st.image("https://cs.christuniversity.in/softex/resources/img/christ_university_Black.png", width=200)
|
15 |
+
st.markdown("<h4 style='text-align: center; margin-top: 0;'>Christ University Kengeri Campus</h4>", unsafe_allow_html=True)
|
16 |
+
st.markdown("<h5 style='text-align: center; margin-top: 0;'>Sports Department</h5>", unsafe_allow_html=True)
|
17 |
+
|
18 |
+
st.markdown("### Team Members")
|
19 |
+
st.markdown("""
|
20 |
+
- Harsh Vardhan Lal 2262069 6BTCSAIML B
|
21 |
+
- Harsheet Sandeep Thakur 2262070 6BTCSAIML B
|
22 |
+
- Nandalal C B 2262115 6BTCSAIML B
|
23 |
+
- Athul Nambiar 2262041 6BTCSAIML B
|
24 |
+
""")
|
25 |
|
26 |
# Initialize components
|
27 |
model = YOLO('yolov8s.pt')
|
|
|
36 |
if 'player_data' not in st.session_state:
|
37 |
st.session_state.player_data = {}
|
38 |
|
39 |
+
# Create a placeholder for the live tracking table
|
40 |
+
live_table_container = st.container()
|
41 |
+
live_table = live_table_container.empty()
|
42 |
+
|
43 |
if uploaded_video:
|
44 |
tfile = tempfile.NamedTemporaryFile(delete=False)
|
45 |
tfile.write(uploaded_video.read())
|
|
|
52 |
st_frame = st.empty()
|
53 |
frame_count = 0
|
54 |
|
55 |
+
# Clear previous player data
|
56 |
+
st.session_state.player_data = {}
|
57 |
+
player_count = 0 # Track the number of distinct players
|
58 |
+
|
59 |
while cap.isOpened():
|
60 |
ret, frame = cap.read()
|
61 |
if not ret:
|
62 |
break
|
63 |
|
64 |
+
# Detection + tracking with person class filter (class 0)
|
65 |
results = model(frame)[0]
|
66 |
detections = sv.Detections.from_ultralytics(results)
|
67 |
+
|
68 |
+
# Filter for person class only (class 0 in COCO dataset)
|
69 |
+
person_mask = np.array([class_id == 0 for class_id in detections.class_id])
|
70 |
+
detections = detections[person_mask]
|
71 |
+
|
72 |
+
# Apply tracking - this will assign consistent IDs
|
73 |
detections = tracker.update_with_detections(detections)
|
74 |
+
|
75 |
+
# Extract bounding boxes, track IDs, etc.
|
76 |
+
boxes = detections.xyxy # Get bounding boxes [x1, y1, x2, y2]
|
77 |
+
track_ids = detections.tracker_id # Get track IDs
|
78 |
+
|
79 |
+
if boxes is not None and len(boxes) > 0:
|
80 |
+
for i, (box, track_id) in enumerate(zip(boxes, track_ids)):
|
81 |
+
# Skip detections without a track_id
|
82 |
+
if track_id is None:
|
83 |
+
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
+
# Use the track_id directly from ByteTrack
|
86 |
+
player_id = int(track_id)
|
|
|
87 |
|
88 |
+
# Calculate center point of bounding box
|
89 |
+
x1, y1, x2, y2 = box
|
90 |
+
centroid = (int((x1 + x2) / 2), int((y1 + y2) / 2))
|
91 |
+
|
92 |
+
# Initialize speed to 0 for all cases
|
93 |
+
speed = 0.0
|
94 |
+
|
95 |
+
# Initialize new player if not seen before
|
96 |
+
if player_id not in st.session_state.player_data:
|
97 |
+
st.session_state.player_data[player_id] = {
|
98 |
+
'positions': [centroid],
|
99 |
+
'timestamps': [frame_count / fps],
|
100 |
+
'distance': 0.0,
|
101 |
+
'speeds': [0.0], # Initialize with zero speed
|
102 |
+
'last_seen': frame_count
|
103 |
+
}
|
104 |
+
player_count += 1 # Increment player counter
|
105 |
+
else:
|
106 |
+
# Calculate movement metrics
|
107 |
+
prev_pos = st.session_state.player_data[player_id]['positions'][-1]
|
108 |
+
time_diff = (frame_count / fps) - st.session_state.player_data[player_id]['timestamps'][-1]
|
109 |
+
|
110 |
+
# Calculate distance
|
111 |
+
pixel_dist = dist.euclidean(prev_pos, centroid)
|
112 |
+
|
113 |
+
# Apply motion smoothing to ignore unrealistic movements
|
114 |
+
if pixel_dist < frame_width * 0.1: # Max 10% of screen width per frame
|
115 |
+
speed_px = pixel_dist / time_diff if time_diff > 0 else 0.0
|
116 |
+
|
117 |
+
# Convert to meters per second
|
118 |
+
speed = speed_px / pixels_per_meter
|
119 |
+
meter_dist = pixel_dist / pixels_per_meter
|
120 |
+
|
121 |
+
# Update player data
|
122 |
+
st.session_state.player_data[player_id]['positions'].append(centroid)
|
123 |
+
st.session_state.player_data[player_id]['timestamps'].append(frame_count / fps)
|
124 |
+
st.session_state.player_data[player_id]['distance'] += meter_dist
|
125 |
+
st.session_state.player_data[player_id]['speeds'].append(speed)
|
126 |
+
st.session_state.player_data[player_id]['last_seen'] = frame_count
|
127 |
+
else:
|
128 |
+
# Just update position without adding to distance/speed
|
129 |
+
st.session_state.player_data[player_id]['positions'].append(centroid)
|
130 |
+
st.session_state.player_data[player_id]['timestamps'].append(frame_count / fps)
|
131 |
+
st.session_state.player_data[player_id]['last_seen'] = frame_count
|
132 |
+
# Maintain previous speed
|
133 |
+
speed = st.session_state.player_data[player_id]['speeds'][-1] if st.session_state.player_data[player_id]['speeds'] else 0
|
134 |
+
|
135 |
+
# Draw player info on frame
|
136 |
+
label = f"ID:{player_id} Speed:{speed:.1f}m/s"
|
137 |
+
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
|
138 |
+
cv2.putText(frame, label, (int(x1), int(y1-10)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
|
139 |
|
140 |
+
# Update live table
|
141 |
+
if st.session_state.player_data:
|
142 |
+
live_data = []
|
143 |
+
for player_id, data in st.session_state.player_data.items():
|
144 |
+
# Only show recently seen players
|
145 |
+
if frame_count - data['last_seen'] < 30:
|
146 |
+
current_speed = data['speeds'][-1] if data['speeds'] else 0
|
147 |
+
live_data.append({
|
148 |
+
"Player": f"Player {player_id}",
|
149 |
+
"Current Speed (m/s)": round(current_speed, 2),
|
150 |
+
"Distance (m)": round(data['distance'], 2),
|
151 |
+
"Time (s)": round(data['timestamps'][-1], 2) if data['timestamps'] else 0
|
152 |
+
})
|
153 |
+
|
154 |
+
live_df = pd.DataFrame(live_data)
|
155 |
+
if not live_df.empty:
|
156 |
+
live_table.dataframe(live_df, use_container_width=True, hide_index=True)
|
157 |
|
158 |
st_frame.image(frame, channels="BGR")
|
159 |
frame_count += 1
|
|
|
161 |
cap.release()
|
162 |
|
163 |
# Display final analytics
|
164 |
+
st.subheader(f"Player Statistics ({player_count} players detected)")
|
165 |
+
|
166 |
+
# Create player stats table
|
167 |
+
stats_data = []
|
168 |
for player_id, data in st.session_state.player_data.items():
|
169 |
+
if len(data['speeds']) > 0:
|
170 |
+
avg_speed = np.mean(data['speeds'])
|
171 |
+
max_speed = np.max(data['speeds'])
|
172 |
+
else:
|
173 |
+
avg_speed = 0
|
174 |
+
max_speed = 0
|
175 |
+
|
176 |
+
duration = data['timestamps'][-1] - data['timestamps'][0] if len(data['timestamps']) > 1 else 0
|
177 |
+
|
178 |
+
# Only include players with significant tracking data
|
179 |
+
if len(data['positions']) > 10:
|
180 |
+
stats_data.append({
|
181 |
+
"Player": f"Player {player_id}",
|
182 |
+
"Total Distance (m)": round(data['distance'], 2),
|
183 |
+
"Avg Speed (m/s)": round(avg_speed, 2),
|
184 |
+
"Max Speed (m/s)": round(max_speed, 2),
|
185 |
+
"Duration (s)": round(duration, 2)
|
186 |
+
})
|
187 |
+
|
188 |
+
stats_df = pd.DataFrame(stats_data)
|
189 |
+
st.dataframe(stats_df, use_container_width=True, hide_index=True)
|
190 |
+
|
191 |
+
# Create speed comparison chart
|
192 |
+
if st.session_state.player_data:
|
193 |
+
st.subheader("Player Speed Comparison")
|
194 |
+
|
195 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
196 |
+
|
197 |
+
for player_id, data in st.session_state.player_data.items():
|
198 |
+
# Only plot players with significant data
|
199 |
+
if len(data['speeds']) > 10:
|
200 |
+
ax.plot(data['timestamps'], data['speeds'], label=f"Player {player_id}")
|
201 |
+
|
202 |
+
ax.set_xlabel('Time (seconds)')
|
203 |
+
ax.set_ylabel('Speed (m/s)')
|
204 |
+
ax.set_title('Player Speed Over Time')
|
205 |
+
ax.legend()
|
206 |
+
ax.grid(True)
|
207 |
+
|
208 |
+
st.pyplot(fig)
|