Spaces:
Runtime error
Runtime error
File size: 30,345 Bytes
2d9e199 5c64b10 2d9e199 5c64b10 ab7919c a92ddb8 5c64b10 1ebe6d8 2d9e199 5c64b10 2d9e199 ab7919c 5c64b10 a92ddb8 2d9e199 ab7919c a92ddb8 9efab58 a92ddb8 2d9e199 68b9a19 2d9e199 6134734 5c64b10 ab7919c 2d9e199 ab7919c 2d9e199 ab7919c 2d9e199 ab7919c 9612f89 ab7919c 9612f89 ab7919c 9612f89 ab7919c 5c64b10 ab7919c 5c64b10 ab7919c 5c64b10 ab7919c 2d9e199 ab7919c 5c64b10 ab7919c 5434768 ab7919c f08f4f7 ab7919c 801bf02 ab7919c 2d9e199 5c64b10 2d9e199 5c64b10 2d9e199 5c64b10 2d9e199 5c64b10 93c5e4d 5c64b10 2d9e199 5c64b10 9efab58 5c64b10 9efab58 5c64b10 2d9e199 5c64b10 2d9e199 9612f89 2d9e199 a92ddb8 dffe378 e4cd0fb dba2df7 e4cd0fb 29a0aca a92ddb8 2d9e199 a92ddb8 5c64b10 a92ddb8 2d9e199 a92ddb8 2d9e199 a92ddb8 9147f8e a92ddb8 2d9e199 1ebe6d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 |
#!/usr/bin/env python3
import os
import json
import glob
import time
import sqlite3
import logging
import cv2
import numpy as np
import subprocess
from datetime import datetime
from typing import List, Dict, Any, Tuple
from omegaconf import OmegaConf
from computer.util import load_model_from_config
from PIL import Image
import io
import torch
from einops import rearrange
import webdataset as wds
import pandas as pd
import ast
import pickle
from moviepy.editor import VideoFileClip
import signal
# Import the existing functions
from data.data_collection.synthetic_script_compute_canada import process_trajectory, initialize_clean_state
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("trajectory_processor.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# Define constants
DB_FILE = "trajectory_processor.db"
FRAMES_DIR = "interaction_logs"
OUTPUT_DIR = 'train_dataset_encoded_online'
os.makedirs(OUTPUT_DIR, exist_ok=True)
SCREEN_WIDTH = 512
SCREEN_HEIGHT = 384
MEMORY_LIMIT = "2g"
CHECK_INTERVAL = 60 # Check for new data every 60 seconds
# load autoencoder
config = OmegaConf.load('../computer/autoencoder/config_kl4_lr4.5e6_load_acc1_512_384_mar10_keyboard_init_16_contmar15_acc1.yaml')
autoencoder = load_model_from_config(config, '../computer/autoencoder/saved_kl4_bsz8_acc8_lr4.5e6_load_acc1_512_384_mar10_keyboard_init_16_cont_mar15_acc1_cont_1e6_cont_2e7_cont/model-2076000.ckpt')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
autoencoder = autoencoder.to(device)
# Global flag for graceful shutdown
running = True
KEYMAPPING = {
'arrowup': 'up',
'arrowdown': 'down',
'arrowleft': 'left',
'arrowright': 'right',
'meta': 'command',
'contextmenu': 'apps',
'control': 'ctrl',
}
KEYS = ['\t', '\n', '\r', ' ', '!', '"', '#', '$', '%', '&', "'", '(',
')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7',
'8', '9', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`',
'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o',
'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~',
'accept', 'add', 'alt', 'altleft', 'altright', 'apps', 'backspace',
'browserback', 'browserfavorites', 'browserforward', 'browserhome',
'browserrefresh', 'browsersearch', 'browserstop', 'capslock', 'clear',
'convert', 'ctrl', 'ctrlleft', 'ctrlright', 'decimal', 'del', 'delete',
'divide', 'down', 'end', 'enter', 'esc', 'escape', 'execute', 'f1', 'f10',
'f11', 'f12', 'f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f2', 'f20',
'f21', 'f22', 'f23', 'f24', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9',
'final', 'fn', 'hanguel', 'hangul', 'hanja', 'help', 'home', 'insert', 'junja',
'kana', 'kanji', 'launchapp1', 'launchapp2', 'launchmail',
'launchmediaselect', 'left', 'modechange', 'multiply', 'nexttrack',
'nonconvert', 'num0', 'num1', 'num2', 'num3', 'num4', 'num5', 'num6',
'num7', 'num8', 'num9', 'numlock', 'pagedown', 'pageup', 'pause', 'pgdn',
'pgup', 'playpause', 'prevtrack', 'print', 'printscreen', 'prntscrn',
'prtsc', 'prtscr', 'return', 'right', 'scrolllock', 'select', 'separator',
'shift', 'shiftleft', 'shiftright', 'sleep', 'space', 'stop', 'subtract', 'tab',
'up', 'volumedown', 'volumemute', 'volumeup', 'win', 'winleft', 'winright', 'yen',
'command', 'option', 'optionleft', 'optionright']
INVALID_KEYS = ['f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f20',
'f21', 'f22', 'f23', 'f24', 'select', 'separator', 'execute']
VALID_KEYS = [key for key in KEYS if key not in INVALID_KEYS]
itos = VALID_KEYS
stoi = {key: i for i, key in enumerate(itos)}
def signal_handler(sig, frame):
"""Handle Ctrl+C and other termination signals"""
global running
logger.info("Shutdown signal received. Finishing current processing and exiting...")
running = False
# Register signal handlers
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
def initialize_database():
"""Initialize the SQLite database if it doesn't exist."""
conn = sqlite3.connect(DB_FILE)
cursor = conn.cursor()
# Create tables if they don't exist
cursor.execute('''
CREATE TABLE IF NOT EXISTS processed_sessions (
id INTEGER PRIMARY KEY,
log_file TEXT UNIQUE,
client_id TEXT,
processed_time TIMESTAMP
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS processed_segments (
id INTEGER PRIMARY KEY,
log_file TEXT,
client_id TEXT,
segment_index INTEGER,
start_time REAL,
end_time REAL,
processed_time TIMESTAMP,
trajectory_id INTEGER,
UNIQUE(log_file, segment_index)
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS config (
key TEXT PRIMARY KEY,
value TEXT
)
''')
# Initialize next_id if not exists
cursor.execute("SELECT value FROM config WHERE key = 'next_id'")
if not cursor.fetchone():
cursor.execute("INSERT INTO config (key, value) VALUES ('next_id', '0')")
conn.commit()
conn.close()
def is_session_complete(log_file):
"""Check if a session is complete (has an EOS marker)."""
try:
with open(log_file, 'r') as f:
for line in f:
try:
entry = json.loads(line.strip())
if entry.get("is_eos", False):
return True
except json.JSONDecodeError:
continue
return False
except Exception as e:
logger.error(f"Error checking if session {log_file} is complete: {e}")
return False
def is_session_valid(log_file):
"""
Check if a session is valid (has more than just an EOS entry).
Returns True if the log file has at least one non-EOS entry.
"""
try:
entry_count = 0
has_non_eos = False
with open(log_file, 'r') as f:
for line in f:
try:
entry = json.loads(line.strip())
entry_count += 1
if not entry.get("is_eos", False) and not entry.get("is_reset", False):
has_non_eos = True
except json.JSONDecodeError:
continue
# Valid if there's at least one entry and at least one non-EOS entry
return entry_count > 0 and has_non_eos
except Exception as e:
logger.error(f"Error checking if session {log_file} is valid: {e}")
return False
def load_trajectory(log_file):
"""Load a trajectory from a log file."""
trajectory = []
try:
with open(log_file, 'r') as f:
for line in f:
try:
entry = json.loads(line.strip())
trajectory.append(entry)
except json.JSONDecodeError:
logger.warning(f"Skipping invalid JSON line in {log_file}")
continue
return trajectory
except Exception as e:
logger.error(f"Error loading trajectory from {log_file}: {e}")
return []
@torch.no_grad()
def process_session_file(log_file, clean_state):
"""Process a session file, splitting into multiple trajectories at reset points."""
conn = None
try:
conn = sqlite3.connect(DB_FILE)
conn.execute("BEGIN TRANSACTION") # Explicit transaction
cursor = conn.cursor()
# Ensure output directory exists
os.makedirs("generated_videos", exist_ok=True)
# Get session details
trajectory = load_trajectory(log_file)
if not trajectory:
logger.error(f"Empty trajectory for {log_file}, skipping")
return []
client_id = trajectory[0].get("client_id", "unknown")
# Find all reset points and EOS
reset_indices = []
has_eos = False
for i, entry in enumerate(trajectory):
if entry.get("is_reset", False):
reset_indices.append(i)
if entry.get("is_eos", False):
has_eos = True
# If no resets and no EOS, this is incomplete - skip
if not reset_indices and not has_eos:
logger.warning(f"Session {log_file} has no resets and no EOS, may be incomplete")
return []
# Split trajectory at reset points
sub_trajectories = []
start_idx = 0
# Add all segments between resets
for reset_idx in reset_indices:
if reset_idx > start_idx: # Only add non-empty segments
sub_trajectories.append(trajectory[start_idx:reset_idx])
start_idx = reset_idx + 1 # Start new segment after the reset
# Add the final segment if it's not empty
if start_idx < len(trajectory):
sub_trajectories.append(trajectory[start_idx:])
# Process each sub-trajectory
processed_ids = []
for i, sub_traj in enumerate(sub_trajectories):
# Skip segments with no interaction data (just control messages)
if not any(not entry.get("is_reset", False) and not entry.get("is_eos", False) for entry in sub_traj):
continue
# Get the next ID for this sub-trajectory
cursor.execute("SELECT value FROM config WHERE key = 'next_id'")
next_id = int(cursor.fetchone()[0])
# Find timestamps for this segment
start_time = sub_traj[0]["timestamp"]
end_time = sub_traj[-1]["timestamp"]
# STEP 1: Generate a video from the original frames
segment_label = f"segment_{i+1}_of_{len(sub_trajectories)}"
video_path = os.path.join("generated_videos", f"trajectory_{next_id}_{segment_label}.mp4")
# Generate video from original frames for comparison
success, frame_count = generate_comparison_video(
client_id,
sub_traj,
video_path,
start_time,
end_time
)
if not success:
logger.warning(f"Failed to generate comparison video for segment {i+1}, but continuing with processing")
# STEP 2: Process with Docker for training data generation
try:
logger.info(f"Processing segment {i+1}/{len(sub_trajectories)} from {log_file} as trajectory {next_id}")
# Format the trajectory as needed by process_trajectory function
formatted_trajectory = format_trajectory_for_processing(sub_traj)
record_num = next_id
# Call the external process_trajectory function
args = (record_num, formatted_trajectory)
process_trajectory(args, SCREEN_WIDTH, SCREEN_HEIGHT, clean_state, MEMORY_LIMIT)
# Prepare training data format
video_file = f'raw_data/raw_data/videos/record_{record_num}.mp4'
action_file = f'raw_data/raw_data/actions/record_{record_num}.csv'
mouse_data = pd.read_csv(action_file)
mapping_dict = {}
target_data = []
# remove the existing tar file if exists
if os.path.exists(os.path.join(OUTPUT_DIR, f'record_{record_num}.tar')):
logger.info(f"Removing existing tar file {os.path.join(OUTPUT_DIR, f'record_{record_num}.tar')}")
os.remove(os.path.join(OUTPUT_DIR, f'record_{record_num}.tar'))
sink = wds.TarWriter(os.path.join(OUTPUT_DIR, f'record_{record_num}.tar'))
with VideoFileClip(video_file) as video:
fps = video.fps
assert fps == 15, f"Expected 15 FPS, got {fps}"
duration = video.duration
down_keys = set([])
for image_num in range(int(fps*duration)):
action_row = mouse_data.iloc[image_num]
x = int(action_row['X'])
y = int(action_row['Y'])
left_click = True if action_row['Left Click'] == 1 else False
right_click = True if action_row['Right Click'] == 1 else False
key_events = ast.literal_eval(action_row['Key Events'])
for key_state, key in key_events:
if key_state == "keydown":
down_keys.add(key)
elif key_state == "keyup":
down_keys.remove(key)
else:
raise ValueError(f"Unknown key event type: {key_state}")
mapping_dict[(record_num, image_num)] = (x, y, left_click, right_click, list(down_keys))
target_data.append((record_num, image_num))
frame = video.get_frame(image_num / fps)
# Normalize to [-1, 1]
image_array = (frame / 127.5 - 1.0).astype(np.float32)
# Convert to torch tensor
images_tensor = torch.tensor(image_array).unsqueeze(0)
images_tensor = rearrange(images_tensor, 'b h w c -> b c h w')
# Move to device for inference
images_tensor = images_tensor.to(device)
# Encode images
posterior = autoencoder.encode(images_tensor)
latents = posterior.sample() # Sample from the posterior
# Move back to CPU for saving
latents = latents.cpu()
# Save each latent to the tar file
latent = latents[0]
keys = [str(image_num)]
key = keys[0]
# Convert latent to bytes
latent_bytes = io.BytesIO()
np.save(latent_bytes, latent.numpy())
latent_bytes.seek(0)
# Write to tar
sample = {
"__key__": key,
"npy": latent_bytes.getvalue(),
}
sink.write(sample)
debug = False
# Debug first batch if requested
if debug:
debug_dir = os.path.join(OUTPUT_DIR, 'debug')
os.makedirs(debug_dir, exist_ok=True)
# Decode latents back to images
reconstructions = autoencoder.decode(latents.to(device))
# Save original and reconstructed images side by side
for idx, (orig, recon) in enumerate(zip(images_tensor, reconstructions)):
# Convert to numpy and move to CPU
orig = orig.cpu().numpy()
recon = recon.cpu().numpy()
# Denormalize from [-1,1] to [0,255]
orig = (orig + 1.0) * 127.5
recon = (recon + 1.0) * 127.5
# Clip values to valid range
orig = np.clip(orig, 0, 255).astype(np.uint8)
recon = np.clip(recon, 0, 255).astype(np.uint8)
# Rearrange from CHW to HWC
orig = np.transpose(orig, (1,2,0))
recon = np.transpose(recon, (1,2,0))
# Create side-by-side comparison
comparison = np.concatenate([orig, recon], axis=1)
# Save comparison image
Image.fromarray(comparison).save(
os.path.join(debug_dir, f'debug_{image_num}_{idx}_{keys[idx]}.png')
)
print(f"\nDebug visualizations saved to {debug_dir}")
sink.close()
# merge with existing mapping_dict if exists, otherwise create new one
if os.path.exists(os.path.join(OUTPUT_DIR, 'image_action_mapping_with_key_states.pkl')):
with open(os.path.join(OUTPUT_DIR, 'image_action_mapping_with_key_states.pkl'), 'rb') as f:
existing_mapping_dict = pickle.load(f)
for key, value in existing_mapping_dict.items():
if key not in mapping_dict:
mapping_dict[key] = value
# save the mapping_dict in an atomic way
temp_path = os.path.join(OUTPUT_DIR, 'image_action_mapping_with_key_states.pkl.temp')
with open(temp_path, 'wb') as f:
pickle.dump(mapping_dict, f)
os.rename(temp_path, os.path.join(OUTPUT_DIR, 'image_action_mapping_with_key_states.pkl'))
# merge with existing target_data if exists, otherwise create new one
target_data = pd.DataFrame(target_data, columns=['record_num', 'image_num'])
if os.path.exists(os.path.join(OUTPUT_DIR, 'train_dataset.target_frames.csv')):
existing_target_data = pd.read_csv(os.path.join(OUTPUT_DIR, 'train_dataset.target_frames.csv'))
target_data = pd.concat([existing_target_data, target_data])
# deduplicate
target_data = target_data.drop_duplicates()
# save the target_data in an atomic way
temp_path = os.path.join(OUTPUT_DIR, 'train_dataset.target_frames.csv.temp')
target_data.to_csv(temp_path, index=False)
os.rename(temp_path, os.path.join(OUTPUT_DIR, 'train_dataset.target_frames.csv'))
# Mark this segment as processed
cursor.execute(
"""INSERT INTO processed_segments
(log_file, client_id, segment_index, start_time, end_time,
processed_time, trajectory_id)
VALUES (?, ?, ?, ?, ?, ?, ?)""",
(log_file, client_id, i, start_time, end_time,
datetime.now().isoformat(), next_id)
)
# Increment the next ID
cursor.execute("UPDATE config SET value = ? WHERE key = 'next_id'", (str(next_id + 1),))
conn.commit()
processed_ids.append(next_id)
logger.info(f"Successfully processed segment {i+1}/{len(sub_trajectories)} from {log_file}")
except Exception as e:
logger.error(f"Failed to process segment {i+1}/{len(sub_trajectories)} from {log_file}: {e}")
continue
# Mark the entire session as processed only if at least one segment succeeded
if processed_ids:
try:
cursor.execute(
"INSERT INTO processed_sessions (log_file, client_id, processed_time) VALUES (?, ?, ?)",
(log_file, client_id, datetime.now().isoformat())
)
conn.commit()
except sqlite3.IntegrityError:
# This can happen if we're re-processing a file that had some segments fail
pass
# Commit only at the end if everything succeeds
conn.commit()
return processed_ids
except Exception as e:
logger.error(f"Error processing session {log_file}: {e}")
if conn:
conn.rollback() # Roll back on error
return []
finally:
if conn:
conn.close() # Always close connection
def format_trajectory_for_processing(trajectory):
"""
Format the trajectory in the structure expected by process_trajectory function.
The exact format will depend on what your process_trajectory function expects.
This is a placeholder - modify based on the actual requirements.
"""
formatted_events = []
down_keys = set([])
for entry in trajectory:
# Skip control messages
if entry.get("is_reset") or entry.get("is_eos"):
continue
# Extract input data
inputs = entry.get("inputs", {})
key_events = []
for key in inputs.get("keys_down", []):
key = key.lower()
if key in KEYMAPPING:
print (f"Key {key} mapped to {KEYMAPPING[key]}")
key = KEYMAPPING[key]
if key not in stoi:
print (f"Key {key} not found in stoi")
if key not in down_keys and key in stoi:
down_keys.add(key)
key_events.append(("keydown", key))
for key in inputs.get("keys_up", []):
key = key.lower()
if key in KEYMAPPING:
print (f"Key {key} mapped to {KEYMAPPING[key]}")
key = KEYMAPPING[key]
if key not in stoi:
print (f"Key {key} not found in stoi")
if key in down_keys and key in stoi:
down_keys.remove(key)
key_events.append(("keyup", key))
event = {
"pos": (inputs.get("x"), inputs.get("y")),
"left_click": inputs.get("is_left_click", False),
"right_click": inputs.get("is_right_click", False),
"key_events": key_events,
}
formatted_events.append(event)
return formatted_events
def generate_comparison_video(client_id, trajectory, output_file, start_time, end_time):
"""
Generate a video from the original frames for comparison purposes.
Args:
client_id: Client ID for frame lookup
trajectory: List of interaction log entries for this segment
output_file: Path to save the output video
start_time: Start timestamp for this segment
end_time: End timestamp for this segment
Returns:
(bool, int): (success status, frame count)
"""
try:
# Get frame files for this client
frame_dir = os.path.join(FRAMES_DIR, f"frames_{client_id}")
if not os.path.exists(frame_dir):
logger.warning(f"No frame directory found for client {client_id}")
return False, 0
all_frames = glob.glob(os.path.join(frame_dir, "*.png"))
# Sort frames by timestamp in filename
all_frames.sort(key=lambda x: float(os.path.basename(x).split('.png')[0]))
if not all_frames:
logger.error(f"No frames found for client {client_id}")
return False, 0
# Filter frames to the time range of this segment
# Frame filenames are timestamps, so we can use them for filtering
segment_frames = [
f for f in all_frames
if start_time <= float(os.path.basename(f).split('.png')[0]) <= end_time
]
if not segment_frames:
logger.error(f"No frames found in time range for segment {start_time}-{end_time}")
return False, 0
# Read the first frame to get dimensions
first_frame = cv2.imread(segment_frames[0])
if first_frame is None:
logger.error(f"Could not read first frame {segment_frames[0]}")
return False, 0
height, width, channels = first_frame.shape
# Create video writer
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video = cv2.VideoWriter(output_file, fourcc, 10.0, (width, height))
# Process each frame
for frame_file in segment_frames:
frame = cv2.imread(frame_file)
if frame is not None:
video.write(frame)
# Release the video writer
video.release()
logger.info(f"Created comparison video {output_file} with {len(segment_frames)} frames")
return True, len(segment_frames)
except Exception as e:
logger.error(f"Error generating comparison video: {e}")
return False, 0
def main():
"""Main function to run the data processing pipeline."""
global running
# create a padding image first
if not os.path.exists(os.path.join(OUTPUT_DIR, 'padding.npy')):
logger.info("Creating padding image...")
padding_data = np.zeros((SCREEN_HEIGHT, SCREEN_WIDTH, 3), dtype=np.float32)
padding_tensor = torch.tensor(padding_data).unsqueeze(0)
padding_tensor = rearrange(padding_tensor, 'b h w c -> b c h w').to(device)
posterior = autoencoder.encode(padding_tensor)
latent = posterior.sample()
latent = torch.zeros_like(latent).squeeze(0)
np.save(os.path.join(OUTPUT_DIR, 'padding.tmp.npy'), latent.cpu().numpy())
os.rename(os.path.join(OUTPUT_DIR, 'padding.tmp.npy'), os.path.join(OUTPUT_DIR, 'padding.npy'))
# Initialize database
initialize_database()
# Initialize clean Docker state
logger.info("Initializing clean container state...")
clean_state = initialize_clean_state()
logger.info(f"Clean state initialized: {clean_state}")
# Ensure output directory exists
os.makedirs(OUTPUT_DIR, exist_ok=True)
logger.info(f"Starting continuous monitoring for new sessions (check interval: {CHECK_INTERVAL} seconds)")
try:
# Main monitoring loop
while running:
try:
# Find all log files
log_files = glob.glob(os.path.join(FRAMES_DIR, "session_*.jsonl"))
logger.info(f"Found {len(log_files)} log files")
# Filter for complete sessions
complete_sessions = [f for f in log_files if is_session_complete(f)]
logger.info(f"Found {len(complete_sessions)} complete sessions")
# Sort sessions by the numeric timestamp in the filename (session_<timestamp>_*.jsonl)
def _extract_ts(path):
"""Return int timestamp from session_<ts>_<n>.jsonl; fallback to 0 if parse fails."""
try:
basename = os.path.basename(path) # session_1750138392_3.jsonl
ts_part = basename.split('_')[1] # '1750138392'
return int(ts_part)
except Exception: # noqa: E722
return 0
complete_sessions.sort(key=_extract_ts)
# Filter for sessions not yet processed
conn = sqlite3.connect(DB_FILE)
cursor = conn.cursor()
cursor.execute("SELECT log_file FROM processed_sessions")
processed_files = set(row[0] for row in cursor.fetchall())
conn.close()
new_sessions = [f for f in complete_sessions if f not in processed_files]
logger.info(f"Found {len(new_sessions)} new sessions to process")
# Filter for valid sessions
valid_sessions = [f for f in new_sessions if is_session_valid(f)]
logger.info(f"Found {len(valid_sessions)} valid new sessions to process")
# Process each valid session
total_trajectories = 0
for log_file in valid_sessions:
if not running:
logger.info("Shutdown in progress, stopping processing")
break
logger.info(f"Processing session file: {log_file}")
processed_ids = process_session_file(log_file, clean_state)
total_trajectories += len(processed_ids)
if total_trajectories > 0:
# Get next ID for reporting
conn = sqlite3.connect(DB_FILE)
cursor = conn.cursor()
cursor.execute("SELECT value FROM config WHERE key = 'next_id'")
next_id = int(cursor.fetchone()[0])
conn.close()
logger.info(f"Processing cycle complete. Generated {total_trajectories} new trajectories.")
logger.info(f"Next ID will be {next_id}")
else:
logger.info("No new trajectories processed in this cycle")
# Sleep until next check, but with periodic wake-ups to check running flag
remaining_sleep = CHECK_INTERVAL
while remaining_sleep > 0 and running:
sleep_chunk = min(5, remaining_sleep) # Check running flag every 5 seconds max
time.sleep(sleep_chunk)
remaining_sleep -= sleep_chunk
except Exception as e:
logger.error(f"Error in processing cycle: {e}")
# Sleep briefly to avoid rapid error loops
time.sleep(10)
except KeyboardInterrupt:
logger.info("Keyboard interrupt received, shutting down")
finally:
logger.info("Shutting down trajectory processor")
if __name__ == "__main__":
try:
main()
except Exception as e:
logger.error(f"Unhandled exception: {e}", exc_info=True)
|