neural-os / main.py
da03
.
8b76adf
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from typing import List, Tuple
import numpy as np
from PIL import Image, ImageDraw
import base64
import io
import json
import asyncio
from utils import initialize_model, sample_frame
import torch
import os
import time
from typing import Any, Dict
from ldm.models.diffusion.ddpm import LatentDiffusion, DDIMSampler
import concurrent.futures
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEBUG_MODE = False
DEBUG_MODE_2 = False
NUM_MAX_FRAMES = 1
TIMESTEPS = 1000
SCREEN_WIDTH = 512
SCREEN_HEIGHT = 384
NUM_SAMPLING_STEPS = 32
USE_RNN = False
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-384k"
MODEL_NAME = "yuntian-deng/computer-model-noss-forsure"
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-2k"
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-10k"
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-54k"
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-unfreezernn-160k"
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-368k"
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-unfreezernn-198k"
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-674k"
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-online-74k"
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-online-70k"
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-online-x0-46k"
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-142k"
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-338k"
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-ddpm32-x0-140k"
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-ddpm32-eps-144k"
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-70k"
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-joint-onlineonly-eps22-40k"
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-22-38k"
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-222-42k"
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-2222-70k"
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-222222-48k"
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-222222k7-06k"
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-222222k7-114k"
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-222222k7-136k"
#MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-222222k7-184k"
#MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-222222k7-272k"
#MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-222222k7-272k"
#MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-oo-eps222222k72-270k"
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-222222k72-108k"
print (f'setting: DEBUG_MODE: {DEBUG_MODE}, DEBUG_MODE_2: {DEBUG_MODE_2}, NUM_MAX_FRAMES: {NUM_MAX_FRAMES}, NUM_SAMPLING_STEPS: {NUM_SAMPLING_STEPS}, MODEL_NAME: {MODEL_NAME}')
with open('latent_stats.json', 'r') as f:
latent_stats = json.load(f)
DATA_NORMALIZATION = {'mean': torch.tensor(latent_stats['mean']).to(device), 'std': torch.tensor(latent_stats['std']).to(device)}
LATENT_DIMS = (16, SCREEN_HEIGHT // 8, SCREEN_WIDTH // 8)
# Initialize the model at the start of your application
#model = initialize_model("config_csllm.yaml", "yuntian-deng/computer-model")
#model = initialize_model("config_rnn.yaml", "yuntian-deng/computer-model")
#model = initialize_model("config_final_model.yaml", "yuntian-deng/computer-model-noss")
#model = initialize_model("config_final_model.yaml", "yuntian-deng/computer-model")
if 'origunet' in MODEL_NAME:
if 'x0' in MODEL_NAME and 'eps' not in MODEL_NAME:
if 'ddpm32' in MODEL_NAME:
TIMESTEPS = 32
model = initialize_model("config_final_model_origunet_nospatial_x0_ddpm32.yaml", MODEL_NAME)
else:
model = initialize_model("config_final_model_origunet_nospatial_x0.yaml", MODEL_NAME)
else:
if 'ddpm32' in MODEL_NAME:
TIMESTEPS = 32
model = initialize_model("config_final_model_origunet_nospatial_ddpm32.yaml", MODEL_NAME)
else:
model = initialize_model("config_final_model_origunet_nospatial.yaml", MODEL_NAME)
else:
model = initialize_model("config_final_model.yaml", MODEL_NAME)
model = model.to(device)
#model = torch.compile(model)
padding_image = torch.zeros(*LATENT_DIMS).unsqueeze(0).to(device)
padding_image = (padding_image - DATA_NORMALIZATION['mean'].view(1, -1, 1, 1)) / DATA_NORMALIZATION['std'].view(1, -1, 1, 1)
# Valid keyboard inputs
KEYS = ['\t', '\n', '\r', ' ', '!', '"', '#', '$', '%', '&', "'", '(',
')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7',
'8', '9', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`',
'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o',
'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~',
'accept', 'add', 'alt', 'altleft', 'altright', 'apps', 'backspace',
'browserback', 'browserfavorites', 'browserforward', 'browserhome',
'browserrefresh', 'browsersearch', 'browserstop', 'capslock', 'clear',
'convert', 'ctrl', 'ctrlleft', 'ctrlright', 'decimal', 'del', 'delete',
'divide', 'down', 'end', 'enter', 'esc', 'escape', 'execute', 'f1', 'f10',
'f11', 'f12', 'f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f2', 'f20',
'f21', 'f22', 'f23', 'f24', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9',
'final', 'fn', 'hanguel', 'hangul', 'hanja', 'help', 'home', 'insert', 'junja',
'kana', 'kanji', 'launchapp1', 'launchapp2', 'launchmail',
'launchmediaselect', 'left', 'modechange', 'multiply', 'nexttrack',
'nonconvert', 'num0', 'num1', 'num2', 'num3', 'num4', 'num5', 'num6',
'num7', 'num8', 'num9', 'numlock', 'pagedown', 'pageup', 'pause', 'pgdn',
'pgup', 'playpause', 'prevtrack', 'print', 'printscreen', 'prntscrn',
'prtsc', 'prtscr', 'return', 'right', 'scrolllock', 'select', 'separator',
'shift', 'shiftleft', 'shiftright', 'sleep', 'space', 'stop', 'subtract', 'tab',
'up', 'volumedown', 'volumemute', 'volumeup', 'win', 'winleft', 'winright', 'yen',
'command', 'option', 'optionleft', 'optionright']
KEYMAPPING = {
'arrowup': 'up',
'arrowdown': 'down',
'arrowleft': 'left',
'arrowright': 'right',
'meta': 'command',
'contextmenu': 'apps',
'control': 'ctrl',
}
INVALID_KEYS = ['f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f20',
'f21', 'f22', 'f23', 'f24', 'select', 'separator', 'execute']
VALID_KEYS = [key for key in KEYS if key not in INVALID_KEYS]
itos = VALID_KEYS
stoi = {key: i for i, key in enumerate(itos)}
app = FastAPI()
# Mount the static directory to serve HTML, JavaScript, and CSS files
app.mount("/static", StaticFiles(directory="static"), name="static")
# Add this at the top with other global variables
connection_counter = 0
# Connection timeout settings
CONNECTION_TIMEOUT = 20 + 1 # 20 seconds timeout plus 1 second grace period
WARNING_TIME = 10 + 1 # 10 seconds warning before timeout plus 1 second grace period
# Create a thread pool executor
thread_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
def prepare_model_inputs(
previous_frame: torch.Tensor,
hidden_states: Any,
x: int,
y: int,
right_click: bool,
left_click: bool,
keys_down: List[str],
stoi: Dict[str, int],
itos: List[str],
time_step: int
) -> Dict[str, torch.Tensor]:
"""Prepare inputs for the model."""
# Clamp coordinates to valid ranges
x = min(max(0, x), SCREEN_WIDTH - 1) if x is not None else 0
y = min(max(0, y), SCREEN_HEIGHT - 1) if y is not None else 0
if DEBUG_MODE:
print ('DEBUG MODE, SETTING TIME STEP TO 0')
time_step = 0
if DEBUG_MODE_2:
if time_step > NUM_MAX_FRAMES-1:
print ('DEBUG MODE_2, SETTING TIME STEP TO 0')
time_step = 0
inputs = {
'image_features': previous_frame.to(device),
'is_padding': torch.BoolTensor([time_step == 0]).to(device),
'x': torch.LongTensor([x]).unsqueeze(0).to(device),
'y': torch.LongTensor([y]).unsqueeze(0).to(device),
'is_leftclick': torch.BoolTensor([left_click]).unsqueeze(0).to(device),
'is_rightclick': torch.BoolTensor([right_click]).unsqueeze(0).to(device),
'key_events': torch.zeros(len(itos), dtype=torch.long).to(device)
}
for key in keys_down:
key = key.lower()
if key in KEYMAPPING:
key = KEYMAPPING[key]
if key in stoi:
inputs['key_events'][stoi[key]] = 1
else:
print (f'Key {key} not found in stoi')
if hidden_states is not None:
inputs['hidden_states'] = hidden_states
if DEBUG_MODE:
print ('DEBUG MODE, REMOVING INPUTS')
if 'hidden_states' in inputs:
del inputs['hidden_states']
if DEBUG_MODE_2:
if time_step > NUM_MAX_FRAMES-1:
print ('DEBUG MODE_2, REMOVING HIDDEN STATES')
if 'hidden_states' in inputs:
del inputs['hidden_states']
print (f'Time step: {time_step}')
return inputs
@torch.no_grad()
async def process_frame(
model: LatentDiffusion,
inputs: Dict[str, torch.Tensor],
use_rnn: bool = False,
num_sampling_steps: int = 32
) -> Tuple[torch.Tensor, np.ndarray, Any, Dict[str, float]]:
"""Process a single frame through the model."""
# Run the heavy computation in a separate thread
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
thread_executor,
lambda: _process_frame_sync(model, inputs, use_rnn, num_sampling_steps)
)
def _process_frame_sync(model, inputs, use_rnn, num_sampling_steps):
"""Synchronous version of process_frame that runs in a thread"""
timing = {}
# Temporal encoding
start = time.perf_counter()
output_from_rnn, hidden_states = model.temporal_encoder.forward_step(inputs)
timing['temporal_encoder'] = time.perf_counter() - start
# UNet sampling
start = time.perf_counter()
print (f"model.clip_denoised: {model.clip_denoised}")
model.clip_denoised = False
print (f"USE_RNN: {use_rnn}, NUM_SAMPLING_STEPS: {num_sampling_steps}")
if use_rnn:
sample_latent = output_from_rnn[:, :16]
else:
#NUM_SAMPLING_STEPS = 8
if num_sampling_steps >= TIMESTEPS:
sample_latent = model.p_sample_loop(cond={'c_concat': output_from_rnn}, shape=[1, *LATENT_DIMS], return_intermediates=False, verbose=True)
else:
if num_sampling_steps == 1:
x = torch.randn([1, *LATENT_DIMS], device=device)
t = torch.full((1,), TIMESTEPS-1, device=device, dtype=torch.long)
sample_latent = model.apply_model(x, t, {'c_concat': output_from_rnn})
else:
sampler = DDIMSampler(model)
sample_latent, _ = sampler.sample(
S=num_sampling_steps,
conditioning={'c_concat': output_from_rnn},
batch_size=1,
shape=LATENT_DIMS,
verbose=False
)
timing['unet'] = time.perf_counter() - start
# Decoding
start = time.perf_counter()
sample = sample_latent * DATA_NORMALIZATION['std'].view(1, -1, 1, 1) + DATA_NORMALIZATION['mean'].view(1, -1, 1, 1)
# Use time.sleep(10) here since it's in a separate thread
#time.sleep(10)
sample = model.decode_first_stage(sample)
sample = sample.squeeze(0).clamp(-1, 1)
timing['decode'] = time.perf_counter() - start
# Convert to image
sample_img = ((sample[:3].transpose(0,1).transpose(1,2).cpu().float().numpy() + 1) * 127.5).astype(np.uint8)
timing['total'] = sum(timing.values())
return sample_latent, sample_img, hidden_states, timing
def print_timing_stats(timing_info: Dict[str, float], frame_num: int):
"""Print timing statistics for a frame."""
print(f"\nFrame {frame_num} timing (seconds):")
for key, value in timing_info.items():
print(f" {key.title()}: {value:.4f}")
print(f" FPS: {1.0/timing_info['full_frame']:.2f}")
# Serve the index.html file at the root URL
@app.get("/")
async def get():
return HTMLResponse(open("static/index.html").read())
# WebSocket endpoint for continuous user interaction
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
global connection_counter
connection_counter += 1
client_id = f"{int(time.time())}_{connection_counter}"
print(f"New WebSocket connection: {client_id}")
await websocket.accept()
try:
previous_frame = padding_image
hidden_states = None
keys_down = set() # Initialize as an empty set
frame_num = -1
# Client-specific settings
client_settings = {
"use_rnn": USE_RNN, # Start with default global value
"sampling_steps": NUM_SAMPLING_STEPS # Start with default global value
}
# Connection timeout tracking
last_user_activity_time = time.perf_counter()
timeout_warning_sent = False
timeout_task = None
connection_active = True # Flag to track if connection is still active
user_has_interacted = False # Flag to track if user has started interacting
# Start timing for global FPS calculation
connection_start_time = time.perf_counter()
frame_count = 0
# Input queue management - use asyncio.Queue instead of a list
input_queue = asyncio.Queue()
is_processing = False
# Add a function to reset the simulation
async def reset_simulation():
nonlocal previous_frame, hidden_states, keys_down, frame_num, is_processing, input_queue, user_has_interacted
# Keep the client settings during reset
temp_client_settings = client_settings.copy()
# Log the reset action
log_interaction(
client_id,
{"type": "reset"},
is_end_of_session=False,
is_reset=True # Add this parameter to the log_interaction function
)
# Clear the input queue
while not input_queue.empty():
try:
input_queue.get_nowait()
input_queue.task_done()
except asyncio.QueueEmpty:
break
# Reset all state variables
previous_frame = padding_image
hidden_states = None
keys_down = set()
frame_num = -1
is_processing = False
user_has_interacted = False # Reset user interaction state
# Restore client settings
client_settings.update(temp_client_settings)
print(f"[{time.perf_counter():.3f}] Simulation reset to initial state (preserved settings: USE_RNN={client_settings['use_rnn']}, SAMPLING_STEPS={client_settings['sampling_steps']})")
print(f"[{time.perf_counter():.3f}] User interaction state reset - waiting for user to interact again")
# Send confirmation to client
await websocket.send_json({"type": "reset_confirmed"})
# Also send the current settings to update the UI
await websocket.send_json({
"type": "settings",
"sampling_steps": client_settings["sampling_steps"],
"use_rnn": client_settings["use_rnn"]
})
# Add a function to update sampling steps
async def update_sampling_steps(steps):
nonlocal client_settings
# Validate the input
if steps < 1:
print(f"[{time.perf_counter():.3f}] Invalid sampling steps value: {steps}")
await websocket.send_json({"type": "error", "message": "Invalid sampling steps value"})
return
# Update the client-specific setting
old_steps = client_settings["sampling_steps"]
client_settings["sampling_steps"] = steps
print(f"[{time.perf_counter():.3f}] Updated sampling steps for client {client_id} from {old_steps} to {steps}")
# Send confirmation to client
await websocket.send_json({"type": "steps_updated", "steps": steps})
# Add a function to update USE_RNN setting
async def update_use_rnn(use_rnn):
nonlocal client_settings
# Update the client-specific setting
old_setting = client_settings["use_rnn"]
client_settings["use_rnn"] = use_rnn
print(f"[{time.perf_counter():.3f}] Updated USE_RNN for client {client_id} from {old_setting} to {use_rnn}")
# Send confirmation to client
await websocket.send_json({"type": "rnn_updated", "use_rnn": use_rnn})
# Add timeout checking function
async def check_timeout():
nonlocal timeout_warning_sent, timeout_task, connection_active, user_has_interacted
while True:
try:
# Check if WebSocket is still connected and connection is still active
if not connection_active or websocket.client_state.value >= 2: # CLOSING or CLOSED
print(f"[{time.perf_counter():.3f}] Connection inactive or WebSocket closed, stopping timeout check for client {client_id}")
return
# Don't start timeout tracking until user has actually interacted
if not user_has_interacted:
print(f"[{time.perf_counter():.3f}] User hasn't interacted yet, skipping timeout check for client {client_id}")
await asyncio.sleep(1) # Check every second
continue
current_time = time.perf_counter()
time_since_activity = current_time - last_user_activity_time
print(f"[{current_time:.3f}] Timeout check - time_since_activity: {time_since_activity:.1f}s, WARNING_TIME: {WARNING_TIME}s, CONNECTION_TIMEOUT: {CONNECTION_TIMEOUT}s")
# Send warning at 10 seconds
if time_since_activity >= WARNING_TIME and not timeout_warning_sent:
print(f"[{current_time:.3f}] Sending timeout warning to client {client_id}")
await websocket.send_json({
"type": "timeout_warning",
"timeout_in": CONNECTION_TIMEOUT - WARNING_TIME
})
timeout_warning_sent = True
print(f"[{current_time:.3f}] Timeout warning sent, timeout_warning_sent: {timeout_warning_sent}")
# Close connection at 20 seconds
if time_since_activity >= CONNECTION_TIMEOUT:
print(f"[{current_time:.3f}] TIMEOUT REACHED! Closing connection {client_id} due to timeout")
print(f"[{current_time:.3f}] time_since_activity: {time_since_activity:.1f}s >= CONNECTION_TIMEOUT: {CONNECTION_TIMEOUT}s")
# Clear the input queue before closing
queue_size_before = input_queue.qsize()
print(f"[{current_time:.3f}] Clearing input queue, size before: {queue_size_before}")
while not input_queue.empty():
try:
input_queue.get_nowait()
input_queue.task_done()
except asyncio.QueueEmpty:
break
print(f"[{current_time:.3f}] Input queue cleared, size after: {input_queue.qsize()}")
print(f"[{current_time:.3f}] About to close WebSocket connection...")
await websocket.close(code=1000, reason="User inactivity timeout")
print(f"[{current_time:.3f}] WebSocket.close() called, returning from check_timeout")
return
await asyncio.sleep(1) # Check every second
except Exception as e:
print(f"[{time.perf_counter():.3f}] Error in timeout check for client {client_id}: {e}")
import traceback
traceback.print_exc()
break
# Function to update user activity
def update_user_activity():
nonlocal last_user_activity_time, timeout_warning_sent
old_time = last_user_activity_time
last_user_activity_time = time.perf_counter()
print(f"[{time.perf_counter():.3f}] User activity detected for client {client_id}")
print(f"[{time.perf_counter():.3f}] last_user_activity_time updated: {old_time:.3f} -> {last_user_activity_time:.3f}")
if timeout_warning_sent:
print(f"[{time.perf_counter():.3f}] User activity detected, resetting timeout warning for client {client_id}")
timeout_warning_sent = False
print(f"[{time.perf_counter():.3f}] timeout_warning_sent reset to: {timeout_warning_sent}")
# Send activity reset notification to client
asyncio.create_task(websocket.send_json({"type": "activity_reset"}))
print(f"[{time.perf_counter():.3f}] Activity reset message sent to client")
# Start timeout checking
timeout_task = asyncio.create_task(check_timeout())
print(f"[{time.perf_counter():.3f}] Timeout task started for client {client_id} (waiting for user interaction)")
async def process_input(data):
nonlocal previous_frame, hidden_states, keys_down, frame_num, frame_count, is_processing, user_has_interacted
try:
process_start_time = time.perf_counter()
queue_size = input_queue.qsize()
print(f"[{process_start_time:.3f}] Starting to process input. Queue size before: {queue_size}")
frame_num += 1
frame_count += 1 # Increment total frame counter
# Calculate global FPS
total_elapsed = process_start_time - connection_start_time
global_fps = frame_count / total_elapsed if total_elapsed > 0 else 0
# change x and y to be between 0 and width/height-1 in data
data['x'] = max(0, min(data['x'], SCREEN_WIDTH - 1))
data['y'] = max(0, min(data['y'], SCREEN_HEIGHT - 1))
x = data.get("x")
y = data.get("y")
assert 0 <= x < SCREEN_WIDTH, f"x: {x} is out of range"
assert 0 <= y < SCREEN_HEIGHT, f"y: {y} is out of range"
is_left_click = data.get("is_left_click")
is_right_click = data.get("is_right_click")
keys_down_list = data.get("keys_down", []) # Get as list
keys_up_list = data.get("keys_up", [])
is_auto_input = data.get("is_auto_input", False)
if is_auto_input:
print (f'[{time.perf_counter():.3f}] Auto-input detected')
else:
# Update user activity for non-auto inputs
update_user_activity()
# Mark that user has started interacting
if not user_has_interacted:
user_has_interacted = True
print(f"[{time.perf_counter():.3f}] User has started interacting with canvas for client {client_id}")
wheel_delta_x = data.get("wheel_delta_x", 0)
wheel_delta_y = data.get("wheel_delta_y", 0)
print(f'[{time.perf_counter():.3f}] Processing: x: {x}, y: {y}, is_left_click: {is_left_click}, is_right_click: {is_right_click}, keys_down_list: {keys_down_list}, keys_up_list: {keys_up_list}, wheel: ({wheel_delta_x},{wheel_delta_y}), time_since_activity: {time.perf_counter() - last_user_activity_time:.3f}')
# Update the set based on the received data
for key in keys_down_list:
key = key.lower()
if key in KEYMAPPING:
key = KEYMAPPING[key]
keys_down.add(key)
for key in keys_up_list:
key = key.lower()
if key in KEYMAPPING:
key = KEYMAPPING[key]
if key in keys_down: # Check if key exists to avoid KeyError
keys_down.remove(key)
if DEBUG_MODE:
print (f"DEBUG MODE, REMOVING HIDDEN STATES")
previous_frame = padding_image
if DEBUG_MODE_2:
print (f'dsfdasdf frame_num: {frame_num}')
if frame_num > NUM_MAX_FRAMES-1:
print (f"DEBUG MODE_2, REMOVING HIDDEN STATES")
previous_frame = padding_image
frame_num = 0
inputs = prepare_model_inputs(previous_frame, hidden_states, x, y, is_right_click, is_left_click, list(keys_down), stoi, itos, frame_num)
# Use client-specific settings
client_use_rnn = client_settings["use_rnn"]
client_sampling_steps = client_settings["sampling_steps"]
print(f"[{time.perf_counter():.3f}] Starting model inference with client settings - USE_RNN: {client_use_rnn}, SAMPLING_STEPS: {client_sampling_steps}...")
# Pass client-specific settings to process_frame
previous_frame, sample_img, hidden_states, timing_info = await process_frame(
model,
inputs,
use_rnn=client_use_rnn,
num_sampling_steps=client_sampling_steps
)
print (f'Client {client_id} settings: USE_RNN: {client_use_rnn}, SAMPLING_STEPS: {client_sampling_steps}')
timing_info['full_frame'] = time.perf_counter() - process_start_time
print(f"[{time.perf_counter():.3f}] Model inference complete. Queue size now: {input_queue.qsize()}")
# Use the provided function to print timing statistics
print_timing_stats(timing_info, frame_num)
# Print global FPS measurement
print(f" Global FPS: {global_fps:.2f} (total: {frame_count} frames in {total_elapsed:.2f}s)")
img = Image.fromarray(sample_img)
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
# Send the generated frame back to the client
print(f"[{time.perf_counter():.3f}] Sending image to client...")
try:
await websocket.send_json({"image": img_str})
print(f"[{time.perf_counter():.3f}] Image sent. Queue size before next_input: {input_queue.qsize()}")
except RuntimeError as e:
if "Cannot call 'send' once a close message has been sent" in str(e):
print(f"[{time.perf_counter():.3f}] WebSocket closed, skipping image send")
else:
raise e
except Exception as e:
print(f"[{time.perf_counter():.3f}] Error sending image: {e}")
# Log the input
log_interaction(client_id, data, generated_frame=sample_img)
finally:
is_processing = False
print(f"[{time.perf_counter():.3f}] Processing complete. Queue size before checking next input: {input_queue.qsize()}")
# Check if we have more inputs to process after this one
if not input_queue.empty():
print(f"[{time.perf_counter():.3f}] Queue not empty, processing next input")
asyncio.create_task(process_next_input())
async def process_next_input():
nonlocal is_processing
current_time = time.perf_counter()
if input_queue.empty():
print(f"[{current_time:.3f}] No inputs to process. Queue is empty.")
is_processing = False
return
# Check if WebSocket is still open by checking if it's in a closed state
if websocket.client_state.value >= 2: # CLOSING or CLOSED
print(f"[{current_time:.3f}] WebSocket in closed state ({websocket.client_state.value}), stopping processing")
is_processing = False
return
#if is_processing:
# print(f"[{current_time:.3f}] Already processing an input. Will check again later.")
# return
# Set is_processing to True before proceeding
is_processing = True
queue_size = input_queue.qsize()
print(f"[{current_time:.3f}] Processing next input. Queue size: {queue_size}")
try:
# Initialize variables to track progress
skipped = 0
latest_input = None
# Process the queue one item at a time
while not input_queue.empty():
current_input = await input_queue.get()
input_queue.task_done()
# Always update the latest input
latest_input = current_input
# Check if this is an interesting event
is_interesting = (current_input.get("is_left_click") or
current_input.get("is_right_click") or
(current_input.get("keys_down") and len(current_input.get("keys_down")) > 0) or
(current_input.get("keys_up") and len(current_input.get("keys_up")) > 0) or
current_input.get("wheel_delta_x", 0) != 0 or
current_input.get("wheel_delta_y", 0) != 0)
# Process immediately if interesting
if is_interesting:
print(f"[{current_time:.3f}] Found interesting input (skipped {skipped} events)")
await process_input(current_input) # AWAIT here instead of creating a task
is_processing = False
return
# Otherwise, continue to the next item
skipped += 1
# If this is the last item and no interesting inputs were found
if input_queue.empty():
print(f"[{current_time:.3f}] No interesting inputs, processing latest movement (skipped {skipped-1} events)")
await process_input(latest_input) # AWAIT here instead of creating a task
is_processing = False
return
except Exception as e:
print(f"[{current_time:.3f}] Error in process_next_input: {e}")
import traceback
traceback.print_exc()
is_processing = False # Make sure to reset on error
while True:
try:
# Receive user input
print(f"[{time.perf_counter():.3f}] Waiting for input... Queue size: {input_queue.qsize()}, is_processing: {is_processing}")
data = await websocket.receive_json()
receive_time = time.perf_counter()
if data.get("type") == "heartbeat":
await websocket.send_json({"type": "heartbeat_response"})
continue
# Handle reset command
if data.get("type") == "reset":
print(f"[{receive_time:.3f}] Received reset command")
update_user_activity() # Reset activity timer
await reset_simulation()
continue
# Handle sampling steps update
if data.get("type") == "update_sampling_steps":
print(f"[{receive_time:.3f}] Received request to update sampling steps")
update_user_activity() # Reset activity timer
await update_sampling_steps(data.get("steps", 32))
continue
# Handle USE_RNN update
if data.get("type") == "update_use_rnn":
print(f"[{receive_time:.3f}] Received request to update USE_RNN")
update_user_activity() # Reset activity timer
await update_use_rnn(data.get("use_rnn", False))
continue
# Handle settings request
if data.get("type") == "get_settings":
print(f"[{receive_time:.3f}] Received request for current settings")
update_user_activity() # Reset activity timer
await websocket.send_json({
"type": "settings",
"sampling_steps": client_settings["sampling_steps"],
"use_rnn": client_settings["use_rnn"]
})
continue
# Add the input to our queue
await input_queue.put(data)
print(f"[{receive_time:.3f}] Received input. Queue size now: {input_queue.qsize()}")
# Check if WebSocket is still open before processing
if websocket.client_state.value >= 2: # CLOSING or CLOSED
print(f"[{receive_time:.3f}] WebSocket closed, skipping processing")
continue
# If we're not currently processing, start processing this input
if not is_processing:
print(f"[{receive_time:.3f}] Not currently processing, will call process_next_input()")
is_processing = True
asyncio.create_task(process_next_input()) # Create task but don't await it
else:
print(f"[{receive_time:.3f}] Currently processing, new input queued for later")
except asyncio.TimeoutError:
print("WebSocket connection timed out")
except WebSocketDisconnect:
# Log final EOS entry
log_interaction(client_id, {}, is_end_of_session=True)
print(f"[{time.perf_counter():.3f}] WebSocket disconnected: {client_id}")
print(f"[{time.perf_counter():.3f}] WebSocketDisconnect exception caught")
break
except Exception as e:
print(f"Error in WebSocket connection {client_id}: {e}")
import traceback
traceback.print_exc()
finally:
# Clean up timeout task
print(f"[{time.perf_counter():.3f}] Cleaning up connection {client_id}")
connection_active = False # Signal that connection is being cleaned up
if timeout_task and not timeout_task.done():
print(f"[{time.perf_counter():.3f}] Cancelling timeout task for client {client_id}")
timeout_task.cancel()
try:
await timeout_task
print(f"[{time.perf_counter():.3f}] Timeout task cancelled successfully for client {client_id}")
except asyncio.CancelledError:
print(f"[{time.perf_counter():.3f}] Timeout task cancelled with CancelledError for client {client_id}")
pass
else:
print(f"[{time.perf_counter():.3f}] Timeout task already done or doesn't exist for client {client_id}")
# Print final FPS statistics when connection ends
if frame_num >= 0: # Only if we processed at least one frame
total_time = time.perf_counter() - connection_start_time
print(f"\nConnection {client_id} summary:")
print(f" Total frames processed: {frame_count}")
print(f" Total elapsed time: {total_time:.2f} seconds")
print(f" Average FPS: {frame_count/total_time:.2f}")
print(f"WebSocket connection closed: {client_id}")
def log_interaction(client_id, data, generated_frame=None, is_end_of_session=False, is_reset=False):
"""Log user interaction and optionally the generated frame."""
timestamp = time.time()
# Create directory structure if it doesn't exist
os.makedirs("interaction_logs", exist_ok=True)
# Structure the log entry
log_entry = {
"timestamp": timestamp,
"client_id": client_id,
"is_eos": is_end_of_session,
"is_reset": is_reset
}
# Include type if present (for reset, etc.)
if data.get("type"):
log_entry["type"] = data.get("type")
# Only include input data if this isn't just a control message
if not is_end_of_session and not is_reset:
log_entry["inputs"] = {
"x": data.get("x"),
"y": data.get("y"),
"is_left_click": data.get("is_left_click"),
"is_right_click": data.get("is_right_click"),
"keys_down": data.get("keys_down", []),
"keys_up": data.get("keys_up", []),
"wheel_delta_x": data.get("wheel_delta_x", 0),
"wheel_delta_y": data.get("wheel_delta_y", 0),
"is_auto_input": data.get("is_auto_input", False)
}
else:
# For EOS/reset records, just include minimal info
log_entry["inputs"] = None
# Save to a file (one file per session)
if not os.path.exists("interaction_logs"):
os.makedirs("interaction_logs", exist_ok=True)
session_file = f"interaction_logs/session_{client_id}.jsonl"
with open(session_file, "a") as f:
f.write(json.dumps(log_entry) + "\n")
# Optionally save the frame if provided
if generated_frame is not None and not is_end_of_session and not is_reset:
frame_dir = f"interaction_logs/frames_{client_id}"
os.makedirs(frame_dir, exist_ok=True)
frame_file = f"{frame_dir}/{timestamp:.6f}.png"
# Save the frame as PNG
Image.fromarray(generated_frame).save(frame_file)