da03 commited on
Commit
a92ddb8
·
1 Parent(s): 29a0aca
Files changed (1) hide show
  1. online_data_generation.py +88 -38
online_data_generation.py CHANGED
@@ -21,6 +21,7 @@ import pandas as pd
21
  import ast
22
  import pickle
23
  from moviepy.editor import VideoFileClip
 
24
 
25
  # Import the existing functions
26
  from data.data_collection.synthetic_script_compute_canada import process_trajectory, initialize_clean_state
@@ -44,6 +45,7 @@ os.makedirs(OUTPUT_DIR, exist_ok=True)
44
  SCREEN_WIDTH = 512
45
  SCREEN_HEIGHT = 384
46
  MEMORY_LIMIT = "2g"
 
47
 
48
  # load autoencoder
49
  config = OmegaConf.load('../computer/autoencoder/config_kl4_lr4.5e6_load_acc1_512_384_mar10_keyboard_init_16_contmar15_acc1.yaml')
@@ -51,6 +53,19 @@ autoencoder = load_model_from_config(config, '../computer/autoencoder/saved_kl4_
51
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
52
  autoencoder = autoencoder.to(device)
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  def initialize_database():
55
  """Initialize the SQLite database if it doesn't exist."""
56
  conn = sqlite3.connect(DB_FILE)
@@ -531,7 +546,8 @@ def generate_comparison_video(client_id, trajectory, output_file, start_time, en
531
 
532
  def main():
533
  """Main function to run the data processing pipeline."""
534
-
 
535
  # create a padding image first
536
  if not os.path.exists(os.path.join(OUTPUT_DIR, 'padding.npy')):
537
  logger.info("Creating padding image...")
@@ -543,52 +559,86 @@ def main():
543
  latent = torch.zeros_like(latent).squeeze(0)
544
  np.save(os.path.join(OUTPUT_DIR, 'padding.tmp.npy'), latent.cpu().numpy())
545
  os.rename(os.path.join(OUTPUT_DIR, 'padding.tmp.npy'), os.path.join(OUTPUT_DIR, 'padding.npy'))
 
546
  # Initialize database
547
  initialize_database()
548
 
549
- # Initialize clean Docker state once
550
  logger.info("Initializing clean container state...")
551
  clean_state = initialize_clean_state()
552
  logger.info(f"Clean state initialized: {clean_state}")
553
 
554
- # Find all log files
555
- log_files = glob.glob(os.path.join(FRAMES_DIR, "session_*.jsonl"))
556
- logger.info(f"Found {len(log_files)} log files")
557
-
558
- # Filter for complete sessions
559
- complete_sessions = [f for f in log_files if is_session_complete(f)]
560
- logger.info(f"Found {len(complete_sessions)} complete sessions")
561
-
562
- # Filter for sessions not yet processed
563
- conn = sqlite3.connect(DB_FILE)
564
- cursor = conn.cursor()
565
- cursor.execute("SELECT log_file FROM processed_sessions")
566
- processed_files = set(row[0] for row in cursor.fetchall())
567
- conn.close()
568
-
569
- new_sessions = [f for f in complete_sessions if f not in processed_files]
570
- logger.info(f"Found {len(new_sessions)} new sessions to process")
571
 
572
- # Filter for valid sessions
573
- valid_sessions = [f for f in new_sessions if is_session_valid(f)]
574
- logger.info(f"Found {len(valid_sessions)} valid new sessions to process")
575
 
576
- # Process each valid session
577
- total_trajectories = 0
578
- for log_file in valid_sessions:
579
- logger.info(f"Processing session file: {log_file}")
580
- processed_ids = process_session_file(log_file, clean_state)
581
- total_trajectories += len(processed_ids)
582
-
583
- # Get next ID for reporting
584
- conn = sqlite3.connect(DB_FILE)
585
- cursor = conn.cursor()
586
- cursor.execute("SELECT value FROM config WHERE key = 'next_id'")
587
- next_id = int(cursor.fetchone()[0])
588
- conn.close()
589
-
590
- logger.info(f"Processing complete. Generated {total_trajectories} trajectories.")
591
- logger.info(f"Next ID will be {next_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
592
 
593
 
594
  if __name__ == "__main__":
 
21
  import ast
22
  import pickle
23
  from moviepy.editor import VideoFileClip
24
+ import signal
25
 
26
  # Import the existing functions
27
  from data.data_collection.synthetic_script_compute_canada import process_trajectory, initialize_clean_state
 
45
  SCREEN_WIDTH = 512
46
  SCREEN_HEIGHT = 384
47
  MEMORY_LIMIT = "2g"
48
+ CHECK_INTERVAL = 60 # Check for new data every 60 seconds
49
 
50
  # load autoencoder
51
  config = OmegaConf.load('../computer/autoencoder/config_kl4_lr4.5e6_load_acc1_512_384_mar10_keyboard_init_16_contmar15_acc1.yaml')
 
53
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
54
  autoencoder = autoencoder.to(device)
55
 
56
+ # Global flag for graceful shutdown
57
+ running = True
58
+
59
+ def signal_handler(sig, frame):
60
+ """Handle Ctrl+C and other termination signals"""
61
+ global running
62
+ logger.info("Shutdown signal received. Finishing current processing and exiting...")
63
+ running = False
64
+
65
+ # Register signal handlers
66
+ signal.signal(signal.SIGINT, signal_handler)
67
+ signal.signal(signal.SIGTERM, signal_handler)
68
+
69
  def initialize_database():
70
  """Initialize the SQLite database if it doesn't exist."""
71
  conn = sqlite3.connect(DB_FILE)
 
546
 
547
  def main():
548
  """Main function to run the data processing pipeline."""
549
+ global running
550
+
551
  # create a padding image first
552
  if not os.path.exists(os.path.join(OUTPUT_DIR, 'padding.npy')):
553
  logger.info("Creating padding image...")
 
559
  latent = torch.zeros_like(latent).squeeze(0)
560
  np.save(os.path.join(OUTPUT_DIR, 'padding.tmp.npy'), latent.cpu().numpy())
561
  os.rename(os.path.join(OUTPUT_DIR, 'padding.tmp.npy'), os.path.join(OUTPUT_DIR, 'padding.npy'))
562
+
563
  # Initialize database
564
  initialize_database()
565
 
566
+ # Initialize clean Docker state
567
  logger.info("Initializing clean container state...")
568
  clean_state = initialize_clean_state()
569
  logger.info(f"Clean state initialized: {clean_state}")
570
 
571
+ # Ensure output directory exists
572
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
573
 
574
+ logger.info(f"Starting continuous monitoring for new sessions (check interval: {CHECK_INTERVAL} seconds)")
 
 
575
 
576
+ try:
577
+ # Main monitoring loop
578
+ while running:
579
+ try:
580
+ # Find all log files
581
+ log_files = glob.glob(os.path.join(FRAMES_DIR, "session_*.jsonl"))
582
+ logger.info(f"Found {len(log_files)} log files")
583
+
584
+ # Filter for complete sessions
585
+ complete_sessions = [f for f in log_files if is_session_complete(f)]
586
+ logger.info(f"Found {len(complete_sessions)} complete sessions")
587
+
588
+ # Filter for sessions not yet processed
589
+ conn = sqlite3.connect(DB_FILE)
590
+ cursor = conn.cursor()
591
+ cursor.execute("SELECT log_file FROM processed_sessions")
592
+ processed_files = set(row[0] for row in cursor.fetchall())
593
+ conn.close()
594
+
595
+ new_sessions = [f for f in complete_sessions if f not in processed_files]
596
+ logger.info(f"Found {len(new_sessions)} new sessions to process")
597
+
598
+ # Filter for valid sessions
599
+ valid_sessions = [f for f in new_sessions if is_session_valid(f)]
600
+ logger.info(f"Found {len(valid_sessions)} valid new sessions to process")
601
+
602
+ # Process each valid session
603
+ total_trajectories = 0
604
+ for log_file in valid_sessions:
605
+ if not running:
606
+ logger.info("Shutdown in progress, stopping processing")
607
+ break
608
+
609
+ logger.info(f"Processing session file: {log_file}")
610
+ processed_ids = process_session_file(log_file, clean_state)
611
+ total_trajectories += len(processed_ids)
612
+
613
+ if total_trajectories > 0:
614
+ # Get next ID for reporting
615
+ conn = sqlite3.connect(DB_FILE)
616
+ cursor = conn.cursor()
617
+ cursor.execute("SELECT value FROM config WHERE key = 'next_id'")
618
+ next_id = int(cursor.fetchone()[0])
619
+ conn.close()
620
+
621
+ logger.info(f"Processing cycle complete. Generated {total_trajectories} new trajectories.")
622
+ logger.info(f"Next ID will be {next_id}")
623
+ else:
624
+ logger.info("No new trajectories processed in this cycle")
625
+
626
+ # Sleep until next check, but with periodic wake-ups to check running flag
627
+ remaining_sleep = CHECK_INTERVAL
628
+ while remaining_sleep > 0 and running:
629
+ sleep_chunk = min(5, remaining_sleep) # Check running flag every 5 seconds max
630
+ time.sleep(sleep_chunk)
631
+ remaining_sleep -= sleep_chunk
632
+
633
+ except Exception as e:
634
+ logger.error(f"Error in processing cycle: {e}")
635
+ # Sleep briefly to avoid rapid error loops
636
+ time.sleep(10)
637
+
638
+ except KeyboardInterrupt:
639
+ logger.info("Keyboard interrupt received, shutting down")
640
+ finally:
641
+ logger.info("Shutting down trajectory processor")
642
 
643
 
644
  if __name__ == "__main__":