da03 commited on
Commit
5c64b10
·
1 Parent(s): 2d9e199
Files changed (1) hide show
  1. online_data_generation.py +76 -213
online_data_generation.py CHANGED
@@ -7,14 +7,19 @@ import sqlite3
7
  import logging
8
  import cv2
9
  import numpy as np
 
10
  from datetime import datetime
 
 
 
 
11
 
12
  # Configure logging
13
  logging.basicConfig(
14
  level=logging.INFO,
15
  format='%(asctime)s - %(levelname)s - %(message)s',
16
  handlers=[
17
- logging.FileHandler("video_generator.log"),
18
  logging.StreamHandler()
19
  ]
20
  )
@@ -22,9 +27,10 @@ logger = logging.getLogger(__name__)
22
 
23
  # Define constants
24
  DB_FILE = "trajectory_processor.db"
25
- OUTPUT_DIR = "generated_videos"
26
  FRAMES_DIR = "interaction_logs"
27
-
 
 
28
 
29
  def initialize_database():
30
  """Initialize the SQLite database if it doesn't exist."""
@@ -50,8 +56,6 @@ def initialize_database():
50
  start_time REAL,
51
  end_time REAL,
52
  processed_time TIMESTAMP,
53
- video_path TEXT,
54
- frame_count INTEGER,
55
  trajectory_id INTEGER,
56
  UNIQUE(log_file, segment_index)
57
  )
@@ -137,128 +141,7 @@ def load_trajectory(log_file):
137
  return []
138
 
139
 
140
- def get_frame_files(client_id):
141
- """Get all frame files for a client ID, sorted by timestamp."""
142
- frame_dir = os.path.join(FRAMES_DIR, f"frames_{client_id}")
143
-
144
- if not os.path.exists(frame_dir):
145
- logger.warning(f"No frame directory found for client {client_id}")
146
- return []
147
-
148
- frames = glob.glob(os.path.join(frame_dir, "*.png"))
149
- # Sort frames by timestamp in filename
150
- frames.sort(key=lambda x: float(os.path.basename(x).split('.png')[0]))
151
- return frames
152
-
153
-
154
- def process_trajectory(trajectory, output_file):
155
- """
156
- Process a trajectory and create a video file.
157
-
158
- Args:
159
- trajectory: List of interaction log entries
160
- output_file: Path to save the output video
161
-
162
- Returns:
163
- (bool, int): (success status, frame count)
164
- """
165
- if not trajectory:
166
- logger.error("Cannot process empty trajectory")
167
- return False, 0
168
-
169
- try:
170
- # Extract client_id from the first entry
171
- client_id = trajectory[0].get("client_id")
172
- if not client_id:
173
- logger.error("Trajectory missing client_id")
174
- return False, 0
175
-
176
- # Get all frame files for this client
177
- frame_files = get_frame_files(client_id)
178
- if not frame_files:
179
- logger.error(f"No frames found for client {client_id}")
180
- return False, 0
181
-
182
- # Read the first frame to get dimensions
183
- first_frame = cv2.imread(frame_files[0])
184
- if first_frame is None:
185
- logger.error(f"Could not read first frame {frame_files[0]}")
186
- return False, 0
187
-
188
- height, width, channels = first_frame.shape
189
-
190
- # Create video writer
191
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
192
- video = cv2.VideoWriter(output_file, fourcc, 10.0, (width, height))
193
-
194
- # Process each frame
195
- for frame_file in frame_files:
196
- frame = cv2.imread(frame_file)
197
- if frame is not None:
198
- video.write(frame)
199
-
200
- # Release the video writer
201
- video.release()
202
-
203
- logger.info(f"Successfully created video {output_file} with {len(frame_files)} frames")
204
- return True, len(frame_files)
205
-
206
- except Exception as e:
207
- logger.error(f"Error processing trajectory: {e}")
208
- return False, 0
209
-
210
-
211
- def get_next_id():
212
- """Get the next available ID from the database."""
213
- conn = sqlite3.connect(DB_FILE)
214
- cursor = conn.cursor()
215
-
216
- cursor.execute("SELECT value FROM config WHERE key = 'next_id'")
217
- result = cursor.fetchone()
218
- next_id = int(result[0])
219
-
220
- conn.close()
221
- return next_id
222
-
223
-
224
- def increment_next_id():
225
- """Increment the next ID in the database."""
226
- conn = sqlite3.connect(DB_FILE)
227
- cursor = conn.cursor()
228
-
229
- cursor.execute("UPDATE config SET value = value + 1 WHERE key = 'next_id'")
230
- conn.commit()
231
-
232
- conn.close()
233
-
234
-
235
- def is_session_processed(log_file):
236
- """Check if a session has already been processed."""
237
- conn = sqlite3.connect(DB_FILE)
238
- cursor = conn.cursor()
239
-
240
- cursor.execute("SELECT 1 FROM processed_sessions WHERE log_file = ?", (log_file,))
241
- result = cursor.fetchone() is not None
242
-
243
- conn.close()
244
- return result
245
-
246
-
247
- def mark_session_processed(log_file, client_id, video_path, frame_count):
248
- """Mark a session as processed in the database."""
249
- conn = sqlite3.connect(DB_FILE)
250
- cursor = conn.cursor()
251
-
252
- cursor.execute(
253
- "INSERT INTO processed_sessions (log_file, client_id, processed_time, video_path, frame_count) VALUES (?, ?, ?, ?, ?)",
254
- (log_file, client_id, datetime.now().isoformat(), video_path, frame_count)
255
- )
256
-
257
- conn.commit()
258
- conn.close()
259
-
260
-
261
- def process_session_file(log_file):
262
  """
263
  Process a session file, splitting into multiple trajectories at reset points.
264
  Returns a list of successfully processed trajectory IDs.
@@ -270,6 +153,7 @@ def process_session_file(log_file):
270
  trajectory = load_trajectory(log_file)
271
  if not trajectory:
272
  logger.error(f"Empty trajectory for {log_file}, skipping")
 
273
  return []
274
 
275
  client_id = trajectory[0].get("client_id", "unknown")
@@ -287,6 +171,7 @@ def process_session_file(log_file):
287
  # If no resets and no EOS, this is incomplete - skip
288
  if not reset_indices and not has_eos:
289
  logger.warning(f"Session {log_file} has no resets and no EOS, may be incomplete")
 
290
  return []
291
 
292
  # Split trajectory at reset points
@@ -315,32 +200,29 @@ def process_session_file(log_file):
315
  cursor.execute("SELECT value FROM config WHERE key = 'next_id'")
316
  next_id = int(cursor.fetchone()[0])
317
 
318
- # Define output path
319
- segment_label = f"segment_{i+1}_of_{len(sub_trajectories)}"
320
- output_file = os.path.join(OUTPUT_DIR, f"trajectory_{next_id:06d}_{segment_label}.mp4")
321
-
322
- # Find timestamps for this segment to get corresponding frames
323
  start_time = sub_traj[0]["timestamp"]
324
  end_time = sub_traj[-1]["timestamp"]
325
 
326
- # Process this sub-trajectory
327
- success, frame_count = process_trajectory_segment(
328
- client_id,
329
- sub_traj,
330
- output_file,
331
- start_time,
332
- end_time
333
- )
334
-
335
- if success:
 
336
  # Mark this segment as processed
337
  cursor.execute(
338
  """INSERT INTO processed_segments
339
  (log_file, client_id, segment_index, start_time, end_time,
340
- processed_time, video_path, frame_count, trajectory_id)
341
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
342
  (log_file, client_id, i, start_time, end_time,
343
- datetime.now().isoformat(), output_file, frame_count, next_id)
344
  )
345
 
346
  # Increment the next ID
@@ -349,89 +231,70 @@ def process_session_file(log_file):
349
 
350
  processed_ids.append(next_id)
351
  logger.info(f"Successfully processed segment {i+1}/{len(sub_trajectories)} from {log_file}")
352
- else:
353
- logger.error(f"Failed to process segment {i+1}/{len(sub_trajectories)} from {log_file}")
 
 
354
 
355
- # Mark the entire session as processed
356
- cursor.execute(
357
- "INSERT INTO processed_sessions (log_file, client_id, processed_time) VALUES (?, ?, ?)",
358
- (log_file, client_id, datetime.now().isoformat())
359
- )
360
- conn.commit()
361
- conn.close()
 
 
 
 
362
 
 
363
  return processed_ids
364
 
365
 
366
- def process_trajectory_segment(client_id, trajectory, output_file, start_time, end_time):
367
  """
368
- Process a segment of a trajectory between timestamps and create a video.
369
 
370
- Args:
371
- client_id: Client ID for frame lookup
372
- trajectory: List of interaction log entries for this segment
373
- output_file: Path to save the output video
374
- start_time: Start timestamp for this segment
375
- end_time: End timestamp for this segment
376
-
377
- Returns:
378
- (bool, int): (success status, frame count)
379
  """
380
- try:
381
- # Get frame files for this client
382
- all_frames = get_frame_files(client_id)
383
- if not all_frames:
384
- logger.error(f"No frames found for client {client_id}")
385
- return False, 0
386
-
387
- # Filter frames to the time range of this segment
388
- # Frame filenames are timestamps, so we can use them for filtering
389
- segment_frames = [
390
- f for f in all_frames
391
- if start_time <= float(os.path.basename(f).split('.png')[0]) <= end_time
392
- ]
393
-
394
- if not segment_frames:
395
- logger.error(f"No frames found in time range for segment {start_time}-{end_time}")
396
- return False, 0
397
-
398
- # Read the first frame to get dimensions
399
- first_frame = cv2.imread(segment_frames[0])
400
- if first_frame is None:
401
- logger.error(f"Could not read first frame {segment_frames[0]}")
402
- return False, 0
403
 
404
- height, width, channels = first_frame.shape
405
-
406
- # Create video writer
407
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
408
- video = cv2.VideoWriter(output_file, fourcc, 10.0, (width, height))
409
-
410
- # Process each frame
411
- for frame_file in segment_frames:
412
- frame = cv2.imread(frame_file)
413
- if frame is not None:
414
- video.write(frame)
415
-
416
- # Release the video writer
417
- video.release()
418
 
419
- logger.info(f"Created video {output_file} with {len(segment_frames)} frames")
420
- return True, len(segment_frames)
421
-
422
- except Exception as e:
423
- logger.error(f"Error processing trajectory segment: {e}")
424
- return False, 0
425
 
426
 
427
  def main():
428
  """Main function to run the data processing pipeline."""
429
- # Create output directory if it doesn't exist
430
- os.makedirs(OUTPUT_DIR, exist_ok=True)
431
-
432
  # Initialize database
433
  initialize_database()
434
 
 
 
 
 
 
435
  # Find all log files
436
  log_files = glob.glob(os.path.join(FRAMES_DIR, "session_*.jsonl"))
437
  logger.info(f"Found {len(log_files)} log files")
@@ -458,7 +321,7 @@ def main():
458
  total_trajectories = 0
459
  for log_file in valid_sessions:
460
  logger.info(f"Processing session file: {log_file}")
461
- processed_ids = process_session_file(log_file)
462
  total_trajectories += len(processed_ids)
463
 
464
  # Get next ID for reporting
@@ -468,7 +331,7 @@ def main():
468
  next_id = int(cursor.fetchone()[0])
469
  conn.close()
470
 
471
- logger.info(f"Processing complete. Generated {total_trajectories} trajectory videos.")
472
  logger.info(f"Next ID will be {next_id}")
473
 
474
 
 
7
  import logging
8
  import cv2
9
  import numpy as np
10
+ import subprocess
11
  from datetime import datetime
12
+ from typing import List, Dict, Any, Tuple
13
+
14
+ # Import the existing functions
15
+ from latent_diffusion.ldm.data.data_collection import process_trajectory, initialize_clean_state
16
 
17
  # Configure logging
18
  logging.basicConfig(
19
  level=logging.INFO,
20
  format='%(asctime)s - %(levelname)s - %(message)s',
21
  handlers=[
22
+ logging.FileHandler("trajectory_processor.log"),
23
  logging.StreamHandler()
24
  ]
25
  )
 
27
 
28
  # Define constants
29
  DB_FILE = "trajectory_processor.db"
 
30
  FRAMES_DIR = "interaction_logs"
31
+ SCREEN_WIDTH = 512
32
+ SCREEN_HEIGHT = 384
33
+ MEMORY_LIMIT = "2g"
34
 
35
  def initialize_database():
36
  """Initialize the SQLite database if it doesn't exist."""
 
56
  start_time REAL,
57
  end_time REAL,
58
  processed_time TIMESTAMP,
 
 
59
  trajectory_id INTEGER,
60
  UNIQUE(log_file, segment_index)
61
  )
 
141
  return []
142
 
143
 
144
+ def process_session_file(log_file, clean_state):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  """
146
  Process a session file, splitting into multiple trajectories at reset points.
147
  Returns a list of successfully processed trajectory IDs.
 
153
  trajectory = load_trajectory(log_file)
154
  if not trajectory:
155
  logger.error(f"Empty trajectory for {log_file}, skipping")
156
+ conn.close()
157
  return []
158
 
159
  client_id = trajectory[0].get("client_id", "unknown")
 
171
  # If no resets and no EOS, this is incomplete - skip
172
  if not reset_indices and not has_eos:
173
  logger.warning(f"Session {log_file} has no resets and no EOS, may be incomplete")
174
+ conn.close()
175
  return []
176
 
177
  # Split trajectory at reset points
 
200
  cursor.execute("SELECT value FROM config WHERE key = 'next_id'")
201
  next_id = int(cursor.fetchone()[0])
202
 
203
+ # Find timestamps for this segment
 
 
 
 
204
  start_time = sub_traj[0]["timestamp"]
205
  end_time = sub_traj[-1]["timestamp"]
206
 
207
+ # Process this sub-trajectory using the external function
208
+ try:
209
+ logger.info(f"Processing segment {i+1}/{len(sub_trajectories)} from {log_file} as trajectory {next_id}")
210
+
211
+ # Format the trajectory as needed by process_trajectory function
212
+ formatted_trajectory = format_trajectory_for_processing(sub_traj)
213
+
214
+ # Call the external process_trajectory function
215
+ args = (next_id, formatted_trajectory)
216
+ process_trajectory(args, SCREEN_WIDTH, SCREEN_HEIGHT, clean_state, MEMORY_LIMIT)
217
+
218
  # Mark this segment as processed
219
  cursor.execute(
220
  """INSERT INTO processed_segments
221
  (log_file, client_id, segment_index, start_time, end_time,
222
+ processed_time, trajectory_id)
223
+ VALUES (?, ?, ?, ?, ?, ?, ?)""",
224
  (log_file, client_id, i, start_time, end_time,
225
+ datetime.now().isoformat(), next_id)
226
  )
227
 
228
  # Increment the next ID
 
231
 
232
  processed_ids.append(next_id)
233
  logger.info(f"Successfully processed segment {i+1}/{len(sub_trajectories)} from {log_file}")
234
+
235
+ except Exception as e:
236
+ logger.error(f"Failed to process segment {i+1}/{len(sub_trajectories)} from {log_file}: {e}")
237
+ continue
238
 
239
+ # Mark the entire session as processed only if at least one segment succeeded
240
+ if processed_ids:
241
+ try:
242
+ cursor.execute(
243
+ "INSERT INTO processed_sessions (log_file, client_id, processed_time) VALUES (?, ?, ?)",
244
+ (log_file, client_id, datetime.now().isoformat())
245
+ )
246
+ conn.commit()
247
+ except sqlite3.IntegrityError:
248
+ # This can happen if we're re-processing a file that had some segments fail
249
+ pass
250
 
251
+ conn.close()
252
  return processed_ids
253
 
254
 
255
+ def format_trajectory_for_processing(trajectory):
256
  """
257
+ Format the trajectory in the structure expected by process_trajectory function.
258
 
259
+ The exact format will depend on what your process_trajectory function expects.
260
+ This is a placeholder - modify based on the actual requirements.
 
 
 
 
 
 
 
261
  """
262
+ formatted_events = []
263
+
264
+ for entry in trajectory:
265
+ # Skip control messages
266
+ if entry.get("is_reset") or entry.get("is_eos"):
267
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
+ # Extract input data
270
+ inputs = entry.get("inputs", {})
271
+ key_events = []
272
+ for key in inputs.get("keys_down", []):
273
+ key_events.append(("keydown", key))
274
+ for key in inputs.get("keys_up", []):
275
+ key_events.append(("keyup", key))
276
+ event = {
277
+ "pos": (inputs.get("x"), inputs.get("y")),
278
+ "left_click": inputs.get("is_left_click", False),
279
+ "right_click": inputs.get("is_right_click", False),
280
+ "key_events": key_events,
281
+ }
 
282
 
283
+ formatted_events.append(event)
284
+
285
+ return formatted_events
 
 
 
286
 
287
 
288
  def main():
289
  """Main function to run the data processing pipeline."""
 
 
 
290
  # Initialize database
291
  initialize_database()
292
 
293
+ # Initialize clean Docker state once
294
+ logger.info("Initializing clean container state...")
295
+ clean_state = initialize_clean_state()
296
+ logger.info(f"Clean state initialized: {clean_state}")
297
+
298
  # Find all log files
299
  log_files = glob.glob(os.path.join(FRAMES_DIR, "session_*.jsonl"))
300
  logger.info(f"Found {len(log_files)} log files")
 
321
  total_trajectories = 0
322
  for log_file in valid_sessions:
323
  logger.info(f"Processing session file: {log_file}")
324
+ processed_ids = process_session_file(log_file, clean_state)
325
  total_trajectories += len(processed_ids)
326
 
327
  # Get next ID for reporting
 
331
  next_id = int(cursor.fetchone()[0])
332
  conn.close()
333
 
334
+ logger.info(f"Processing complete. Generated {total_trajectories} trajectories.")
335
  logger.info(f"Next ID will be {next_id}")
336
 
337