Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
import os | |
import json | |
import glob | |
import time | |
import sqlite3 | |
import logging | |
import cv2 | |
import numpy as np | |
import subprocess | |
from datetime import datetime | |
from typing import List, Dict, Any, Tuple | |
# Import the existing functions | |
from latent_diffusion.ldm.data.data_collection import process_trajectory, initialize_clean_state | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.FileHandler("trajectory_processor.log"), | |
logging.StreamHandler() | |
] | |
) | |
logger = logging.getLogger(__name__) | |
# Define constants | |
DB_FILE = "trajectory_processor.db" | |
FRAMES_DIR = "interaction_logs" | |
SCREEN_WIDTH = 512 | |
SCREEN_HEIGHT = 384 | |
MEMORY_LIMIT = "2g" | |
def initialize_database(): | |
"""Initialize the SQLite database if it doesn't exist.""" | |
conn = sqlite3.connect(DB_FILE) | |
cursor = conn.cursor() | |
# Create tables if they don't exist | |
cursor.execute(''' | |
CREATE TABLE IF NOT EXISTS processed_sessions ( | |
id INTEGER PRIMARY KEY, | |
log_file TEXT UNIQUE, | |
client_id TEXT, | |
processed_time TIMESTAMP | |
) | |
''') | |
cursor.execute(''' | |
CREATE TABLE IF NOT EXISTS processed_segments ( | |
id INTEGER PRIMARY KEY, | |
log_file TEXT, | |
client_id TEXT, | |
segment_index INTEGER, | |
start_time REAL, | |
end_time REAL, | |
processed_time TIMESTAMP, | |
trajectory_id INTEGER, | |
UNIQUE(log_file, segment_index) | |
) | |
''') | |
cursor.execute(''' | |
CREATE TABLE IF NOT EXISTS config ( | |
key TEXT PRIMARY KEY, | |
value TEXT | |
) | |
''') | |
# Initialize next_id if not exists | |
cursor.execute("SELECT value FROM config WHERE key = 'next_id'") | |
if not cursor.fetchone(): | |
cursor.execute("INSERT INTO config (key, value) VALUES ('next_id', '1')") | |
conn.commit() | |
conn.close() | |
def is_session_complete(log_file): | |
"""Check if a session is complete (has an EOS marker).""" | |
try: | |
with open(log_file, 'r') as f: | |
for line in f: | |
try: | |
entry = json.loads(line.strip()) | |
if entry.get("is_eos", False): | |
return True | |
except json.JSONDecodeError: | |
continue | |
return False | |
except Exception as e: | |
logger.error(f"Error checking if session {log_file} is complete: {e}") | |
return False | |
def is_session_valid(log_file): | |
""" | |
Check if a session is valid (has more than just an EOS entry). | |
Returns True if the log file has at least one non-EOS entry. | |
""" | |
try: | |
entry_count = 0 | |
has_non_eos = False | |
with open(log_file, 'r') as f: | |
for line in f: | |
try: | |
entry = json.loads(line.strip()) | |
entry_count += 1 | |
if not entry.get("is_eos", False) and not entry.get("is_reset", False): | |
has_non_eos = True | |
except json.JSONDecodeError: | |
continue | |
# Valid if there's at least one entry and at least one non-EOS entry | |
return entry_count > 0 and has_non_eos | |
except Exception as e: | |
logger.error(f"Error checking if session {log_file} is valid: {e}") | |
return False | |
def load_trajectory(log_file): | |
"""Load a trajectory from a log file.""" | |
trajectory = [] | |
try: | |
with open(log_file, 'r') as f: | |
for line in f: | |
try: | |
entry = json.loads(line.strip()) | |
trajectory.append(entry) | |
except json.JSONDecodeError: | |
logger.warning(f"Skipping invalid JSON line in {log_file}") | |
continue | |
return trajectory | |
except Exception as e: | |
logger.error(f"Error loading trajectory from {log_file}: {e}") | |
return [] | |
def process_session_file(log_file, clean_state): | |
""" | |
Process a session file, splitting into multiple trajectories at reset points. | |
Returns a list of successfully processed trajectory IDs. | |
""" | |
conn = sqlite3.connect(DB_FILE) | |
cursor = conn.cursor() | |
# Get session details | |
trajectory = load_trajectory(log_file) | |
if not trajectory: | |
logger.error(f"Empty trajectory for {log_file}, skipping") | |
conn.close() | |
return [] | |
client_id = trajectory[0].get("client_id", "unknown") | |
# Find all reset points and EOS | |
reset_indices = [] | |
has_eos = False | |
for i, entry in enumerate(trajectory): | |
if entry.get("is_reset", False): | |
reset_indices.append(i) | |
if entry.get("is_eos", False): | |
has_eos = True | |
# If no resets and no EOS, this is incomplete - skip | |
if not reset_indices and not has_eos: | |
logger.warning(f"Session {log_file} has no resets and no EOS, may be incomplete") | |
conn.close() | |
return [] | |
# Split trajectory at reset points | |
sub_trajectories = [] | |
start_idx = 0 | |
# Add all segments between resets | |
for reset_idx in reset_indices: | |
if reset_idx > start_idx: # Only add non-empty segments | |
sub_trajectories.append(trajectory[start_idx:reset_idx]) | |
start_idx = reset_idx + 1 # Start new segment after the reset | |
# Add the final segment if it's not empty | |
if start_idx < len(trajectory): | |
sub_trajectories.append(trajectory[start_idx:]) | |
# Process each sub-trajectory | |
processed_ids = [] | |
for i, sub_traj in enumerate(sub_trajectories): | |
# Skip segments with no interaction data (just control messages) | |
if not any(not entry.get("is_reset", False) and not entry.get("is_eos", False) for entry in sub_traj): | |
continue | |
# Get the next ID for this sub-trajectory | |
cursor.execute("SELECT value FROM config WHERE key = 'next_id'") | |
next_id = int(cursor.fetchone()[0]) | |
# Find timestamps for this segment | |
start_time = sub_traj[0]["timestamp"] | |
end_time = sub_traj[-1]["timestamp"] | |
# Process this sub-trajectory using the external function | |
try: | |
logger.info(f"Processing segment {i+1}/{len(sub_trajectories)} from {log_file} as trajectory {next_id}") | |
# Format the trajectory as needed by process_trajectory function | |
formatted_trajectory = format_trajectory_for_processing(sub_traj) | |
# Call the external process_trajectory function | |
args = (next_id, formatted_trajectory) | |
process_trajectory(args, SCREEN_WIDTH, SCREEN_HEIGHT, clean_state, MEMORY_LIMIT) | |
# Mark this segment as processed | |
cursor.execute( | |
"""INSERT INTO processed_segments | |
(log_file, client_id, segment_index, start_time, end_time, | |
processed_time, trajectory_id) | |
VALUES (?, ?, ?, ?, ?, ?, ?)""", | |
(log_file, client_id, i, start_time, end_time, | |
datetime.now().isoformat(), next_id) | |
) | |
# Increment the next ID | |
cursor.execute("UPDATE config SET value = ? WHERE key = 'next_id'", (str(next_id + 1),)) | |
conn.commit() | |
processed_ids.append(next_id) | |
logger.info(f"Successfully processed segment {i+1}/{len(sub_trajectories)} from {log_file}") | |
except Exception as e: | |
logger.error(f"Failed to process segment {i+1}/{len(sub_trajectories)} from {log_file}: {e}") | |
continue | |
# Mark the entire session as processed only if at least one segment succeeded | |
if processed_ids: | |
try: | |
cursor.execute( | |
"INSERT INTO processed_sessions (log_file, client_id, processed_time) VALUES (?, ?, ?)", | |
(log_file, client_id, datetime.now().isoformat()) | |
) | |
conn.commit() | |
except sqlite3.IntegrityError: | |
# This can happen if we're re-processing a file that had some segments fail | |
pass | |
conn.close() | |
return processed_ids | |
def format_trajectory_for_processing(trajectory): | |
""" | |
Format the trajectory in the structure expected by process_trajectory function. | |
The exact format will depend on what your process_trajectory function expects. | |
This is a placeholder - modify based on the actual requirements. | |
""" | |
formatted_events = [] | |
for entry in trajectory: | |
# Skip control messages | |
if entry.get("is_reset") or entry.get("is_eos"): | |
continue | |
# Extract input data | |
inputs = entry.get("inputs", {}) | |
key_events = [] | |
for key in inputs.get("keys_down", []): | |
key_events.append(("keydown", key)) | |
for key in inputs.get("keys_up", []): | |
key_events.append(("keyup", key)) | |
event = { | |
"pos": (inputs.get("x"), inputs.get("y")), | |
"left_click": inputs.get("is_left_click", False), | |
"right_click": inputs.get("is_right_click", False), | |
"key_events": key_events, | |
} | |
formatted_events.append(event) | |
return formatted_events | |
def main(): | |
"""Main function to run the data processing pipeline.""" | |
# Initialize database | |
initialize_database() | |
# Initialize clean Docker state once | |
logger.info("Initializing clean container state...") | |
clean_state = initialize_clean_state() | |
logger.info(f"Clean state initialized: {clean_state}") | |
# Find all log files | |
log_files = glob.glob(os.path.join(FRAMES_DIR, "session_*.jsonl")) | |
logger.info(f"Found {len(log_files)} log files") | |
# Filter for complete sessions | |
complete_sessions = [f for f in log_files if is_session_complete(f)] | |
logger.info(f"Found {len(complete_sessions)} complete sessions") | |
# Filter for sessions not yet processed | |
conn = sqlite3.connect(DB_FILE) | |
cursor = conn.cursor() | |
cursor.execute("SELECT log_file FROM processed_sessions") | |
processed_files = set(row[0] for row in cursor.fetchall()) | |
conn.close() | |
new_sessions = [f for f in complete_sessions if f not in processed_files] | |
logger.info(f"Found {len(new_sessions)} new sessions to process") | |
# Filter for valid sessions | |
valid_sessions = [f for f in new_sessions if is_session_valid(f)] | |
logger.info(f"Found {len(valid_sessions)} valid new sessions to process") | |
# Process each valid session | |
total_trajectories = 0 | |
for log_file in valid_sessions: | |
logger.info(f"Processing session file: {log_file}") | |
processed_ids = process_session_file(log_file, clean_state) | |
total_trajectories += len(processed_ids) | |
# Get next ID for reporting | |
conn = sqlite3.connect(DB_FILE) | |
cursor = conn.cursor() | |
cursor.execute("SELECT value FROM config WHERE key = 'next_id'") | |
next_id = int(cursor.fetchone()[0]) | |
conn.close() | |
logger.info(f"Processing complete. Generated {total_trajectories} trajectories.") | |
logger.info(f"Next ID will be {next_id}") | |
if __name__ == "__main__": | |
try: | |
main() | |
except Exception as e: | |
logger.error(f"Unhandled exception: {e}", exc_info=True) |