da03 commited on
Commit
2d9e199
·
1 Parent(s): 9df0e98
Files changed (1) hide show
  1. online_data_generation.py +479 -0
online_data_generation.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import json
4
+ import glob
5
+ import time
6
+ import sqlite3
7
+ import logging
8
+ import cv2
9
+ import numpy as np
10
+ from datetime import datetime
11
+
12
+ # Configure logging
13
+ logging.basicConfig(
14
+ level=logging.INFO,
15
+ format='%(asctime)s - %(levelname)s - %(message)s',
16
+ handlers=[
17
+ logging.FileHandler("video_generator.log"),
18
+ logging.StreamHandler()
19
+ ]
20
+ )
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Define constants
24
+ DB_FILE = "trajectory_processor.db"
25
+ OUTPUT_DIR = "generated_videos"
26
+ FRAMES_DIR = "interaction_logs"
27
+
28
+
29
+ def initialize_database():
30
+ """Initialize the SQLite database if it doesn't exist."""
31
+ conn = sqlite3.connect(DB_FILE)
32
+ cursor = conn.cursor()
33
+
34
+ # Create tables if they don't exist
35
+ cursor.execute('''
36
+ CREATE TABLE IF NOT EXISTS processed_sessions (
37
+ id INTEGER PRIMARY KEY,
38
+ log_file TEXT UNIQUE,
39
+ client_id TEXT,
40
+ processed_time TIMESTAMP
41
+ )
42
+ ''')
43
+
44
+ cursor.execute('''
45
+ CREATE TABLE IF NOT EXISTS processed_segments (
46
+ id INTEGER PRIMARY KEY,
47
+ log_file TEXT,
48
+ client_id TEXT,
49
+ segment_index INTEGER,
50
+ start_time REAL,
51
+ end_time REAL,
52
+ processed_time TIMESTAMP,
53
+ video_path TEXT,
54
+ frame_count INTEGER,
55
+ trajectory_id INTEGER,
56
+ UNIQUE(log_file, segment_index)
57
+ )
58
+ ''')
59
+
60
+ cursor.execute('''
61
+ CREATE TABLE IF NOT EXISTS config (
62
+ key TEXT PRIMARY KEY,
63
+ value TEXT
64
+ )
65
+ ''')
66
+
67
+ # Initialize next_id if not exists
68
+ cursor.execute("SELECT value FROM config WHERE key = 'next_id'")
69
+ if not cursor.fetchone():
70
+ cursor.execute("INSERT INTO config (key, value) VALUES ('next_id', '1')")
71
+
72
+ conn.commit()
73
+ conn.close()
74
+
75
+
76
+ def is_session_complete(log_file):
77
+ """Check if a session is complete (has an EOS marker)."""
78
+ try:
79
+ with open(log_file, 'r') as f:
80
+ for line in f:
81
+ try:
82
+ entry = json.loads(line.strip())
83
+ if entry.get("is_eos", False):
84
+ return True
85
+ except json.JSONDecodeError:
86
+ continue
87
+ return False
88
+ except Exception as e:
89
+ logger.error(f"Error checking if session {log_file} is complete: {e}")
90
+ return False
91
+
92
+
93
+ def is_session_valid(log_file):
94
+ """
95
+ Check if a session is valid (has more than just an EOS entry).
96
+ Returns True if the log file has at least one non-EOS entry.
97
+ """
98
+ try:
99
+ entry_count = 0
100
+ has_non_eos = False
101
+
102
+ with open(log_file, 'r') as f:
103
+ for line in f:
104
+ try:
105
+ entry = json.loads(line.strip())
106
+ entry_count += 1
107
+ if not entry.get("is_eos", False) and not entry.get("is_reset", False):
108
+ has_non_eos = True
109
+ except json.JSONDecodeError:
110
+ continue
111
+
112
+ # Valid if there's at least one entry and at least one non-EOS entry
113
+ return entry_count > 0 and has_non_eos
114
+
115
+ except Exception as e:
116
+ logger.error(f"Error checking if session {log_file} is valid: {e}")
117
+ return False
118
+
119
+
120
+ def load_trajectory(log_file):
121
+ """Load a trajectory from a log file."""
122
+ trajectory = []
123
+
124
+ try:
125
+ with open(log_file, 'r') as f:
126
+ for line in f:
127
+ try:
128
+ entry = json.loads(line.strip())
129
+ trajectory.append(entry)
130
+ except json.JSONDecodeError:
131
+ logger.warning(f"Skipping invalid JSON line in {log_file}")
132
+ continue
133
+ return trajectory
134
+
135
+ except Exception as e:
136
+ logger.error(f"Error loading trajectory from {log_file}: {e}")
137
+ return []
138
+
139
+
140
+ def get_frame_files(client_id):
141
+ """Get all frame files for a client ID, sorted by timestamp."""
142
+ frame_dir = os.path.join(FRAMES_DIR, f"frames_{client_id}")
143
+
144
+ if not os.path.exists(frame_dir):
145
+ logger.warning(f"No frame directory found for client {client_id}")
146
+ return []
147
+
148
+ frames = glob.glob(os.path.join(frame_dir, "*.png"))
149
+ # Sort frames by timestamp in filename
150
+ frames.sort(key=lambda x: float(os.path.basename(x).split('.png')[0]))
151
+ return frames
152
+
153
+
154
+ def process_trajectory(trajectory, output_file):
155
+ """
156
+ Process a trajectory and create a video file.
157
+
158
+ Args:
159
+ trajectory: List of interaction log entries
160
+ output_file: Path to save the output video
161
+
162
+ Returns:
163
+ (bool, int): (success status, frame count)
164
+ """
165
+ if not trajectory:
166
+ logger.error("Cannot process empty trajectory")
167
+ return False, 0
168
+
169
+ try:
170
+ # Extract client_id from the first entry
171
+ client_id = trajectory[0].get("client_id")
172
+ if not client_id:
173
+ logger.error("Trajectory missing client_id")
174
+ return False, 0
175
+
176
+ # Get all frame files for this client
177
+ frame_files = get_frame_files(client_id)
178
+ if not frame_files:
179
+ logger.error(f"No frames found for client {client_id}")
180
+ return False, 0
181
+
182
+ # Read the first frame to get dimensions
183
+ first_frame = cv2.imread(frame_files[0])
184
+ if first_frame is None:
185
+ logger.error(f"Could not read first frame {frame_files[0]}")
186
+ return False, 0
187
+
188
+ height, width, channels = first_frame.shape
189
+
190
+ # Create video writer
191
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
192
+ video = cv2.VideoWriter(output_file, fourcc, 10.0, (width, height))
193
+
194
+ # Process each frame
195
+ for frame_file in frame_files:
196
+ frame = cv2.imread(frame_file)
197
+ if frame is not None:
198
+ video.write(frame)
199
+
200
+ # Release the video writer
201
+ video.release()
202
+
203
+ logger.info(f"Successfully created video {output_file} with {len(frame_files)} frames")
204
+ return True, len(frame_files)
205
+
206
+ except Exception as e:
207
+ logger.error(f"Error processing trajectory: {e}")
208
+ return False, 0
209
+
210
+
211
+ def get_next_id():
212
+ """Get the next available ID from the database."""
213
+ conn = sqlite3.connect(DB_FILE)
214
+ cursor = conn.cursor()
215
+
216
+ cursor.execute("SELECT value FROM config WHERE key = 'next_id'")
217
+ result = cursor.fetchone()
218
+ next_id = int(result[0])
219
+
220
+ conn.close()
221
+ return next_id
222
+
223
+
224
+ def increment_next_id():
225
+ """Increment the next ID in the database."""
226
+ conn = sqlite3.connect(DB_FILE)
227
+ cursor = conn.cursor()
228
+
229
+ cursor.execute("UPDATE config SET value = value + 1 WHERE key = 'next_id'")
230
+ conn.commit()
231
+
232
+ conn.close()
233
+
234
+
235
+ def is_session_processed(log_file):
236
+ """Check if a session has already been processed."""
237
+ conn = sqlite3.connect(DB_FILE)
238
+ cursor = conn.cursor()
239
+
240
+ cursor.execute("SELECT 1 FROM processed_sessions WHERE log_file = ?", (log_file,))
241
+ result = cursor.fetchone() is not None
242
+
243
+ conn.close()
244
+ return result
245
+
246
+
247
+ def mark_session_processed(log_file, client_id, video_path, frame_count):
248
+ """Mark a session as processed in the database."""
249
+ conn = sqlite3.connect(DB_FILE)
250
+ cursor = conn.cursor()
251
+
252
+ cursor.execute(
253
+ "INSERT INTO processed_sessions (log_file, client_id, processed_time, video_path, frame_count) VALUES (?, ?, ?, ?, ?)",
254
+ (log_file, client_id, datetime.now().isoformat(), video_path, frame_count)
255
+ )
256
+
257
+ conn.commit()
258
+ conn.close()
259
+
260
+
261
+ def process_session_file(log_file):
262
+ """
263
+ Process a session file, splitting into multiple trajectories at reset points.
264
+ Returns a list of successfully processed trajectory IDs.
265
+ """
266
+ conn = sqlite3.connect(DB_FILE)
267
+ cursor = conn.cursor()
268
+
269
+ # Get session details
270
+ trajectory = load_trajectory(log_file)
271
+ if not trajectory:
272
+ logger.error(f"Empty trajectory for {log_file}, skipping")
273
+ return []
274
+
275
+ client_id = trajectory[0].get("client_id", "unknown")
276
+
277
+ # Find all reset points and EOS
278
+ reset_indices = []
279
+ has_eos = False
280
+
281
+ for i, entry in enumerate(trajectory):
282
+ if entry.get("is_reset", False):
283
+ reset_indices.append(i)
284
+ if entry.get("is_eos", False):
285
+ has_eos = True
286
+
287
+ # If no resets and no EOS, this is incomplete - skip
288
+ if not reset_indices and not has_eos:
289
+ logger.warning(f"Session {log_file} has no resets and no EOS, may be incomplete")
290
+ return []
291
+
292
+ # Split trajectory at reset points
293
+ sub_trajectories = []
294
+ start_idx = 0
295
+
296
+ # Add all segments between resets
297
+ for reset_idx in reset_indices:
298
+ if reset_idx > start_idx: # Only add non-empty segments
299
+ sub_trajectories.append(trajectory[start_idx:reset_idx])
300
+ start_idx = reset_idx + 1 # Start new segment after the reset
301
+
302
+ # Add the final segment if it's not empty
303
+ if start_idx < len(trajectory):
304
+ sub_trajectories.append(trajectory[start_idx:])
305
+
306
+ # Process each sub-trajectory
307
+ processed_ids = []
308
+
309
+ for i, sub_traj in enumerate(sub_trajectories):
310
+ # Skip segments with no interaction data (just control messages)
311
+ if not any(not entry.get("is_reset", False) and not entry.get("is_eos", False) for entry in sub_traj):
312
+ continue
313
+
314
+ # Get the next ID for this sub-trajectory
315
+ cursor.execute("SELECT value FROM config WHERE key = 'next_id'")
316
+ next_id = int(cursor.fetchone()[0])
317
+
318
+ # Define output path
319
+ segment_label = f"segment_{i+1}_of_{len(sub_trajectories)}"
320
+ output_file = os.path.join(OUTPUT_DIR, f"trajectory_{next_id:06d}_{segment_label}.mp4")
321
+
322
+ # Find timestamps for this segment to get corresponding frames
323
+ start_time = sub_traj[0]["timestamp"]
324
+ end_time = sub_traj[-1]["timestamp"]
325
+
326
+ # Process this sub-trajectory
327
+ success, frame_count = process_trajectory_segment(
328
+ client_id,
329
+ sub_traj,
330
+ output_file,
331
+ start_time,
332
+ end_time
333
+ )
334
+
335
+ if success:
336
+ # Mark this segment as processed
337
+ cursor.execute(
338
+ """INSERT INTO processed_segments
339
+ (log_file, client_id, segment_index, start_time, end_time,
340
+ processed_time, video_path, frame_count, trajectory_id)
341
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
342
+ (log_file, client_id, i, start_time, end_time,
343
+ datetime.now().isoformat(), output_file, frame_count, next_id)
344
+ )
345
+
346
+ # Increment the next ID
347
+ cursor.execute("UPDATE config SET value = ? WHERE key = 'next_id'", (str(next_id + 1),))
348
+ conn.commit()
349
+
350
+ processed_ids.append(next_id)
351
+ logger.info(f"Successfully processed segment {i+1}/{len(sub_trajectories)} from {log_file}")
352
+ else:
353
+ logger.error(f"Failed to process segment {i+1}/{len(sub_trajectories)} from {log_file}")
354
+
355
+ # Mark the entire session as processed
356
+ cursor.execute(
357
+ "INSERT INTO processed_sessions (log_file, client_id, processed_time) VALUES (?, ?, ?)",
358
+ (log_file, client_id, datetime.now().isoformat())
359
+ )
360
+ conn.commit()
361
+ conn.close()
362
+
363
+ return processed_ids
364
+
365
+
366
+ def process_trajectory_segment(client_id, trajectory, output_file, start_time, end_time):
367
+ """
368
+ Process a segment of a trajectory between timestamps and create a video.
369
+
370
+ Args:
371
+ client_id: Client ID for frame lookup
372
+ trajectory: List of interaction log entries for this segment
373
+ output_file: Path to save the output video
374
+ start_time: Start timestamp for this segment
375
+ end_time: End timestamp for this segment
376
+
377
+ Returns:
378
+ (bool, int): (success status, frame count)
379
+ """
380
+ try:
381
+ # Get frame files for this client
382
+ all_frames = get_frame_files(client_id)
383
+ if not all_frames:
384
+ logger.error(f"No frames found for client {client_id}")
385
+ return False, 0
386
+
387
+ # Filter frames to the time range of this segment
388
+ # Frame filenames are timestamps, so we can use them for filtering
389
+ segment_frames = [
390
+ f for f in all_frames
391
+ if start_time <= float(os.path.basename(f).split('.png')[0]) <= end_time
392
+ ]
393
+
394
+ if not segment_frames:
395
+ logger.error(f"No frames found in time range for segment {start_time}-{end_time}")
396
+ return False, 0
397
+
398
+ # Read the first frame to get dimensions
399
+ first_frame = cv2.imread(segment_frames[0])
400
+ if first_frame is None:
401
+ logger.error(f"Could not read first frame {segment_frames[0]}")
402
+ return False, 0
403
+
404
+ height, width, channels = first_frame.shape
405
+
406
+ # Create video writer
407
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
408
+ video = cv2.VideoWriter(output_file, fourcc, 10.0, (width, height))
409
+
410
+ # Process each frame
411
+ for frame_file in segment_frames:
412
+ frame = cv2.imread(frame_file)
413
+ if frame is not None:
414
+ video.write(frame)
415
+
416
+ # Release the video writer
417
+ video.release()
418
+
419
+ logger.info(f"Created video {output_file} with {len(segment_frames)} frames")
420
+ return True, len(segment_frames)
421
+
422
+ except Exception as e:
423
+ logger.error(f"Error processing trajectory segment: {e}")
424
+ return False, 0
425
+
426
+
427
+ def main():
428
+ """Main function to run the data processing pipeline."""
429
+ # Create output directory if it doesn't exist
430
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
431
+
432
+ # Initialize database
433
+ initialize_database()
434
+
435
+ # Find all log files
436
+ log_files = glob.glob(os.path.join(FRAMES_DIR, "session_*.jsonl"))
437
+ logger.info(f"Found {len(log_files)} log files")
438
+
439
+ # Filter for complete sessions
440
+ complete_sessions = [f for f in log_files if is_session_complete(f)]
441
+ logger.info(f"Found {len(complete_sessions)} complete sessions")
442
+
443
+ # Filter for sessions not yet processed
444
+ conn = sqlite3.connect(DB_FILE)
445
+ cursor = conn.cursor()
446
+ cursor.execute("SELECT log_file FROM processed_sessions")
447
+ processed_files = set(row[0] for row in cursor.fetchall())
448
+ conn.close()
449
+
450
+ new_sessions = [f for f in complete_sessions if f not in processed_files]
451
+ logger.info(f"Found {len(new_sessions)} new sessions to process")
452
+
453
+ # Filter for valid sessions
454
+ valid_sessions = [f for f in new_sessions if is_session_valid(f)]
455
+ logger.info(f"Found {len(valid_sessions)} valid new sessions to process")
456
+
457
+ # Process each valid session
458
+ total_trajectories = 0
459
+ for log_file in valid_sessions:
460
+ logger.info(f"Processing session file: {log_file}")
461
+ processed_ids = process_session_file(log_file)
462
+ total_trajectories += len(processed_ids)
463
+
464
+ # Get next ID for reporting
465
+ conn = sqlite3.connect(DB_FILE)
466
+ cursor = conn.cursor()
467
+ cursor.execute("SELECT value FROM config WHERE key = 'next_id'")
468
+ next_id = int(cursor.fetchone()[0])
469
+ conn.close()
470
+
471
+ logger.info(f"Processing complete. Generated {total_trajectories} trajectory videos.")
472
+ logger.info(f"Next ID will be {next_id}")
473
+
474
+
475
+ if __name__ == "__main__":
476
+ try:
477
+ main()
478
+ except Exception as e:
479
+ logger.error(f"Unhandled exception: {e}", exc_info=True)