#!/usr/bin/env python3 import os import json import glob import time import sqlite3 import logging import cv2 import numpy as np import subprocess from datetime import datetime from typing import List, Dict, Any, Tuple from omegaconf import OmegaConf from computer.util import load_model_from_config from PIL import Image import io import torch from einops import rearrange import webdataset as wds import pandas as pd import ast import pickle from moviepy.editor import VideoFileClip import signal # Import the existing functions from data.data_collection.synthetic_script_compute_canada import process_trajectory, initialize_clean_state # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler("trajectory_processor.log"), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) # Define constants DB_FILE = "trajectory_processor.db" FRAMES_DIR = "interaction_logs" OUTPUT_DIR = 'train_dataset_encoded_online' os.makedirs(OUTPUT_DIR, exist_ok=True) SCREEN_WIDTH = 512 SCREEN_HEIGHT = 384 MEMORY_LIMIT = "2g" CHECK_INTERVAL = 60 # Check for new data every 60 seconds # load autoencoder config = OmegaConf.load('../computer/autoencoder/config_kl4_lr4.5e6_load_acc1_512_384_mar10_keyboard_init_16_contmar15_acc1.yaml') autoencoder = load_model_from_config(config, '../computer/autoencoder/saved_kl4_bsz8_acc8_lr4.5e6_load_acc1_512_384_mar10_keyboard_init_16_cont_mar15_acc1_cont_1e6_cont_2e7_cont/model-2076000.ckpt') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') autoencoder = autoencoder.to(device) # Global flag for graceful shutdown running = True KEYMAPPING = { 'arrowup': 'up', 'arrowdown': 'down', 'arrowleft': 'left', 'arrowright': 'right', 'meta': 'command', 'contextmenu': 'apps', 'control': 'ctrl', } KEYS = ['\t', '\n', '\r', ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', 'accept', 'add', 'alt', 'altleft', 'altright', 'apps', 'backspace', 'browserback', 'browserfavorites', 'browserforward', 'browserhome', 'browserrefresh', 'browsersearch', 'browserstop', 'capslock', 'clear', 'convert', 'ctrl', 'ctrlleft', 'ctrlright', 'decimal', 'del', 'delete', 'divide', 'down', 'end', 'enter', 'esc', 'escape', 'execute', 'f1', 'f10', 'f11', 'f12', 'f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f2', 'f20', 'f21', 'f22', 'f23', 'f24', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9', 'final', 'fn', 'hanguel', 'hangul', 'hanja', 'help', 'home', 'insert', 'junja', 'kana', 'kanji', 'launchapp1', 'launchapp2', 'launchmail', 'launchmediaselect', 'left', 'modechange', 'multiply', 'nexttrack', 'nonconvert', 'num0', 'num1', 'num2', 'num3', 'num4', 'num5', 'num6', 'num7', 'num8', 'num9', 'numlock', 'pagedown', 'pageup', 'pause', 'pgdn', 'pgup', 'playpause', 'prevtrack', 'print', 'printscreen', 'prntscrn', 'prtsc', 'prtscr', 'return', 'right', 'scrolllock', 'select', 'separator', 'shift', 'shiftleft', 'shiftright', 'sleep', 'space', 'stop', 'subtract', 'tab', 'up', 'volumedown', 'volumemute', 'volumeup', 'win', 'winleft', 'winright', 'yen', 'command', 'option', 'optionleft', 'optionright'] INVALID_KEYS = ['f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f20', 'f21', 'f22', 'f23', 'f24', 'select', 'separator', 'execute'] VALID_KEYS = [key for key in KEYS if key not in INVALID_KEYS] itos = VALID_KEYS stoi = {key: i for i, key in enumerate(itos)} def signal_handler(sig, frame): """Handle Ctrl+C and other termination signals""" global running logger.info("Shutdown signal received. Finishing current processing and exiting...") running = False # Register signal handlers signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) def initialize_database(): """Initialize the SQLite database if it doesn't exist.""" conn = sqlite3.connect(DB_FILE) cursor = conn.cursor() # Create tables if they don't exist cursor.execute(''' CREATE TABLE IF NOT EXISTS processed_sessions ( id INTEGER PRIMARY KEY, log_file TEXT UNIQUE, client_id TEXT, processed_time TIMESTAMP ) ''') cursor.execute(''' CREATE TABLE IF NOT EXISTS processed_segments ( id INTEGER PRIMARY KEY, log_file TEXT, client_id TEXT, segment_index INTEGER, start_time REAL, end_time REAL, processed_time TIMESTAMP, trajectory_id INTEGER, UNIQUE(log_file, segment_index) ) ''') cursor.execute(''' CREATE TABLE IF NOT EXISTS config ( key TEXT PRIMARY KEY, value TEXT ) ''') # Initialize next_id if not exists cursor.execute("SELECT value FROM config WHERE key = 'next_id'") if not cursor.fetchone(): cursor.execute("INSERT INTO config (key, value) VALUES ('next_id', '0')") conn.commit() conn.close() def is_session_complete(log_file): """Check if a session is complete (has an EOS marker).""" try: with open(log_file, 'r') as f: for line in f: try: entry = json.loads(line.strip()) if entry.get("is_eos", False): return True except json.JSONDecodeError: continue return False except Exception as e: logger.error(f"Error checking if session {log_file} is complete: {e}") return False def is_session_valid(log_file): """ Check if a session is valid (has more than just an EOS entry). Returns True if the log file has at least one non-EOS entry. """ try: entry_count = 0 has_non_eos = False with open(log_file, 'r') as f: for line in f: try: entry = json.loads(line.strip()) entry_count += 1 if not entry.get("is_eos", False) and not entry.get("is_reset", False): has_non_eos = True except json.JSONDecodeError: continue # Valid if there's at least one entry and at least one non-EOS entry return entry_count > 0 and has_non_eos except Exception as e: logger.error(f"Error checking if session {log_file} is valid: {e}") return False def load_trajectory(log_file): """Load a trajectory from a log file.""" trajectory = [] try: with open(log_file, 'r') as f: for line in f: try: entry = json.loads(line.strip()) trajectory.append(entry) except json.JSONDecodeError: logger.warning(f"Skipping invalid JSON line in {log_file}") continue return trajectory except Exception as e: logger.error(f"Error loading trajectory from {log_file}: {e}") return [] @torch.no_grad() def process_session_file(log_file, clean_state): """Process a session file, splitting into multiple trajectories at reset points.""" conn = None try: conn = sqlite3.connect(DB_FILE) conn.execute("BEGIN TRANSACTION") # Explicit transaction cursor = conn.cursor() # Ensure output directory exists os.makedirs("generated_videos", exist_ok=True) # Get session details trajectory = load_trajectory(log_file) if not trajectory: logger.error(f"Empty trajectory for {log_file}, skipping") return [] client_id = trajectory[0].get("client_id", "unknown") # Find all reset points and EOS reset_indices = [] has_eos = False for i, entry in enumerate(trajectory): if entry.get("is_reset", False): reset_indices.append(i) if entry.get("is_eos", False): has_eos = True # If no resets and no EOS, this is incomplete - skip if not reset_indices and not has_eos: logger.warning(f"Session {log_file} has no resets and no EOS, may be incomplete") return [] # Split trajectory at reset points sub_trajectories = [] start_idx = 0 # Add all segments between resets for reset_idx in reset_indices: if reset_idx > start_idx: # Only add non-empty segments sub_trajectories.append(trajectory[start_idx:reset_idx]) start_idx = reset_idx + 1 # Start new segment after the reset # Add the final segment if it's not empty if start_idx < len(trajectory): sub_trajectories.append(trajectory[start_idx:]) # Process each sub-trajectory processed_ids = [] for i, sub_traj in enumerate(sub_trajectories): # Skip segments with no interaction data (just control messages) if not any(not entry.get("is_reset", False) and not entry.get("is_eos", False) for entry in sub_traj): continue # Get the next ID for this sub-trajectory cursor.execute("SELECT value FROM config WHERE key = 'next_id'") next_id = int(cursor.fetchone()[0]) # Find timestamps for this segment start_time = sub_traj[0]["timestamp"] end_time = sub_traj[-1]["timestamp"] # STEP 1: Generate a video from the original frames segment_label = f"segment_{i+1}_of_{len(sub_trajectories)}" video_path = os.path.join("generated_videos", f"trajectory_{next_id}_{segment_label}.mp4") # Generate video from original frames for comparison success, frame_count = generate_comparison_video( client_id, sub_traj, video_path, start_time, end_time ) if not success: logger.warning(f"Failed to generate comparison video for segment {i+1}, but continuing with processing") # STEP 2: Process with Docker for training data generation try: logger.info(f"Processing segment {i+1}/{len(sub_trajectories)} from {log_file} as trajectory {next_id}") # Format the trajectory as needed by process_trajectory function formatted_trajectory = format_trajectory_for_processing(sub_traj) record_num = next_id # Call the external process_trajectory function args = (record_num, formatted_trajectory) process_trajectory(args, SCREEN_WIDTH, SCREEN_HEIGHT, clean_state, MEMORY_LIMIT) # Prepare training data format video_file = f'raw_data/raw_data/videos/record_{record_num}.mp4' action_file = f'raw_data/raw_data/actions/record_{record_num}.csv' mouse_data = pd.read_csv(action_file) mapping_dict = {} target_data = [] # remove the existing tar file if exists if os.path.exists(os.path.join(OUTPUT_DIR, f'record_{record_num}.tar')): logger.info(f"Removing existing tar file {os.path.join(OUTPUT_DIR, f'record_{record_num}.tar')}") os.remove(os.path.join(OUTPUT_DIR, f'record_{record_num}.tar')) sink = wds.TarWriter(os.path.join(OUTPUT_DIR, f'record_{record_num}.tar')) with VideoFileClip(video_file) as video: fps = video.fps assert fps == 15, f"Expected 15 FPS, got {fps}" duration = video.duration down_keys = set([]) for image_num in range(int(fps*duration)): action_row = mouse_data.iloc[image_num] x = int(action_row['X']) y = int(action_row['Y']) left_click = True if action_row['Left Click'] == 1 else False right_click = True if action_row['Right Click'] == 1 else False key_events = ast.literal_eval(action_row['Key Events']) for key_state, key in key_events: if key_state == "keydown": down_keys.add(key) elif key_state == "keyup": down_keys.remove(key) else: raise ValueError(f"Unknown key event type: {key_state}") mapping_dict[(record_num, image_num)] = (x, y, left_click, right_click, list(down_keys)) target_data.append((record_num, image_num)) frame = video.get_frame(image_num / fps) # Normalize to [-1, 1] image_array = (frame / 127.5 - 1.0).astype(np.float32) # Convert to torch tensor images_tensor = torch.tensor(image_array).unsqueeze(0) images_tensor = rearrange(images_tensor, 'b h w c -> b c h w') # Move to device for inference images_tensor = images_tensor.to(device) # Encode images posterior = autoencoder.encode(images_tensor) latents = posterior.sample() # Sample from the posterior # Move back to CPU for saving latents = latents.cpu() # Save each latent to the tar file latent = latents[0] keys = [str(image_num)] key = keys[0] # Convert latent to bytes latent_bytes = io.BytesIO() np.save(latent_bytes, latent.numpy()) latent_bytes.seek(0) # Write to tar sample = { "__key__": key, "npy": latent_bytes.getvalue(), } sink.write(sample) debug = False # Debug first batch if requested if debug: debug_dir = os.path.join(OUTPUT_DIR, 'debug') os.makedirs(debug_dir, exist_ok=True) # Decode latents back to images reconstructions = autoencoder.decode(latents.to(device)) # Save original and reconstructed images side by side for idx, (orig, recon) in enumerate(zip(images_tensor, reconstructions)): # Convert to numpy and move to CPU orig = orig.cpu().numpy() recon = recon.cpu().numpy() # Denormalize from [-1,1] to [0,255] orig = (orig + 1.0) * 127.5 recon = (recon + 1.0) * 127.5 # Clip values to valid range orig = np.clip(orig, 0, 255).astype(np.uint8) recon = np.clip(recon, 0, 255).astype(np.uint8) # Rearrange from CHW to HWC orig = np.transpose(orig, (1,2,0)) recon = np.transpose(recon, (1,2,0)) # Create side-by-side comparison comparison = np.concatenate([orig, recon], axis=1) # Save comparison image Image.fromarray(comparison).save( os.path.join(debug_dir, f'debug_{image_num}_{idx}_{keys[idx]}.png') ) print(f"\nDebug visualizations saved to {debug_dir}") sink.close() # merge with existing mapping_dict if exists, otherwise create new one if os.path.exists(os.path.join(OUTPUT_DIR, 'image_action_mapping_with_key_states.pkl')): with open(os.path.join(OUTPUT_DIR, 'image_action_mapping_with_key_states.pkl'), 'rb') as f: existing_mapping_dict = pickle.load(f) for key, value in existing_mapping_dict.items(): if key not in mapping_dict: mapping_dict[key] = value # save the mapping_dict in an atomic way temp_path = os.path.join(OUTPUT_DIR, 'image_action_mapping_with_key_states.pkl.temp') with open(temp_path, 'wb') as f: pickle.dump(mapping_dict, f) os.rename(temp_path, os.path.join(OUTPUT_DIR, 'image_action_mapping_with_key_states.pkl')) # merge with existing target_data if exists, otherwise create new one target_data = pd.DataFrame(target_data, columns=['record_num', 'image_num']) if os.path.exists(os.path.join(OUTPUT_DIR, 'train_dataset.target_frames.csv')): existing_target_data = pd.read_csv(os.path.join(OUTPUT_DIR, 'train_dataset.target_frames.csv')) target_data = pd.concat([existing_target_data, target_data]) # deduplicate target_data = target_data.drop_duplicates() # save the target_data in an atomic way temp_path = os.path.join(OUTPUT_DIR, 'train_dataset.target_frames.csv.temp') target_data.to_csv(temp_path, index=False) os.rename(temp_path, os.path.join(OUTPUT_DIR, 'train_dataset.target_frames.csv')) # Mark this segment as processed cursor.execute( """INSERT INTO processed_segments (log_file, client_id, segment_index, start_time, end_time, processed_time, trajectory_id) VALUES (?, ?, ?, ?, ?, ?, ?)""", (log_file, client_id, i, start_time, end_time, datetime.now().isoformat(), next_id) ) # Increment the next ID cursor.execute("UPDATE config SET value = ? WHERE key = 'next_id'", (str(next_id + 1),)) conn.commit() processed_ids.append(next_id) logger.info(f"Successfully processed segment {i+1}/{len(sub_trajectories)} from {log_file}") except Exception as e: logger.error(f"Failed to process segment {i+1}/{len(sub_trajectories)} from {log_file}: {e}") continue # Mark the entire session as processed only if at least one segment succeeded if processed_ids: try: cursor.execute( "INSERT INTO processed_sessions (log_file, client_id, processed_time) VALUES (?, ?, ?)", (log_file, client_id, datetime.now().isoformat()) ) conn.commit() except sqlite3.IntegrityError: # This can happen if we're re-processing a file that had some segments fail pass # Commit only at the end if everything succeeds conn.commit() return processed_ids except Exception as e: logger.error(f"Error processing session {log_file}: {e}") if conn: conn.rollback() # Roll back on error return [] finally: if conn: conn.close() # Always close connection def format_trajectory_for_processing(trajectory): """ Format the trajectory in the structure expected by process_trajectory function. The exact format will depend on what your process_trajectory function expects. This is a placeholder - modify based on the actual requirements. """ formatted_events = [] down_keys = set([]) for entry in trajectory: # Skip control messages if entry.get("is_reset") or entry.get("is_eos"): continue # Extract input data inputs = entry.get("inputs", {}) key_events = [] for key in inputs.get("keys_down", []): key = key.lower() if key in KEYMAPPING: print (f"Key {key} mapped to {KEYMAPPING[key]}") key = KEYMAPPING[key] if key not in stoi: print (f"Key {key} not found in stoi") if key not in down_keys and key in stoi: down_keys.add(key) key_events.append(("keydown", key)) for key in inputs.get("keys_up", []): key = key.lower() if key in KEYMAPPING: print (f"Key {key} mapped to {KEYMAPPING[key]}") key = KEYMAPPING[key] if key not in stoi: print (f"Key {key} not found in stoi") if key in down_keys and key in stoi: down_keys.remove(key) key_events.append(("keyup", key)) event = { "pos": (inputs.get("x"), inputs.get("y")), "left_click": inputs.get("is_left_click", False), "right_click": inputs.get("is_right_click", False), "key_events": key_events, } formatted_events.append(event) return formatted_events def generate_comparison_video(client_id, trajectory, output_file, start_time, end_time): """ Generate a video from the original frames for comparison purposes. Args: client_id: Client ID for frame lookup trajectory: List of interaction log entries for this segment output_file: Path to save the output video start_time: Start timestamp for this segment end_time: End timestamp for this segment Returns: (bool, int): (success status, frame count) """ try: # Get frame files for this client frame_dir = os.path.join(FRAMES_DIR, f"frames_{client_id}") if not os.path.exists(frame_dir): logger.warning(f"No frame directory found for client {client_id}") return False, 0 all_frames = glob.glob(os.path.join(frame_dir, "*.png")) # Sort frames by timestamp in filename all_frames.sort(key=lambda x: float(os.path.basename(x).split('.png')[0])) if not all_frames: logger.error(f"No frames found for client {client_id}") return False, 0 # Filter frames to the time range of this segment # Frame filenames are timestamps, so we can use them for filtering segment_frames = [ f for f in all_frames if start_time <= float(os.path.basename(f).split('.png')[0]) <= end_time ] if not segment_frames: logger.error(f"No frames found in time range for segment {start_time}-{end_time}") return False, 0 # Read the first frame to get dimensions first_frame = cv2.imread(segment_frames[0]) if first_frame is None: logger.error(f"Could not read first frame {segment_frames[0]}") return False, 0 height, width, channels = first_frame.shape # Create video writer fourcc = cv2.VideoWriter_fourcc(*'mp4v') video = cv2.VideoWriter(output_file, fourcc, 10.0, (width, height)) # Process each frame for frame_file in segment_frames: frame = cv2.imread(frame_file) if frame is not None: video.write(frame) # Release the video writer video.release() logger.info(f"Created comparison video {output_file} with {len(segment_frames)} frames") return True, len(segment_frames) except Exception as e: logger.error(f"Error generating comparison video: {e}") return False, 0 def main(): """Main function to run the data processing pipeline.""" global running # create a padding image first if not os.path.exists(os.path.join(OUTPUT_DIR, 'padding.npy')): logger.info("Creating padding image...") padding_data = np.zeros((SCREEN_HEIGHT, SCREEN_WIDTH, 3), dtype=np.float32) padding_tensor = torch.tensor(padding_data).unsqueeze(0) padding_tensor = rearrange(padding_tensor, 'b h w c -> b c h w').to(device) posterior = autoencoder.encode(padding_tensor) latent = posterior.sample() latent = torch.zeros_like(latent).squeeze(0) np.save(os.path.join(OUTPUT_DIR, 'padding.tmp.npy'), latent.cpu().numpy()) os.rename(os.path.join(OUTPUT_DIR, 'padding.tmp.npy'), os.path.join(OUTPUT_DIR, 'padding.npy')) # Initialize database initialize_database() # Initialize clean Docker state logger.info("Initializing clean container state...") clean_state = initialize_clean_state() logger.info(f"Clean state initialized: {clean_state}") # Ensure output directory exists os.makedirs(OUTPUT_DIR, exist_ok=True) logger.info(f"Starting continuous monitoring for new sessions (check interval: {CHECK_INTERVAL} seconds)") try: # Main monitoring loop while running: try: # Find all log files log_files = glob.glob(os.path.join(FRAMES_DIR, "session_*.jsonl")) logger.info(f"Found {len(log_files)} log files") # Filter for complete sessions complete_sessions = [f for f in log_files if is_session_complete(f)] logger.info(f"Found {len(complete_sessions)} complete sessions") # Sort sessions by the numeric timestamp in the filename (session__*.jsonl) def _extract_ts(path): """Return int timestamp from session__.jsonl; fallback to 0 if parse fails.""" try: basename = os.path.basename(path) # session_1750138392_3.jsonl ts_part = basename.split('_')[1] # '1750138392' return int(ts_part) except Exception: # noqa: E722 return 0 complete_sessions.sort(key=_extract_ts) # Filter for sessions not yet processed conn = sqlite3.connect(DB_FILE) cursor = conn.cursor() cursor.execute("SELECT log_file FROM processed_sessions") processed_files = set(row[0] for row in cursor.fetchall()) conn.close() new_sessions = [f for f in complete_sessions if f not in processed_files] logger.info(f"Found {len(new_sessions)} new sessions to process") # Filter for valid sessions valid_sessions = [f for f in new_sessions if is_session_valid(f)] logger.info(f"Found {len(valid_sessions)} valid new sessions to process") # Process each valid session total_trajectories = 0 for log_file in valid_sessions: if not running: logger.info("Shutdown in progress, stopping processing") break logger.info(f"Processing session file: {log_file}") processed_ids = process_session_file(log_file, clean_state) total_trajectories += len(processed_ids) if total_trajectories > 0: # Get next ID for reporting conn = sqlite3.connect(DB_FILE) cursor = conn.cursor() cursor.execute("SELECT value FROM config WHERE key = 'next_id'") next_id = int(cursor.fetchone()[0]) conn.close() logger.info(f"Processing cycle complete. Generated {total_trajectories} new trajectories.") logger.info(f"Next ID will be {next_id}") else: logger.info("No new trajectories processed in this cycle") # Sleep until next check, but with periodic wake-ups to check running flag remaining_sleep = CHECK_INTERVAL while remaining_sleep > 0 and running: sleep_chunk = min(5, remaining_sleep) # Check running flag every 5 seconds max time.sleep(sleep_chunk) remaining_sleep -= sleep_chunk except Exception as e: logger.error(f"Error in processing cycle: {e}") # Sleep briefly to avoid rapid error loops time.sleep(10) except KeyboardInterrupt: logger.info("Keyboard interrupt received, shutting down") finally: logger.info("Shutting down trajectory processor") if __name__ == "__main__": try: main() except Exception as e: logger.error(f"Unhandled exception: {e}", exc_info=True)