athulnambiar commited on
Commit
ce8c119
·
verified ·
1 Parent(s): b75c09a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -43
app.py CHANGED
@@ -2,10 +2,26 @@
2
  import streamlit as st
3
  import cv2
4
  import numpy as np
 
5
  from ultralytics import YOLO
6
  import supervision as sv
7
  from scipy.spatial import distance as dist
8
  import tempfile
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # Initialize components
11
  model = YOLO('yolov8s.pt')
@@ -20,6 +36,10 @@ calibration_dist = st.number_input("Field width in meters (for speed calibration
20
  if 'player_data' not in st.session_state:
21
  st.session_state.player_data = {}
22
 
 
 
 
 
23
  if uploaded_video:
24
  tfile = tempfile.NamedTemporaryFile(delete=False)
25
  tfile.write(uploaded_video.read())
@@ -32,51 +52,108 @@ if uploaded_video:
32
  st_frame = st.empty()
33
  frame_count = 0
34
 
 
 
 
 
35
  while cap.isOpened():
36
  ret, frame = cap.read()
37
  if not ret:
38
  break
39
 
40
- # Detection + tracking
41
  results = model(frame)[0]
42
  detections = sv.Detections.from_ultralytics(results)
 
 
 
 
 
 
43
  detections = tracker.update_with_detections(detections)
44
-
45
- for detection in detections:
46
- track_id = int(detection[4])
47
- bbox = detection[0]
48
- centroid = (int((bbox[0]+bbox[2])/2), int((bbox[1]+bbox[3])/2))
49
- speed = 0.0 # Initialize speed for all cases
50
-
51
- # Initialize new player
52
- if track_id not in st.session_state.player_data:
53
- st.session_state.player_data[track_id] = {
54
- 'positions': [centroid],
55
- 'timestamps': [frame_count/fps],
56
- 'distance': 0.0,
57
- 'speeds': []
58
- }
59
- else:
60
- # Calculate movement metrics
61
- prev_pos = st.session_state.player_data[track_id]['positions'][-1]
62
- time_diff = (frame_count/fps) - st.session_state.player_data[track_id]['timestamps'][-1]
63
 
64
- pixel_dist = dist.euclidean(prev_pos, centroid)
65
- speed_px = pixel_dist / time_diff if time_diff > 0 else 0.0
66
- speed = speed_px * pixels_per_meter # Convert to m/s
67
 
68
- # Update player record
69
- st.session_state.player_data[track_id]['positions'].append(centroid)
70
- st.session_state.player_data[track_id]['timestamps'].append(frame_count/fps)
71
- st.session_state.player_data[track_id]['distance'] += pixel_dist
72
- st.session_state.player_data[track_id]['speeds'].append(speed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- # Annotation with failsafe
75
- label = f"ID:{track_id} Speed:{speed:.1f}m/s"
76
- cv2.rectangle(frame, (int(bbox[0]), int(bbox[1])),
77
- (int(bbox[2]), int(bbox[3])), (0,255,0), 2)
78
- cv2.putText(frame, label, (int(bbox[0]), int(bbox[1]-10)),
79
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,255,0), 1)
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  st_frame.image(frame, channels="BGR")
82
  frame_count += 1
@@ -84,13 +161,48 @@ if uploaded_video:
84
  cap.release()
85
 
86
  # Display final analytics
87
- st.subheader("Player Statistics")
 
 
 
88
  for player_id, data in st.session_state.player_data.items():
89
- avg_speed = np.mean(data['speeds']) if data['speeds'] else 0
90
- st.write(f"""
91
- **Player {player_id}**
92
- - Total distance: {data['distance'] * pixels_per_meter:.2f}m
93
- - Avg speed: {avg_speed:.1f}m/s
94
- - Max speed: {np.max(data['speeds']):.1f}m/s
95
- - Active duration: {data['timestamps'][-1]-data['timestamps'][0]:.1f}s
96
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import streamlit as st
3
  import cv2
4
  import numpy as np
5
+ import pandas as pd
6
  from ultralytics import YOLO
7
  import supervision as sv
8
  from scipy.spatial import distance as dist
9
  import tempfile
10
+ import matplotlib.pyplot as plt
11
+
12
+ # Sidebar with logo and team information
13
+ with st.sidebar:
14
+ st.image("https://cs.christuniversity.in/softex/resources/img/christ_university_Black.png", width=200)
15
+ st.markdown("<h4 style='text-align: center; margin-top: 0;'>Christ University Kengeri Campus</h4>", unsafe_allow_html=True)
16
+ st.markdown("<h5 style='text-align: center; margin-top: 0;'>Sports Department</h5>", unsafe_allow_html=True)
17
+
18
+ st.markdown("### Team Members")
19
+ st.markdown("""
20
+ - Harsh Vardhan Lal 2262069 6BTCSAIML B
21
+ - Harsheet Sandeep Thakur 2262070 6BTCSAIML B
22
+ - Nandalal C B 2262115 6BTCSAIML B
23
+ - Athul Nambiar 2262041 6BTCSAIML B
24
+ """)
25
 
26
  # Initialize components
27
  model = YOLO('yolov8s.pt')
 
36
  if 'player_data' not in st.session_state:
37
  st.session_state.player_data = {}
38
 
39
+ # Create a placeholder for the live tracking table
40
+ live_table_container = st.container()
41
+ live_table = live_table_container.empty()
42
+
43
  if uploaded_video:
44
  tfile = tempfile.NamedTemporaryFile(delete=False)
45
  tfile.write(uploaded_video.read())
 
52
  st_frame = st.empty()
53
  frame_count = 0
54
 
55
+ # Clear previous player data
56
+ st.session_state.player_data = {}
57
+ player_count = 0 # Track the number of distinct players
58
+
59
  while cap.isOpened():
60
  ret, frame = cap.read()
61
  if not ret:
62
  break
63
 
64
+ # Detection + tracking with person class filter (class 0)
65
  results = model(frame)[0]
66
  detections = sv.Detections.from_ultralytics(results)
67
+
68
+ # Filter for person class only (class 0 in COCO dataset)
69
+ person_mask = np.array([class_id == 0 for class_id in detections.class_id])
70
+ detections = detections[person_mask]
71
+
72
+ # Apply tracking - this will assign consistent IDs
73
  detections = tracker.update_with_detections(detections)
74
+
75
+ # Extract bounding boxes, track IDs, etc.
76
+ boxes = detections.xyxy # Get bounding boxes [x1, y1, x2, y2]
77
+ track_ids = detections.tracker_id # Get track IDs
78
+
79
+ if boxes is not None and len(boxes) > 0:
80
+ for i, (box, track_id) in enumerate(zip(boxes, track_ids)):
81
+ # Skip detections without a track_id
82
+ if track_id is None:
83
+ continue
 
 
 
 
 
 
 
 
 
84
 
85
+ # Use the track_id directly from ByteTrack
86
+ player_id = int(track_id)
 
87
 
88
+ # Calculate center point of bounding box
89
+ x1, y1, x2, y2 = box
90
+ centroid = (int((x1 + x2) / 2), int((y1 + y2) / 2))
91
+
92
+ # Initialize speed to 0 for all cases
93
+ speed = 0.0
94
+
95
+ # Initialize new player if not seen before
96
+ if player_id not in st.session_state.player_data:
97
+ st.session_state.player_data[player_id] = {
98
+ 'positions': [centroid],
99
+ 'timestamps': [frame_count / fps],
100
+ 'distance': 0.0,
101
+ 'speeds': [0.0], # Initialize with zero speed
102
+ 'last_seen': frame_count
103
+ }
104
+ player_count += 1 # Increment player counter
105
+ else:
106
+ # Calculate movement metrics
107
+ prev_pos = st.session_state.player_data[player_id]['positions'][-1]
108
+ time_diff = (frame_count / fps) - st.session_state.player_data[player_id]['timestamps'][-1]
109
+
110
+ # Calculate distance
111
+ pixel_dist = dist.euclidean(prev_pos, centroid)
112
+
113
+ # Apply motion smoothing to ignore unrealistic movements
114
+ if pixel_dist < frame_width * 0.1: # Max 10% of screen width per frame
115
+ speed_px = pixel_dist / time_diff if time_diff > 0 else 0.0
116
+
117
+ # Convert to meters per second
118
+ speed = speed_px / pixels_per_meter
119
+ meter_dist = pixel_dist / pixels_per_meter
120
+
121
+ # Update player data
122
+ st.session_state.player_data[player_id]['positions'].append(centroid)
123
+ st.session_state.player_data[player_id]['timestamps'].append(frame_count / fps)
124
+ st.session_state.player_data[player_id]['distance'] += meter_dist
125
+ st.session_state.player_data[player_id]['speeds'].append(speed)
126
+ st.session_state.player_data[player_id]['last_seen'] = frame_count
127
+ else:
128
+ # Just update position without adding to distance/speed
129
+ st.session_state.player_data[player_id]['positions'].append(centroid)
130
+ st.session_state.player_data[player_id]['timestamps'].append(frame_count / fps)
131
+ st.session_state.player_data[player_id]['last_seen'] = frame_count
132
+ # Maintain previous speed
133
+ speed = st.session_state.player_data[player_id]['speeds'][-1] if st.session_state.player_data[player_id]['speeds'] else 0
134
+
135
+ # Draw player info on frame
136
+ label = f"ID:{player_id} Speed:{speed:.1f}m/s"
137
+ cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
138
+ cv2.putText(frame, label, (int(x1), int(y1-10)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
139
 
140
+ # Update live table
141
+ if st.session_state.player_data:
142
+ live_data = []
143
+ for player_id, data in st.session_state.player_data.items():
144
+ # Only show recently seen players
145
+ if frame_count - data['last_seen'] < 30:
146
+ current_speed = data['speeds'][-1] if data['speeds'] else 0
147
+ live_data.append({
148
+ "Player": f"Player {player_id}",
149
+ "Current Speed (m/s)": round(current_speed, 2),
150
+ "Distance (m)": round(data['distance'], 2),
151
+ "Time (s)": round(data['timestamps'][-1], 2) if data['timestamps'] else 0
152
+ })
153
+
154
+ live_df = pd.DataFrame(live_data)
155
+ if not live_df.empty:
156
+ live_table.dataframe(live_df, use_container_width=True, hide_index=True)
157
 
158
  st_frame.image(frame, channels="BGR")
159
  frame_count += 1
 
161
  cap.release()
162
 
163
  # Display final analytics
164
+ st.subheader(f"Player Statistics ({player_count} players detected)")
165
+
166
+ # Create player stats table
167
+ stats_data = []
168
  for player_id, data in st.session_state.player_data.items():
169
+ if len(data['speeds']) > 0:
170
+ avg_speed = np.mean(data['speeds'])
171
+ max_speed = np.max(data['speeds'])
172
+ else:
173
+ avg_speed = 0
174
+ max_speed = 0
175
+
176
+ duration = data['timestamps'][-1] - data['timestamps'][0] if len(data['timestamps']) > 1 else 0
177
+
178
+ # Only include players with significant tracking data
179
+ if len(data['positions']) > 10:
180
+ stats_data.append({
181
+ "Player": f"Player {player_id}",
182
+ "Total Distance (m)": round(data['distance'], 2),
183
+ "Avg Speed (m/s)": round(avg_speed, 2),
184
+ "Max Speed (m/s)": round(max_speed, 2),
185
+ "Duration (s)": round(duration, 2)
186
+ })
187
+
188
+ stats_df = pd.DataFrame(stats_data)
189
+ st.dataframe(stats_df, use_container_width=True, hide_index=True)
190
+
191
+ # Create speed comparison chart
192
+ if st.session_state.player_data:
193
+ st.subheader("Player Speed Comparison")
194
+
195
+ fig, ax = plt.subplots(figsize=(10, 6))
196
+
197
+ for player_id, data in st.session_state.player_data.items():
198
+ # Only plot players with significant data
199
+ if len(data['speeds']) > 10:
200
+ ax.plot(data['timestamps'], data['speeds'], label=f"Player {player_id}")
201
+
202
+ ax.set_xlabel('Time (seconds)')
203
+ ax.set_ylabel('Speed (m/s)')
204
+ ax.set_title('Player Speed Over Time')
205
+ ax.legend()
206
+ ax.grid(True)
207
+
208
+ st.pyplot(fig)