Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
9612f89
1
Parent(s):
1ebe6d8
- online_data_generation.py +89 -1
online_data_generation.py
CHANGED
@@ -149,6 +149,9 @@ def process_session_file(log_file, clean_state):
|
|
149 |
conn = sqlite3.connect(DB_FILE)
|
150 |
cursor = conn.cursor()
|
151 |
|
|
|
|
|
|
|
152 |
# Get session details
|
153 |
trajectory = load_trajectory(log_file)
|
154 |
if not trajectory:
|
@@ -204,7 +207,23 @@ def process_session_file(log_file, clean_state):
|
|
204 |
start_time = sub_traj[0]["timestamp"]
|
205 |
end_time = sub_traj[-1]["timestamp"]
|
206 |
|
207 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
try:
|
209 |
logger.info(f"Processing segment {i+1}/{len(sub_trajectories)} from {log_file} as trajectory {next_id}")
|
210 |
|
@@ -285,6 +304,75 @@ def format_trajectory_for_processing(trajectory):
|
|
285 |
return formatted_events
|
286 |
|
287 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
def main():
|
289 |
"""Main function to run the data processing pipeline."""
|
290 |
# Initialize database
|
|
|
149 |
conn = sqlite3.connect(DB_FILE)
|
150 |
cursor = conn.cursor()
|
151 |
|
152 |
+
# Ensure output directory exists
|
153 |
+
os.makedirs("generated_videos", exist_ok=True)
|
154 |
+
|
155 |
# Get session details
|
156 |
trajectory = load_trajectory(log_file)
|
157 |
if not trajectory:
|
|
|
207 |
start_time = sub_traj[0]["timestamp"]
|
208 |
end_time = sub_traj[-1]["timestamp"]
|
209 |
|
210 |
+
# STEP 1: Generate a video from the original frames
|
211 |
+
segment_label = f"segment_{i+1}_of_{len(sub_trajectories)}"
|
212 |
+
video_path = os.path.join("generated_videos", f"trajectory_{next_id:06d}_{segment_label}.mp4")
|
213 |
+
|
214 |
+
# Generate video from original frames for comparison
|
215 |
+
success, frame_count = generate_comparison_video(
|
216 |
+
client_id,
|
217 |
+
sub_traj,
|
218 |
+
video_path,
|
219 |
+
start_time,
|
220 |
+
end_time
|
221 |
+
)
|
222 |
+
|
223 |
+
if not success:
|
224 |
+
logger.warning(f"Failed to generate comparison video for segment {i+1}, but continuing with processing")
|
225 |
+
|
226 |
+
# STEP 2: Process with Docker for training data generation
|
227 |
try:
|
228 |
logger.info(f"Processing segment {i+1}/{len(sub_trajectories)} from {log_file} as trajectory {next_id}")
|
229 |
|
|
|
304 |
return formatted_events
|
305 |
|
306 |
|
307 |
+
def generate_comparison_video(client_id, trajectory, output_file, start_time, end_time):
|
308 |
+
"""
|
309 |
+
Generate a video from the original frames for comparison purposes.
|
310 |
+
|
311 |
+
Args:
|
312 |
+
client_id: Client ID for frame lookup
|
313 |
+
trajectory: List of interaction log entries for this segment
|
314 |
+
output_file: Path to save the output video
|
315 |
+
start_time: Start timestamp for this segment
|
316 |
+
end_time: End timestamp for this segment
|
317 |
+
|
318 |
+
Returns:
|
319 |
+
(bool, int): (success status, frame count)
|
320 |
+
"""
|
321 |
+
try:
|
322 |
+
# Get frame files for this client
|
323 |
+
frame_dir = os.path.join(FRAMES_DIR, f"frames_{client_id}")
|
324 |
+
if not os.path.exists(frame_dir):
|
325 |
+
logger.warning(f"No frame directory found for client {client_id}")
|
326 |
+
return False, 0
|
327 |
+
|
328 |
+
all_frames = glob.glob(os.path.join(frame_dir, "*.png"))
|
329 |
+
# Sort frames by timestamp in filename
|
330 |
+
all_frames.sort(key=lambda x: float(os.path.basename(x).split('.png')[0]))
|
331 |
+
|
332 |
+
if not all_frames:
|
333 |
+
logger.error(f"No frames found for client {client_id}")
|
334 |
+
return False, 0
|
335 |
+
|
336 |
+
# Filter frames to the time range of this segment
|
337 |
+
# Frame filenames are timestamps, so we can use them for filtering
|
338 |
+
segment_frames = [
|
339 |
+
f for f in all_frames
|
340 |
+
if start_time <= float(os.path.basename(f).split('.png')[0]) <= end_time
|
341 |
+
]
|
342 |
+
|
343 |
+
if not segment_frames:
|
344 |
+
logger.error(f"No frames found in time range for segment {start_time}-{end_time}")
|
345 |
+
return False, 0
|
346 |
+
|
347 |
+
# Read the first frame to get dimensions
|
348 |
+
first_frame = cv2.imread(segment_frames[0])
|
349 |
+
if first_frame is None:
|
350 |
+
logger.error(f"Could not read first frame {segment_frames[0]}")
|
351 |
+
return False, 0
|
352 |
+
|
353 |
+
height, width, channels = first_frame.shape
|
354 |
+
|
355 |
+
# Create video writer
|
356 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
357 |
+
video = cv2.VideoWriter(output_file, fourcc, 10.0, (width, height))
|
358 |
+
|
359 |
+
# Process each frame
|
360 |
+
for frame_file in segment_frames:
|
361 |
+
frame = cv2.imread(frame_file)
|
362 |
+
if frame is not None:
|
363 |
+
video.write(frame)
|
364 |
+
|
365 |
+
# Release the video writer
|
366 |
+
video.release()
|
367 |
+
|
368 |
+
logger.info(f"Created comparison video {output_file} with {len(segment_frames)} frames")
|
369 |
+
return True, len(segment_frames)
|
370 |
+
|
371 |
+
except Exception as e:
|
372 |
+
logger.error(f"Error generating comparison video: {e}")
|
373 |
+
return False, 0
|
374 |
+
|
375 |
+
|
376 |
def main():
|
377 |
"""Main function to run the data processing pipeline."""
|
378 |
# Initialize database
|