Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
import os | |
import json | |
import glob | |
import time | |
import sqlite3 | |
import logging | |
import cv2 | |
import numpy as np | |
import subprocess | |
from datetime import datetime | |
from typing import List, Dict, Any, Tuple | |
from omegaconf import OmegaConf | |
from computer.util import load_model_from_config | |
from PIL import Image | |
import io | |
import torch | |
from einops import rearrange | |
import webdataset as wds | |
import pandas as pd | |
import ast | |
import pickle | |
from moviepy.editor import VideoFileClip | |
import signal | |
# Import the existing functions | |
from data.data_collection.synthetic_script_compute_canada import process_trajectory, initialize_clean_state | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.FileHandler("trajectory_processor.log"), | |
logging.StreamHandler() | |
] | |
) | |
logger = logging.getLogger(__name__) | |
# Define constants | |
DB_FILE = "trajectory_processor.db" | |
FRAMES_DIR = "interaction_logs" | |
OUTPUT_DIR = 'train_dataset_encoded_online' | |
os.makedirs(OUTPUT_DIR, exist_ok=True) | |
SCREEN_WIDTH = 512 | |
SCREEN_HEIGHT = 384 | |
MEMORY_LIMIT = "2g" | |
CHECK_INTERVAL = 60 # Check for new data every 60 seconds | |
# load autoencoder | |
config = OmegaConf.load('../computer/autoencoder/config_kl4_lr4.5e6_load_acc1_512_384_mar10_keyboard_init_16_contmar15_acc1.yaml') | |
autoencoder = load_model_from_config(config, '../computer/autoencoder/saved_kl4_bsz8_acc8_lr4.5e6_load_acc1_512_384_mar10_keyboard_init_16_cont_mar15_acc1_cont_1e6_cont_2e7_cont/model-2076000.ckpt') | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
autoencoder = autoencoder.to(device) | |
# Global flag for graceful shutdown | |
running = True | |
KEYMAPPING = { | |
'arrowup': 'up', | |
'arrowdown': 'down', | |
'arrowleft': 'left', | |
'arrowright': 'right', | |
'meta': 'command', | |
'contextmenu': 'apps', | |
'control': 'ctrl', | |
} | |
KEYS = ['\t', '\n', '\r', ' ', '!', '"', '#', '$', '%', '&', "'", '(', | |
')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', | |
'8', '9', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', | |
'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', | |
'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', | |
'accept', 'add', 'alt', 'altleft', 'altright', 'apps', 'backspace', | |
'browserback', 'browserfavorites', 'browserforward', 'browserhome', | |
'browserrefresh', 'browsersearch', 'browserstop', 'capslock', 'clear', | |
'convert', 'ctrl', 'ctrlleft', 'ctrlright', 'decimal', 'del', 'delete', | |
'divide', 'down', 'end', 'enter', 'esc', 'escape', 'execute', 'f1', 'f10', | |
'f11', 'f12', 'f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f2', 'f20', | |
'f21', 'f22', 'f23', 'f24', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9', | |
'final', 'fn', 'hanguel', 'hangul', 'hanja', 'help', 'home', 'insert', 'junja', | |
'kana', 'kanji', 'launchapp1', 'launchapp2', 'launchmail', | |
'launchmediaselect', 'left', 'modechange', 'multiply', 'nexttrack', | |
'nonconvert', 'num0', 'num1', 'num2', 'num3', 'num4', 'num5', 'num6', | |
'num7', 'num8', 'num9', 'numlock', 'pagedown', 'pageup', 'pause', 'pgdn', | |
'pgup', 'playpause', 'prevtrack', 'print', 'printscreen', 'prntscrn', | |
'prtsc', 'prtscr', 'return', 'right', 'scrolllock', 'select', 'separator', | |
'shift', 'shiftleft', 'shiftright', 'sleep', 'space', 'stop', 'subtract', 'tab', | |
'up', 'volumedown', 'volumemute', 'volumeup', 'win', 'winleft', 'winright', 'yen', | |
'command', 'option', 'optionleft', 'optionright'] | |
INVALID_KEYS = ['f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f20', | |
'f21', 'f22', 'f23', 'f24', 'select', 'separator', 'execute'] | |
VALID_KEYS = [key for key in KEYS if key not in INVALID_KEYS] | |
itos = VALID_KEYS | |
stoi = {key: i for i, key in enumerate(itos)} | |
def signal_handler(sig, frame): | |
"""Handle Ctrl+C and other termination signals""" | |
global running | |
logger.info("Shutdown signal received. Finishing current processing and exiting...") | |
running = False | |
# Register signal handlers | |
signal.signal(signal.SIGINT, signal_handler) | |
signal.signal(signal.SIGTERM, signal_handler) | |
def initialize_database(): | |
"""Initialize the SQLite database if it doesn't exist.""" | |
conn = sqlite3.connect(DB_FILE) | |
cursor = conn.cursor() | |
# Create tables if they don't exist | |
cursor.execute(''' | |
CREATE TABLE IF NOT EXISTS processed_sessions ( | |
id INTEGER PRIMARY KEY, | |
log_file TEXT UNIQUE, | |
client_id TEXT, | |
processed_time TIMESTAMP | |
) | |
''') | |
cursor.execute(''' | |
CREATE TABLE IF NOT EXISTS processed_segments ( | |
id INTEGER PRIMARY KEY, | |
log_file TEXT, | |
client_id TEXT, | |
segment_index INTEGER, | |
start_time REAL, | |
end_time REAL, | |
processed_time TIMESTAMP, | |
trajectory_id INTEGER, | |
UNIQUE(log_file, segment_index) | |
) | |
''') | |
cursor.execute(''' | |
CREATE TABLE IF NOT EXISTS config ( | |
key TEXT PRIMARY KEY, | |
value TEXT | |
) | |
''') | |
# Initialize next_id if not exists | |
cursor.execute("SELECT value FROM config WHERE key = 'next_id'") | |
if not cursor.fetchone(): | |
cursor.execute("INSERT INTO config (key, value) VALUES ('next_id', '0')") | |
conn.commit() | |
conn.close() | |
def is_session_complete(log_file): | |
"""Check if a session is complete (has an EOS marker).""" | |
try: | |
with open(log_file, 'r') as f: | |
for line in f: | |
try: | |
entry = json.loads(line.strip()) | |
if entry.get("is_eos", False): | |
return True | |
except json.JSONDecodeError: | |
continue | |
return False | |
except Exception as e: | |
logger.error(f"Error checking if session {log_file} is complete: {e}") | |
return False | |
def is_session_valid(log_file): | |
""" | |
Check if a session is valid (has more than just an EOS entry). | |
Returns True if the log file has at least one non-EOS entry. | |
""" | |
try: | |
entry_count = 0 | |
has_non_eos = False | |
with open(log_file, 'r') as f: | |
for line in f: | |
try: | |
entry = json.loads(line.strip()) | |
entry_count += 1 | |
if not entry.get("is_eos", False) and not entry.get("is_reset", False): | |
has_non_eos = True | |
except json.JSONDecodeError: | |
continue | |
# Valid if there's at least one entry and at least one non-EOS entry | |
return entry_count > 0 and has_non_eos | |
except Exception as e: | |
logger.error(f"Error checking if session {log_file} is valid: {e}") | |
return False | |
def load_trajectory(log_file): | |
"""Load a trajectory from a log file.""" | |
trajectory = [] | |
try: | |
with open(log_file, 'r') as f: | |
for line in f: | |
try: | |
entry = json.loads(line.strip()) | |
trajectory.append(entry) | |
except json.JSONDecodeError: | |
logger.warning(f"Skipping invalid JSON line in {log_file}") | |
continue | |
return trajectory | |
except Exception as e: | |
logger.error(f"Error loading trajectory from {log_file}: {e}") | |
return [] | |
def process_session_file(log_file, clean_state): | |
"""Process a session file, splitting into multiple trajectories at reset points.""" | |
conn = None | |
try: | |
conn = sqlite3.connect(DB_FILE) | |
conn.execute("BEGIN TRANSACTION") # Explicit transaction | |
cursor = conn.cursor() | |
# Ensure output directory exists | |
os.makedirs("generated_videos", exist_ok=True) | |
# Get session details | |
trajectory = load_trajectory(log_file) | |
if not trajectory: | |
logger.error(f"Empty trajectory for {log_file}, skipping") | |
return [] | |
client_id = trajectory[0].get("client_id", "unknown") | |
# Find all reset points and EOS | |
reset_indices = [] | |
has_eos = False | |
for i, entry in enumerate(trajectory): | |
if entry.get("is_reset", False): | |
reset_indices.append(i) | |
if entry.get("is_eos", False): | |
has_eos = True | |
# If no resets and no EOS, this is incomplete - skip | |
if not reset_indices and not has_eos: | |
logger.warning(f"Session {log_file} has no resets and no EOS, may be incomplete") | |
return [] | |
# Split trajectory at reset points | |
sub_trajectories = [] | |
start_idx = 0 | |
# Add all segments between resets | |
for reset_idx in reset_indices: | |
if reset_idx > start_idx: # Only add non-empty segments | |
sub_trajectories.append(trajectory[start_idx:reset_idx]) | |
start_idx = reset_idx + 1 # Start new segment after the reset | |
# Add the final segment if it's not empty | |
if start_idx < len(trajectory): | |
sub_trajectories.append(trajectory[start_idx:]) | |
# Process each sub-trajectory | |
processed_ids = [] | |
for i, sub_traj in enumerate(sub_trajectories): | |
# Skip segments with no interaction data (just control messages) | |
if not any(not entry.get("is_reset", False) and not entry.get("is_eos", False) for entry in sub_traj): | |
continue | |
# Get the next ID for this sub-trajectory | |
cursor.execute("SELECT value FROM config WHERE key = 'next_id'") | |
next_id = int(cursor.fetchone()[0]) | |
# Find timestamps for this segment | |
start_time = sub_traj[0]["timestamp"] | |
end_time = sub_traj[-1]["timestamp"] | |
# STEP 1: Generate a video from the original frames | |
segment_label = f"segment_{i+1}_of_{len(sub_trajectories)}" | |
video_path = os.path.join("generated_videos", f"trajectory_{next_id}_{segment_label}.mp4") | |
# Generate video from original frames for comparison | |
success, frame_count = generate_comparison_video( | |
client_id, | |
sub_traj, | |
video_path, | |
start_time, | |
end_time | |
) | |
if not success: | |
logger.warning(f"Failed to generate comparison video for segment {i+1}, but continuing with processing") | |
# STEP 2: Process with Docker for training data generation | |
try: | |
logger.info(f"Processing segment {i+1}/{len(sub_trajectories)} from {log_file} as trajectory {next_id}") | |
# Format the trajectory as needed by process_trajectory function | |
formatted_trajectory = format_trajectory_for_processing(sub_traj) | |
record_num = next_id | |
# Call the external process_trajectory function | |
args = (record_num, formatted_trajectory) | |
process_trajectory(args, SCREEN_WIDTH, SCREEN_HEIGHT, clean_state, MEMORY_LIMIT) | |
# Prepare training data format | |
video_file = f'raw_data/raw_data/videos/record_{record_num}.mp4' | |
action_file = f'raw_data/raw_data/actions/record_{record_num}.csv' | |
mouse_data = pd.read_csv(action_file) | |
mapping_dict = {} | |
target_data = [] | |
# remove the existing tar file if exists | |
if os.path.exists(os.path.join(OUTPUT_DIR, f'record_{record_num}.tar')): | |
logger.info(f"Removing existing tar file {os.path.join(OUTPUT_DIR, f'record_{record_num}.tar')}") | |
os.remove(os.path.join(OUTPUT_DIR, f'record_{record_num}.tar')) | |
sink = wds.TarWriter(os.path.join(OUTPUT_DIR, f'record_{record_num}.tar')) | |
with VideoFileClip(video_file) as video: | |
fps = video.fps | |
assert fps == 15, f"Expected 15 FPS, got {fps}" | |
duration = video.duration | |
down_keys = set([]) | |
for image_num in range(int(fps*duration)): | |
action_row = mouse_data.iloc[image_num] | |
x = int(action_row['X']) | |
y = int(action_row['Y']) | |
left_click = True if action_row['Left Click'] == 1 else False | |
right_click = True if action_row['Right Click'] == 1 else False | |
key_events = ast.literal_eval(action_row['Key Events']) | |
for key_state, key in key_events: | |
if key_state == "keydown": | |
down_keys.add(key) | |
elif key_state == "keyup": | |
down_keys.remove(key) | |
else: | |
raise ValueError(f"Unknown key event type: {key_state}") | |
mapping_dict[(record_num, image_num)] = (x, y, left_click, right_click, list(down_keys)) | |
target_data.append((record_num, image_num)) | |
frame = video.get_frame(image_num / fps) | |
# Normalize to [-1, 1] | |
image_array = (frame / 127.5 - 1.0).astype(np.float32) | |
# Convert to torch tensor | |
images_tensor = torch.tensor(image_array).unsqueeze(0) | |
images_tensor = rearrange(images_tensor, 'b h w c -> b c h w') | |
# Move to device for inference | |
images_tensor = images_tensor.to(device) | |
# Encode images | |
posterior = autoencoder.encode(images_tensor) | |
latents = posterior.sample() # Sample from the posterior | |
# Move back to CPU for saving | |
latents = latents.cpu() | |
# Save each latent to the tar file | |
latent = latents[0] | |
keys = [str(image_num)] | |
key = keys[0] | |
# Convert latent to bytes | |
latent_bytes = io.BytesIO() | |
np.save(latent_bytes, latent.numpy()) | |
latent_bytes.seek(0) | |
# Write to tar | |
sample = { | |
"__key__": key, | |
"npy": latent_bytes.getvalue(), | |
} | |
sink.write(sample) | |
debug = False | |
# Debug first batch if requested | |
if debug: | |
debug_dir = os.path.join(OUTPUT_DIR, 'debug') | |
os.makedirs(debug_dir, exist_ok=True) | |
# Decode latents back to images | |
reconstructions = autoencoder.decode(latents.to(device)) | |
# Save original and reconstructed images side by side | |
for idx, (orig, recon) in enumerate(zip(images_tensor, reconstructions)): | |
# Convert to numpy and move to CPU | |
orig = orig.cpu().numpy() | |
recon = recon.cpu().numpy() | |
# Denormalize from [-1,1] to [0,255] | |
orig = (orig + 1.0) * 127.5 | |
recon = (recon + 1.0) * 127.5 | |
# Clip values to valid range | |
orig = np.clip(orig, 0, 255).astype(np.uint8) | |
recon = np.clip(recon, 0, 255).astype(np.uint8) | |
# Rearrange from CHW to HWC | |
orig = np.transpose(orig, (1,2,0)) | |
recon = np.transpose(recon, (1,2,0)) | |
# Create side-by-side comparison | |
comparison = np.concatenate([orig, recon], axis=1) | |
# Save comparison image | |
Image.fromarray(comparison).save( | |
os.path.join(debug_dir, f'debug_{image_num}_{idx}_{keys[idx]}.png') | |
) | |
print(f"\nDebug visualizations saved to {debug_dir}") | |
sink.close() | |
# merge with existing mapping_dict if exists, otherwise create new one | |
if os.path.exists(os.path.join(OUTPUT_DIR, 'image_action_mapping_with_key_states.pkl')): | |
with open(os.path.join(OUTPUT_DIR, 'image_action_mapping_with_key_states.pkl'), 'rb') as f: | |
existing_mapping_dict = pickle.load(f) | |
for key, value in existing_mapping_dict.items(): | |
if key not in mapping_dict: | |
mapping_dict[key] = value | |
# save the mapping_dict in an atomic way | |
temp_path = os.path.join(OUTPUT_DIR, 'image_action_mapping_with_key_states.pkl.temp') | |
with open(temp_path, 'wb') as f: | |
pickle.dump(mapping_dict, f) | |
os.rename(temp_path, os.path.join(OUTPUT_DIR, 'image_action_mapping_with_key_states.pkl')) | |
# merge with existing target_data if exists, otherwise create new one | |
target_data = pd.DataFrame(target_data, columns=['record_num', 'image_num']) | |
if os.path.exists(os.path.join(OUTPUT_DIR, 'train_dataset.target_frames.csv')): | |
existing_target_data = pd.read_csv(os.path.join(OUTPUT_DIR, 'train_dataset.target_frames.csv')) | |
target_data = pd.concat([existing_target_data, target_data]) | |
# deduplicate | |
target_data = target_data.drop_duplicates() | |
# save the target_data in an atomic way | |
temp_path = os.path.join(OUTPUT_DIR, 'train_dataset.target_frames.csv.temp') | |
target_data.to_csv(temp_path, index=False) | |
os.rename(temp_path, os.path.join(OUTPUT_DIR, 'train_dataset.target_frames.csv')) | |
# Mark this segment as processed | |
cursor.execute( | |
"""INSERT INTO processed_segments | |
(log_file, client_id, segment_index, start_time, end_time, | |
processed_time, trajectory_id) | |
VALUES (?, ?, ?, ?, ?, ?, ?)""", | |
(log_file, client_id, i, start_time, end_time, | |
datetime.now().isoformat(), next_id) | |
) | |
# Increment the next ID | |
cursor.execute("UPDATE config SET value = ? WHERE key = 'next_id'", (str(next_id + 1),)) | |
conn.commit() | |
processed_ids.append(next_id) | |
logger.info(f"Successfully processed segment {i+1}/{len(sub_trajectories)} from {log_file}") | |
except Exception as e: | |
logger.error(f"Failed to process segment {i+1}/{len(sub_trajectories)} from {log_file}: {e}") | |
continue | |
# Mark the entire session as processed only if at least one segment succeeded | |
if processed_ids: | |
try: | |
cursor.execute( | |
"INSERT INTO processed_sessions (log_file, client_id, processed_time) VALUES (?, ?, ?)", | |
(log_file, client_id, datetime.now().isoformat()) | |
) | |
conn.commit() | |
except sqlite3.IntegrityError: | |
# This can happen if we're re-processing a file that had some segments fail | |
pass | |
# Commit only at the end if everything succeeds | |
conn.commit() | |
return processed_ids | |
except Exception as e: | |
logger.error(f"Error processing session {log_file}: {e}") | |
if conn: | |
conn.rollback() # Roll back on error | |
return [] | |
finally: | |
if conn: | |
conn.close() # Always close connection | |
def format_trajectory_for_processing(trajectory): | |
""" | |
Format the trajectory in the structure expected by process_trajectory function. | |
The exact format will depend on what your process_trajectory function expects. | |
This is a placeholder - modify based on the actual requirements. | |
""" | |
formatted_events = [] | |
down_keys = set([]) | |
for entry in trajectory: | |
# Skip control messages | |
if entry.get("is_reset") or entry.get("is_eos"): | |
continue | |
# Extract input data | |
inputs = entry.get("inputs", {}) | |
key_events = [] | |
for key in inputs.get("keys_down", []): | |
key = key.lower() | |
if key in KEYMAPPING: | |
print (f"Key {key} mapped to {KEYMAPPING[key]}") | |
key = KEYMAPPING[key] | |
if key not in stoi: | |
print (f"Key {key} not found in stoi") | |
if key not in down_keys and key in stoi: | |
down_keys.add(key) | |
key_events.append(("keydown", key)) | |
for key in inputs.get("keys_up", []): | |
key = key.lower() | |
if key in KEYMAPPING: | |
print (f"Key {key} mapped to {KEYMAPPING[key]}") | |
key = KEYMAPPING[key] | |
if key not in stoi: | |
print (f"Key {key} not found in stoi") | |
if key in down_keys and key in stoi: | |
down_keys.remove(key) | |
key_events.append(("keyup", key)) | |
event = { | |
"pos": (inputs.get("x"), inputs.get("y")), | |
"left_click": inputs.get("is_left_click", False), | |
"right_click": inputs.get("is_right_click", False), | |
"key_events": key_events, | |
} | |
formatted_events.append(event) | |
return formatted_events | |
def generate_comparison_video(client_id, trajectory, output_file, start_time, end_time): | |
""" | |
Generate a video from the original frames for comparison purposes. | |
Args: | |
client_id: Client ID for frame lookup | |
trajectory: List of interaction log entries for this segment | |
output_file: Path to save the output video | |
start_time: Start timestamp for this segment | |
end_time: End timestamp for this segment | |
Returns: | |
(bool, int): (success status, frame count) | |
""" | |
try: | |
# Get frame files for this client | |
frame_dir = os.path.join(FRAMES_DIR, f"frames_{client_id}") | |
if not os.path.exists(frame_dir): | |
logger.warning(f"No frame directory found for client {client_id}") | |
return False, 0 | |
all_frames = glob.glob(os.path.join(frame_dir, "*.png")) | |
# Sort frames by timestamp in filename | |
all_frames.sort(key=lambda x: float(os.path.basename(x).split('.png')[0])) | |
if not all_frames: | |
logger.error(f"No frames found for client {client_id}") | |
return False, 0 | |
# Filter frames to the time range of this segment | |
# Frame filenames are timestamps, so we can use them for filtering | |
segment_frames = [ | |
f for f in all_frames | |
if start_time <= float(os.path.basename(f).split('.png')[0]) <= end_time | |
] | |
if not segment_frames: | |
logger.error(f"No frames found in time range for segment {start_time}-{end_time}") | |
return False, 0 | |
# Read the first frame to get dimensions | |
first_frame = cv2.imread(segment_frames[0]) | |
if first_frame is None: | |
logger.error(f"Could not read first frame {segment_frames[0]}") | |
return False, 0 | |
height, width, channels = first_frame.shape | |
# Create video writer | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
video = cv2.VideoWriter(output_file, fourcc, 10.0, (width, height)) | |
# Process each frame | |
for frame_file in segment_frames: | |
frame = cv2.imread(frame_file) | |
if frame is not None: | |
video.write(frame) | |
# Release the video writer | |
video.release() | |
logger.info(f"Created comparison video {output_file} with {len(segment_frames)} frames") | |
return True, len(segment_frames) | |
except Exception as e: | |
logger.error(f"Error generating comparison video: {e}") | |
return False, 0 | |
def main(): | |
"""Main function to run the data processing pipeline.""" | |
global running | |
# create a padding image first | |
if not os.path.exists(os.path.join(OUTPUT_DIR, 'padding.npy')): | |
logger.info("Creating padding image...") | |
padding_data = np.zeros((SCREEN_HEIGHT, SCREEN_WIDTH, 3), dtype=np.float32) | |
padding_tensor = torch.tensor(padding_data).unsqueeze(0) | |
padding_tensor = rearrange(padding_tensor, 'b h w c -> b c h w').to(device) | |
posterior = autoencoder.encode(padding_tensor) | |
latent = posterior.sample() | |
latent = torch.zeros_like(latent).squeeze(0) | |
np.save(os.path.join(OUTPUT_DIR, 'padding.tmp.npy'), latent.cpu().numpy()) | |
os.rename(os.path.join(OUTPUT_DIR, 'padding.tmp.npy'), os.path.join(OUTPUT_DIR, 'padding.npy')) | |
# Initialize database | |
initialize_database() | |
# Initialize clean Docker state | |
logger.info("Initializing clean container state...") | |
clean_state = initialize_clean_state() | |
logger.info(f"Clean state initialized: {clean_state}") | |
# Ensure output directory exists | |
os.makedirs(OUTPUT_DIR, exist_ok=True) | |
logger.info(f"Starting continuous monitoring for new sessions (check interval: {CHECK_INTERVAL} seconds)") | |
try: | |
# Main monitoring loop | |
while running: | |
try: | |
# Find all log files | |
log_files = glob.glob(os.path.join(FRAMES_DIR, "session_*.jsonl")) | |
logger.info(f"Found {len(log_files)} log files") | |
# Filter for complete sessions | |
complete_sessions = [f for f in log_files if is_session_complete(f)] | |
logger.info(f"Found {len(complete_sessions)} complete sessions") | |
# Sort sessions by the numeric timestamp in the filename (session_<timestamp>_*.jsonl) | |
def _extract_ts(path): | |
"""Return int timestamp from session_<ts>_<n>.jsonl; fallback to 0 if parse fails.""" | |
try: | |
basename = os.path.basename(path) # session_1750138392_3.jsonl | |
ts_part = basename.split('_')[1] # '1750138392' | |
return int(ts_part) | |
except Exception: # noqa: E722 | |
return 0 | |
complete_sessions.sort(key=_extract_ts) | |
# Filter for sessions not yet processed | |
conn = sqlite3.connect(DB_FILE) | |
cursor = conn.cursor() | |
cursor.execute("SELECT log_file FROM processed_sessions") | |
processed_files = set(row[0] for row in cursor.fetchall()) | |
conn.close() | |
new_sessions = [f for f in complete_sessions if f not in processed_files] | |
logger.info(f"Found {len(new_sessions)} new sessions to process") | |
# Filter for valid sessions | |
valid_sessions = [f for f in new_sessions if is_session_valid(f)] | |
logger.info(f"Found {len(valid_sessions)} valid new sessions to process") | |
# Process each valid session | |
total_trajectories = 0 | |
for log_file in valid_sessions: | |
if not running: | |
logger.info("Shutdown in progress, stopping processing") | |
break | |
logger.info(f"Processing session file: {log_file}") | |
processed_ids = process_session_file(log_file, clean_state) | |
total_trajectories += len(processed_ids) | |
if total_trajectories > 0: | |
# Get next ID for reporting | |
conn = sqlite3.connect(DB_FILE) | |
cursor = conn.cursor() | |
cursor.execute("SELECT value FROM config WHERE key = 'next_id'") | |
next_id = int(cursor.fetchone()[0]) | |
conn.close() | |
logger.info(f"Processing cycle complete. Generated {total_trajectories} new trajectories.") | |
logger.info(f"Next ID will be {next_id}") | |
else: | |
logger.info("No new trajectories processed in this cycle") | |
# Sleep until next check, but with periodic wake-ups to check running flag | |
remaining_sleep = CHECK_INTERVAL | |
while remaining_sleep > 0 and running: | |
sleep_chunk = min(5, remaining_sleep) # Check running flag every 5 seconds max | |
time.sleep(sleep_chunk) | |
remaining_sleep -= sleep_chunk | |
except Exception as e: | |
logger.error(f"Error in processing cycle: {e}") | |
# Sleep briefly to avoid rapid error loops | |
time.sleep(10) | |
except KeyboardInterrupt: | |
logger.info("Keyboard interrupt received, shutting down") | |
finally: | |
logger.info("Shutting down trajectory processor") | |
if __name__ == "__main__": | |
try: | |
main() | |
except Exception as e: | |
logger.error(f"Unhandled exception: {e}", exc_info=True) | |