Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
a92ddb8
1
Parent(s):
29a0aca
- online_data_generation.py +88 -38
online_data_generation.py
CHANGED
@@ -21,6 +21,7 @@ import pandas as pd
|
|
21 |
import ast
|
22 |
import pickle
|
23 |
from moviepy.editor import VideoFileClip
|
|
|
24 |
|
25 |
# Import the existing functions
|
26 |
from data.data_collection.synthetic_script_compute_canada import process_trajectory, initialize_clean_state
|
@@ -44,6 +45,7 @@ os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
44 |
SCREEN_WIDTH = 512
|
45 |
SCREEN_HEIGHT = 384
|
46 |
MEMORY_LIMIT = "2g"
|
|
|
47 |
|
48 |
# load autoencoder
|
49 |
config = OmegaConf.load('../computer/autoencoder/config_kl4_lr4.5e6_load_acc1_512_384_mar10_keyboard_init_16_contmar15_acc1.yaml')
|
@@ -51,6 +53,19 @@ autoencoder = load_model_from_config(config, '../computer/autoencoder/saved_kl4_
|
|
51 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
52 |
autoencoder = autoencoder.to(device)
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
def initialize_database():
|
55 |
"""Initialize the SQLite database if it doesn't exist."""
|
56 |
conn = sqlite3.connect(DB_FILE)
|
@@ -531,7 +546,8 @@ def generate_comparison_video(client_id, trajectory, output_file, start_time, en
|
|
531 |
|
532 |
def main():
|
533 |
"""Main function to run the data processing pipeline."""
|
534 |
-
|
|
|
535 |
# create a padding image first
|
536 |
if not os.path.exists(os.path.join(OUTPUT_DIR, 'padding.npy')):
|
537 |
logger.info("Creating padding image...")
|
@@ -543,52 +559,86 @@ def main():
|
|
543 |
latent = torch.zeros_like(latent).squeeze(0)
|
544 |
np.save(os.path.join(OUTPUT_DIR, 'padding.tmp.npy'), latent.cpu().numpy())
|
545 |
os.rename(os.path.join(OUTPUT_DIR, 'padding.tmp.npy'), os.path.join(OUTPUT_DIR, 'padding.npy'))
|
|
|
546 |
# Initialize database
|
547 |
initialize_database()
|
548 |
|
549 |
-
# Initialize clean Docker state
|
550 |
logger.info("Initializing clean container state...")
|
551 |
clean_state = initialize_clean_state()
|
552 |
logger.info(f"Clean state initialized: {clean_state}")
|
553 |
|
554 |
-
#
|
555 |
-
|
556 |
-
logger.info(f"Found {len(log_files)} log files")
|
557 |
-
|
558 |
-
# Filter for complete sessions
|
559 |
-
complete_sessions = [f for f in log_files if is_session_complete(f)]
|
560 |
-
logger.info(f"Found {len(complete_sessions)} complete sessions")
|
561 |
-
|
562 |
-
# Filter for sessions not yet processed
|
563 |
-
conn = sqlite3.connect(DB_FILE)
|
564 |
-
cursor = conn.cursor()
|
565 |
-
cursor.execute("SELECT log_file FROM processed_sessions")
|
566 |
-
processed_files = set(row[0] for row in cursor.fetchall())
|
567 |
-
conn.close()
|
568 |
-
|
569 |
-
new_sessions = [f for f in complete_sessions if f not in processed_files]
|
570 |
-
logger.info(f"Found {len(new_sessions)} new sessions to process")
|
571 |
|
572 |
-
|
573 |
-
valid_sessions = [f for f in new_sessions if is_session_valid(f)]
|
574 |
-
logger.info(f"Found {len(valid_sessions)} valid new sessions to process")
|
575 |
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
592 |
|
593 |
|
594 |
if __name__ == "__main__":
|
|
|
21 |
import ast
|
22 |
import pickle
|
23 |
from moviepy.editor import VideoFileClip
|
24 |
+
import signal
|
25 |
|
26 |
# Import the existing functions
|
27 |
from data.data_collection.synthetic_script_compute_canada import process_trajectory, initialize_clean_state
|
|
|
45 |
SCREEN_WIDTH = 512
|
46 |
SCREEN_HEIGHT = 384
|
47 |
MEMORY_LIMIT = "2g"
|
48 |
+
CHECK_INTERVAL = 60 # Check for new data every 60 seconds
|
49 |
|
50 |
# load autoencoder
|
51 |
config = OmegaConf.load('../computer/autoencoder/config_kl4_lr4.5e6_load_acc1_512_384_mar10_keyboard_init_16_contmar15_acc1.yaml')
|
|
|
53 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
54 |
autoencoder = autoencoder.to(device)
|
55 |
|
56 |
+
# Global flag for graceful shutdown
|
57 |
+
running = True
|
58 |
+
|
59 |
+
def signal_handler(sig, frame):
|
60 |
+
"""Handle Ctrl+C and other termination signals"""
|
61 |
+
global running
|
62 |
+
logger.info("Shutdown signal received. Finishing current processing and exiting...")
|
63 |
+
running = False
|
64 |
+
|
65 |
+
# Register signal handlers
|
66 |
+
signal.signal(signal.SIGINT, signal_handler)
|
67 |
+
signal.signal(signal.SIGTERM, signal_handler)
|
68 |
+
|
69 |
def initialize_database():
|
70 |
"""Initialize the SQLite database if it doesn't exist."""
|
71 |
conn = sqlite3.connect(DB_FILE)
|
|
|
546 |
|
547 |
def main():
|
548 |
"""Main function to run the data processing pipeline."""
|
549 |
+
global running
|
550 |
+
|
551 |
# create a padding image first
|
552 |
if not os.path.exists(os.path.join(OUTPUT_DIR, 'padding.npy')):
|
553 |
logger.info("Creating padding image...")
|
|
|
559 |
latent = torch.zeros_like(latent).squeeze(0)
|
560 |
np.save(os.path.join(OUTPUT_DIR, 'padding.tmp.npy'), latent.cpu().numpy())
|
561 |
os.rename(os.path.join(OUTPUT_DIR, 'padding.tmp.npy'), os.path.join(OUTPUT_DIR, 'padding.npy'))
|
562 |
+
|
563 |
# Initialize database
|
564 |
initialize_database()
|
565 |
|
566 |
+
# Initialize clean Docker state
|
567 |
logger.info("Initializing clean container state...")
|
568 |
clean_state = initialize_clean_state()
|
569 |
logger.info(f"Clean state initialized: {clean_state}")
|
570 |
|
571 |
+
# Ensure output directory exists
|
572 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
|
574 |
+
logger.info(f"Starting continuous monitoring for new sessions (check interval: {CHECK_INTERVAL} seconds)")
|
|
|
|
|
575 |
|
576 |
+
try:
|
577 |
+
# Main monitoring loop
|
578 |
+
while running:
|
579 |
+
try:
|
580 |
+
# Find all log files
|
581 |
+
log_files = glob.glob(os.path.join(FRAMES_DIR, "session_*.jsonl"))
|
582 |
+
logger.info(f"Found {len(log_files)} log files")
|
583 |
+
|
584 |
+
# Filter for complete sessions
|
585 |
+
complete_sessions = [f for f in log_files if is_session_complete(f)]
|
586 |
+
logger.info(f"Found {len(complete_sessions)} complete sessions")
|
587 |
+
|
588 |
+
# Filter for sessions not yet processed
|
589 |
+
conn = sqlite3.connect(DB_FILE)
|
590 |
+
cursor = conn.cursor()
|
591 |
+
cursor.execute("SELECT log_file FROM processed_sessions")
|
592 |
+
processed_files = set(row[0] for row in cursor.fetchall())
|
593 |
+
conn.close()
|
594 |
+
|
595 |
+
new_sessions = [f for f in complete_sessions if f not in processed_files]
|
596 |
+
logger.info(f"Found {len(new_sessions)} new sessions to process")
|
597 |
+
|
598 |
+
# Filter for valid sessions
|
599 |
+
valid_sessions = [f for f in new_sessions if is_session_valid(f)]
|
600 |
+
logger.info(f"Found {len(valid_sessions)} valid new sessions to process")
|
601 |
+
|
602 |
+
# Process each valid session
|
603 |
+
total_trajectories = 0
|
604 |
+
for log_file in valid_sessions:
|
605 |
+
if not running:
|
606 |
+
logger.info("Shutdown in progress, stopping processing")
|
607 |
+
break
|
608 |
+
|
609 |
+
logger.info(f"Processing session file: {log_file}")
|
610 |
+
processed_ids = process_session_file(log_file, clean_state)
|
611 |
+
total_trajectories += len(processed_ids)
|
612 |
+
|
613 |
+
if total_trajectories > 0:
|
614 |
+
# Get next ID for reporting
|
615 |
+
conn = sqlite3.connect(DB_FILE)
|
616 |
+
cursor = conn.cursor()
|
617 |
+
cursor.execute("SELECT value FROM config WHERE key = 'next_id'")
|
618 |
+
next_id = int(cursor.fetchone()[0])
|
619 |
+
conn.close()
|
620 |
+
|
621 |
+
logger.info(f"Processing cycle complete. Generated {total_trajectories} new trajectories.")
|
622 |
+
logger.info(f"Next ID will be {next_id}")
|
623 |
+
else:
|
624 |
+
logger.info("No new trajectories processed in this cycle")
|
625 |
+
|
626 |
+
# Sleep until next check, but with periodic wake-ups to check running flag
|
627 |
+
remaining_sleep = CHECK_INTERVAL
|
628 |
+
while remaining_sleep > 0 and running:
|
629 |
+
sleep_chunk = min(5, remaining_sleep) # Check running flag every 5 seconds max
|
630 |
+
time.sleep(sleep_chunk)
|
631 |
+
remaining_sleep -= sleep_chunk
|
632 |
+
|
633 |
+
except Exception as e:
|
634 |
+
logger.error(f"Error in processing cycle: {e}")
|
635 |
+
# Sleep briefly to avoid rapid error loops
|
636 |
+
time.sleep(10)
|
637 |
+
|
638 |
+
except KeyboardInterrupt:
|
639 |
+
logger.info("Keyboard interrupt received, shutting down")
|
640 |
+
finally:
|
641 |
+
logger.info("Shutting down trajectory processor")
|
642 |
|
643 |
|
644 |
if __name__ == "__main__":
|