da03 commited on
Commit
9612f89
·
1 Parent(s): 1ebe6d8
Files changed (1) hide show
  1. online_data_generation.py +89 -1
online_data_generation.py CHANGED
@@ -149,6 +149,9 @@ def process_session_file(log_file, clean_state):
149
  conn = sqlite3.connect(DB_FILE)
150
  cursor = conn.cursor()
151
 
 
 
 
152
  # Get session details
153
  trajectory = load_trajectory(log_file)
154
  if not trajectory:
@@ -204,7 +207,23 @@ def process_session_file(log_file, clean_state):
204
  start_time = sub_traj[0]["timestamp"]
205
  end_time = sub_traj[-1]["timestamp"]
206
 
207
- # Process this sub-trajectory using the external function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  try:
209
  logger.info(f"Processing segment {i+1}/{len(sub_trajectories)} from {log_file} as trajectory {next_id}")
210
 
@@ -285,6 +304,75 @@ def format_trajectory_for_processing(trajectory):
285
  return formatted_events
286
 
287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  def main():
289
  """Main function to run the data processing pipeline."""
290
  # Initialize database
 
149
  conn = sqlite3.connect(DB_FILE)
150
  cursor = conn.cursor()
151
 
152
+ # Ensure output directory exists
153
+ os.makedirs("generated_videos", exist_ok=True)
154
+
155
  # Get session details
156
  trajectory = load_trajectory(log_file)
157
  if not trajectory:
 
207
  start_time = sub_traj[0]["timestamp"]
208
  end_time = sub_traj[-1]["timestamp"]
209
 
210
+ # STEP 1: Generate a video from the original frames
211
+ segment_label = f"segment_{i+1}_of_{len(sub_trajectories)}"
212
+ video_path = os.path.join("generated_videos", f"trajectory_{next_id:06d}_{segment_label}.mp4")
213
+
214
+ # Generate video from original frames for comparison
215
+ success, frame_count = generate_comparison_video(
216
+ client_id,
217
+ sub_traj,
218
+ video_path,
219
+ start_time,
220
+ end_time
221
+ )
222
+
223
+ if not success:
224
+ logger.warning(f"Failed to generate comparison video for segment {i+1}, but continuing with processing")
225
+
226
+ # STEP 2: Process with Docker for training data generation
227
  try:
228
  logger.info(f"Processing segment {i+1}/{len(sub_trajectories)} from {log_file} as trajectory {next_id}")
229
 
 
304
  return formatted_events
305
 
306
 
307
+ def generate_comparison_video(client_id, trajectory, output_file, start_time, end_time):
308
+ """
309
+ Generate a video from the original frames for comparison purposes.
310
+
311
+ Args:
312
+ client_id: Client ID for frame lookup
313
+ trajectory: List of interaction log entries for this segment
314
+ output_file: Path to save the output video
315
+ start_time: Start timestamp for this segment
316
+ end_time: End timestamp for this segment
317
+
318
+ Returns:
319
+ (bool, int): (success status, frame count)
320
+ """
321
+ try:
322
+ # Get frame files for this client
323
+ frame_dir = os.path.join(FRAMES_DIR, f"frames_{client_id}")
324
+ if not os.path.exists(frame_dir):
325
+ logger.warning(f"No frame directory found for client {client_id}")
326
+ return False, 0
327
+
328
+ all_frames = glob.glob(os.path.join(frame_dir, "*.png"))
329
+ # Sort frames by timestamp in filename
330
+ all_frames.sort(key=lambda x: float(os.path.basename(x).split('.png')[0]))
331
+
332
+ if not all_frames:
333
+ logger.error(f"No frames found for client {client_id}")
334
+ return False, 0
335
+
336
+ # Filter frames to the time range of this segment
337
+ # Frame filenames are timestamps, so we can use them for filtering
338
+ segment_frames = [
339
+ f for f in all_frames
340
+ if start_time <= float(os.path.basename(f).split('.png')[0]) <= end_time
341
+ ]
342
+
343
+ if not segment_frames:
344
+ logger.error(f"No frames found in time range for segment {start_time}-{end_time}")
345
+ return False, 0
346
+
347
+ # Read the first frame to get dimensions
348
+ first_frame = cv2.imread(segment_frames[0])
349
+ if first_frame is None:
350
+ logger.error(f"Could not read first frame {segment_frames[0]}")
351
+ return False, 0
352
+
353
+ height, width, channels = first_frame.shape
354
+
355
+ # Create video writer
356
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
357
+ video = cv2.VideoWriter(output_file, fourcc, 10.0, (width, height))
358
+
359
+ # Process each frame
360
+ for frame_file in segment_frames:
361
+ frame = cv2.imread(frame_file)
362
+ if frame is not None:
363
+ video.write(frame)
364
+
365
+ # Release the video writer
366
+ video.release()
367
+
368
+ logger.info(f"Created comparison video {output_file} with {len(segment_frames)} frames")
369
+ return True, len(segment_frames)
370
+
371
+ except Exception as e:
372
+ logger.error(f"Error generating comparison video: {e}")
373
+ return False, 0
374
+
375
+
376
  def main():
377
  """Main function to run the data processing pipeline."""
378
  # Initialize database