Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
from fastapi.responses import HTMLResponse | |
from fastapi.staticfiles import StaticFiles | |
from typing import List, Tuple | |
import numpy as np | |
from PIL import Image, ImageDraw | |
import base64 | |
import io | |
import json | |
import asyncio | |
from utils import initialize_model, sample_frame | |
import torch | |
import os | |
import time | |
from typing import Any, Dict | |
from ldm.models.diffusion.ddpm import LatentDiffusion, DDIMSampler | |
import concurrent.futures | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
DEBUG_MODE = False | |
DEBUG_MODE_2 = False | |
NUM_MAX_FRAMES = 1 | |
TIMESTEPS = 1000 | |
SCREEN_WIDTH = 512 | |
SCREEN_HEIGHT = 384 | |
NUM_SAMPLING_STEPS = 32 | |
USE_RNN = False | |
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-384k" | |
MODEL_NAME = "yuntian-deng/computer-model-noss-forsure" | |
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-2k" | |
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-10k" | |
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-54k" | |
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-unfreezernn-160k" | |
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-368k" | |
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-unfreezernn-198k" | |
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-674k" | |
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-online-74k" | |
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-online-70k" | |
MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-online-x0-46k" | |
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-142k" | |
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-338k" | |
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-ddpm32-x0-140k" | |
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-ddpm32-eps-144k" | |
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-70k" | |
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-joint-onlineonly-eps22-40k" | |
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-22-38k" | |
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-222-42k" | |
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-2222-70k" | |
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-222222-48k" | |
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-222222k7-06k" | |
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-222222k7-114k" | |
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-222222k7-136k" | |
#MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-222222k7-184k" | |
#MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-222222k7-272k" | |
#MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-222222k7-272k" | |
#MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-oo-eps222222k72-270k" | |
MODEL_NAME = "yuntian-deng/computer-model-s-newnewd-freezernn-origunet-nospatial-online-x0-joint-onlineonly-222222k72-108k" | |
print (f'setting: DEBUG_MODE: {DEBUG_MODE}, DEBUG_MODE_2: {DEBUG_MODE_2}, NUM_MAX_FRAMES: {NUM_MAX_FRAMES}, NUM_SAMPLING_STEPS: {NUM_SAMPLING_STEPS}, MODEL_NAME: {MODEL_NAME}') | |
with open('latent_stats.json', 'r') as f: | |
latent_stats = json.load(f) | |
DATA_NORMALIZATION = {'mean': torch.tensor(latent_stats['mean']).to(device), 'std': torch.tensor(latent_stats['std']).to(device)} | |
LATENT_DIMS = (16, SCREEN_HEIGHT // 8, SCREEN_WIDTH // 8) | |
# Initialize the model at the start of your application | |
#model = initialize_model("config_csllm.yaml", "yuntian-deng/computer-model") | |
#model = initialize_model("config_rnn.yaml", "yuntian-deng/computer-model") | |
#model = initialize_model("config_final_model.yaml", "yuntian-deng/computer-model-noss") | |
#model = initialize_model("config_final_model.yaml", "yuntian-deng/computer-model") | |
if 'origunet' in MODEL_NAME: | |
if 'x0' in MODEL_NAME and 'eps' not in MODEL_NAME: | |
if 'ddpm32' in MODEL_NAME: | |
TIMESTEPS = 32 | |
model = initialize_model("config_final_model_origunet_nospatial_x0_ddpm32.yaml", MODEL_NAME) | |
else: | |
model = initialize_model("config_final_model_origunet_nospatial_x0.yaml", MODEL_NAME) | |
else: | |
if 'ddpm32' in MODEL_NAME: | |
TIMESTEPS = 32 | |
model = initialize_model("config_final_model_origunet_nospatial_ddpm32.yaml", MODEL_NAME) | |
else: | |
model = initialize_model("config_final_model_origunet_nospatial.yaml", MODEL_NAME) | |
else: | |
model = initialize_model("config_final_model.yaml", MODEL_NAME) | |
model = model.to(device) | |
#model = torch.compile(model) | |
padding_image = torch.zeros(*LATENT_DIMS).unsqueeze(0).to(device) | |
padding_image = (padding_image - DATA_NORMALIZATION['mean'].view(1, -1, 1, 1)) / DATA_NORMALIZATION['std'].view(1, -1, 1, 1) | |
# Valid keyboard inputs | |
KEYS = ['\t', '\n', '\r', ' ', '!', '"', '#', '$', '%', '&', "'", '(', | |
')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', | |
'8', '9', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', | |
'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', | |
'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', | |
'accept', 'add', 'alt', 'altleft', 'altright', 'apps', 'backspace', | |
'browserback', 'browserfavorites', 'browserforward', 'browserhome', | |
'browserrefresh', 'browsersearch', 'browserstop', 'capslock', 'clear', | |
'convert', 'ctrl', 'ctrlleft', 'ctrlright', 'decimal', 'del', 'delete', | |
'divide', 'down', 'end', 'enter', 'esc', 'escape', 'execute', 'f1', 'f10', | |
'f11', 'f12', 'f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f2', 'f20', | |
'f21', 'f22', 'f23', 'f24', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9', | |
'final', 'fn', 'hanguel', 'hangul', 'hanja', 'help', 'home', 'insert', 'junja', | |
'kana', 'kanji', 'launchapp1', 'launchapp2', 'launchmail', | |
'launchmediaselect', 'left', 'modechange', 'multiply', 'nexttrack', | |
'nonconvert', 'num0', 'num1', 'num2', 'num3', 'num4', 'num5', 'num6', | |
'num7', 'num8', 'num9', 'numlock', 'pagedown', 'pageup', 'pause', 'pgdn', | |
'pgup', 'playpause', 'prevtrack', 'print', 'printscreen', 'prntscrn', | |
'prtsc', 'prtscr', 'return', 'right', 'scrolllock', 'select', 'separator', | |
'shift', 'shiftleft', 'shiftright', 'sleep', 'space', 'stop', 'subtract', 'tab', | |
'up', 'volumedown', 'volumemute', 'volumeup', 'win', 'winleft', 'winright', 'yen', | |
'command', 'option', 'optionleft', 'optionright'] | |
KEYMAPPING = { | |
'arrowup': 'up', | |
'arrowdown': 'down', | |
'arrowleft': 'left', | |
'arrowright': 'right', | |
'meta': 'command', | |
'contextmenu': 'apps', | |
'control': 'ctrl', | |
} | |
INVALID_KEYS = ['f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f20', | |
'f21', 'f22', 'f23', 'f24', 'select', 'separator', 'execute'] | |
VALID_KEYS = [key for key in KEYS if key not in INVALID_KEYS] | |
itos = VALID_KEYS | |
stoi = {key: i for i, key in enumerate(itos)} | |
app = FastAPI() | |
# Mount the static directory to serve HTML, JavaScript, and CSS files | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
# Add this at the top with other global variables | |
connection_counter = 0 | |
# Connection timeout settings | |
CONNECTION_TIMEOUT = 20 + 1 # 20 seconds timeout plus 1 second grace period | |
WARNING_TIME = 10 + 1 # 10 seconds warning before timeout plus 1 second grace period | |
# Create a thread pool executor | |
thread_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) | |
def prepare_model_inputs( | |
previous_frame: torch.Tensor, | |
hidden_states: Any, | |
x: int, | |
y: int, | |
right_click: bool, | |
left_click: bool, | |
keys_down: List[str], | |
stoi: Dict[str, int], | |
itos: List[str], | |
time_step: int | |
) -> Dict[str, torch.Tensor]: | |
"""Prepare inputs for the model.""" | |
# Clamp coordinates to valid ranges | |
x = min(max(0, x), SCREEN_WIDTH - 1) if x is not None else 0 | |
y = min(max(0, y), SCREEN_HEIGHT - 1) if y is not None else 0 | |
if DEBUG_MODE: | |
print ('DEBUG MODE, SETTING TIME STEP TO 0') | |
time_step = 0 | |
if DEBUG_MODE_2: | |
if time_step > NUM_MAX_FRAMES-1: | |
print ('DEBUG MODE_2, SETTING TIME STEP TO 0') | |
time_step = 0 | |
inputs = { | |
'image_features': previous_frame.to(device), | |
'is_padding': torch.BoolTensor([time_step == 0]).to(device), | |
'x': torch.LongTensor([x]).unsqueeze(0).to(device), | |
'y': torch.LongTensor([y]).unsqueeze(0).to(device), | |
'is_leftclick': torch.BoolTensor([left_click]).unsqueeze(0).to(device), | |
'is_rightclick': torch.BoolTensor([right_click]).unsqueeze(0).to(device), | |
'key_events': torch.zeros(len(itos), dtype=torch.long).to(device) | |
} | |
for key in keys_down: | |
key = key.lower() | |
if key in KEYMAPPING: | |
key = KEYMAPPING[key] | |
if key in stoi: | |
inputs['key_events'][stoi[key]] = 1 | |
else: | |
print (f'Key {key} not found in stoi') | |
if hidden_states is not None: | |
inputs['hidden_states'] = hidden_states | |
if DEBUG_MODE: | |
print ('DEBUG MODE, REMOVING INPUTS') | |
if 'hidden_states' in inputs: | |
del inputs['hidden_states'] | |
if DEBUG_MODE_2: | |
if time_step > NUM_MAX_FRAMES-1: | |
print ('DEBUG MODE_2, REMOVING HIDDEN STATES') | |
if 'hidden_states' in inputs: | |
del inputs['hidden_states'] | |
print (f'Time step: {time_step}') | |
return inputs | |
async def process_frame( | |
model: LatentDiffusion, | |
inputs: Dict[str, torch.Tensor], | |
use_rnn: bool = False, | |
num_sampling_steps: int = 32 | |
) -> Tuple[torch.Tensor, np.ndarray, Any, Dict[str, float]]: | |
"""Process a single frame through the model.""" | |
# Run the heavy computation in a separate thread | |
loop = asyncio.get_running_loop() | |
return await loop.run_in_executor( | |
thread_executor, | |
lambda: _process_frame_sync(model, inputs, use_rnn, num_sampling_steps) | |
) | |
def _process_frame_sync(model, inputs, use_rnn, num_sampling_steps): | |
"""Synchronous version of process_frame that runs in a thread""" | |
timing = {} | |
# Temporal encoding | |
start = time.perf_counter() | |
output_from_rnn, hidden_states = model.temporal_encoder.forward_step(inputs) | |
timing['temporal_encoder'] = time.perf_counter() - start | |
# UNet sampling | |
start = time.perf_counter() | |
print (f"model.clip_denoised: {model.clip_denoised}") | |
model.clip_denoised = False | |
print (f"USE_RNN: {use_rnn}, NUM_SAMPLING_STEPS: {num_sampling_steps}") | |
if use_rnn: | |
sample_latent = output_from_rnn[:, :16] | |
else: | |
#NUM_SAMPLING_STEPS = 8 | |
if num_sampling_steps >= TIMESTEPS: | |
sample_latent = model.p_sample_loop(cond={'c_concat': output_from_rnn}, shape=[1, *LATENT_DIMS], return_intermediates=False, verbose=True) | |
else: | |
if num_sampling_steps == 1: | |
x = torch.randn([1, *LATENT_DIMS], device=device) | |
t = torch.full((1,), TIMESTEPS-1, device=device, dtype=torch.long) | |
sample_latent = model.apply_model(x, t, {'c_concat': output_from_rnn}) | |
else: | |
sampler = DDIMSampler(model) | |
sample_latent, _ = sampler.sample( | |
S=num_sampling_steps, | |
conditioning={'c_concat': output_from_rnn}, | |
batch_size=1, | |
shape=LATENT_DIMS, | |
verbose=False | |
) | |
timing['unet'] = time.perf_counter() - start | |
# Decoding | |
start = time.perf_counter() | |
sample = sample_latent * DATA_NORMALIZATION['std'].view(1, -1, 1, 1) + DATA_NORMALIZATION['mean'].view(1, -1, 1, 1) | |
# Use time.sleep(10) here since it's in a separate thread | |
#time.sleep(10) | |
sample = model.decode_first_stage(sample) | |
sample = sample.squeeze(0).clamp(-1, 1) | |
timing['decode'] = time.perf_counter() - start | |
# Convert to image | |
sample_img = ((sample[:3].transpose(0,1).transpose(1,2).cpu().float().numpy() + 1) * 127.5).astype(np.uint8) | |
timing['total'] = sum(timing.values()) | |
return sample_latent, sample_img, hidden_states, timing | |
def print_timing_stats(timing_info: Dict[str, float], frame_num: int): | |
"""Print timing statistics for a frame.""" | |
print(f"\nFrame {frame_num} timing (seconds):") | |
for key, value in timing_info.items(): | |
print(f" {key.title()}: {value:.4f}") | |
print(f" FPS: {1.0/timing_info['full_frame']:.2f}") | |
# Serve the index.html file at the root URL | |
async def get(): | |
return HTMLResponse(open("static/index.html").read()) | |
# WebSocket endpoint for continuous user interaction | |
async def websocket_endpoint(websocket: WebSocket): | |
global connection_counter | |
connection_counter += 1 | |
client_id = f"{int(time.time())}_{connection_counter}" | |
print(f"New WebSocket connection: {client_id}") | |
await websocket.accept() | |
try: | |
previous_frame = padding_image | |
hidden_states = None | |
keys_down = set() # Initialize as an empty set | |
frame_num = -1 | |
# Client-specific settings | |
client_settings = { | |
"use_rnn": USE_RNN, # Start with default global value | |
"sampling_steps": NUM_SAMPLING_STEPS # Start with default global value | |
} | |
# Connection timeout tracking | |
last_user_activity_time = time.perf_counter() | |
timeout_warning_sent = False | |
timeout_task = None | |
connection_active = True # Flag to track if connection is still active | |
user_has_interacted = False # Flag to track if user has started interacting | |
# Start timing for global FPS calculation | |
connection_start_time = time.perf_counter() | |
frame_count = 0 | |
# Input queue management - use asyncio.Queue instead of a list | |
input_queue = asyncio.Queue() | |
is_processing = False | |
# Add a function to reset the simulation | |
async def reset_simulation(): | |
nonlocal previous_frame, hidden_states, keys_down, frame_num, is_processing, input_queue, user_has_interacted | |
# Keep the client settings during reset | |
temp_client_settings = client_settings.copy() | |
# Log the reset action | |
log_interaction( | |
client_id, | |
{"type": "reset"}, | |
is_end_of_session=False, | |
is_reset=True # Add this parameter to the log_interaction function | |
) | |
# Clear the input queue | |
while not input_queue.empty(): | |
try: | |
input_queue.get_nowait() | |
input_queue.task_done() | |
except asyncio.QueueEmpty: | |
break | |
# Reset all state variables | |
previous_frame = padding_image | |
hidden_states = None | |
keys_down = set() | |
frame_num = -1 | |
is_processing = False | |
user_has_interacted = False # Reset user interaction state | |
# Restore client settings | |
client_settings.update(temp_client_settings) | |
print(f"[{time.perf_counter():.3f}] Simulation reset to initial state (preserved settings: USE_RNN={client_settings['use_rnn']}, SAMPLING_STEPS={client_settings['sampling_steps']})") | |
print(f"[{time.perf_counter():.3f}] User interaction state reset - waiting for user to interact again") | |
# Send confirmation to client | |
await websocket.send_json({"type": "reset_confirmed"}) | |
# Also send the current settings to update the UI | |
await websocket.send_json({ | |
"type": "settings", | |
"sampling_steps": client_settings["sampling_steps"], | |
"use_rnn": client_settings["use_rnn"] | |
}) | |
# Add a function to update sampling steps | |
async def update_sampling_steps(steps): | |
nonlocal client_settings | |
# Validate the input | |
if steps < 1: | |
print(f"[{time.perf_counter():.3f}] Invalid sampling steps value: {steps}") | |
await websocket.send_json({"type": "error", "message": "Invalid sampling steps value"}) | |
return | |
# Update the client-specific setting | |
old_steps = client_settings["sampling_steps"] | |
client_settings["sampling_steps"] = steps | |
print(f"[{time.perf_counter():.3f}] Updated sampling steps for client {client_id} from {old_steps} to {steps}") | |
# Send confirmation to client | |
await websocket.send_json({"type": "steps_updated", "steps": steps}) | |
# Add a function to update USE_RNN setting | |
async def update_use_rnn(use_rnn): | |
nonlocal client_settings | |
# Update the client-specific setting | |
old_setting = client_settings["use_rnn"] | |
client_settings["use_rnn"] = use_rnn | |
print(f"[{time.perf_counter():.3f}] Updated USE_RNN for client {client_id} from {old_setting} to {use_rnn}") | |
# Send confirmation to client | |
await websocket.send_json({"type": "rnn_updated", "use_rnn": use_rnn}) | |
# Add timeout checking function | |
async def check_timeout(): | |
nonlocal timeout_warning_sent, timeout_task, connection_active, user_has_interacted | |
while True: | |
try: | |
# Check if WebSocket is still connected and connection is still active | |
if not connection_active or websocket.client_state.value >= 2: # CLOSING or CLOSED | |
print(f"[{time.perf_counter():.3f}] Connection inactive or WebSocket closed, stopping timeout check for client {client_id}") | |
return | |
# Don't start timeout tracking until user has actually interacted | |
if not user_has_interacted: | |
print(f"[{time.perf_counter():.3f}] User hasn't interacted yet, skipping timeout check for client {client_id}") | |
await asyncio.sleep(1) # Check every second | |
continue | |
current_time = time.perf_counter() | |
time_since_activity = current_time - last_user_activity_time | |
print(f"[{current_time:.3f}] Timeout check - time_since_activity: {time_since_activity:.1f}s, WARNING_TIME: {WARNING_TIME}s, CONNECTION_TIMEOUT: {CONNECTION_TIMEOUT}s") | |
# Send warning at 10 seconds | |
if time_since_activity >= WARNING_TIME and not timeout_warning_sent: | |
print(f"[{current_time:.3f}] Sending timeout warning to client {client_id}") | |
await websocket.send_json({ | |
"type": "timeout_warning", | |
"timeout_in": CONNECTION_TIMEOUT - WARNING_TIME | |
}) | |
timeout_warning_sent = True | |
print(f"[{current_time:.3f}] Timeout warning sent, timeout_warning_sent: {timeout_warning_sent}") | |
# Close connection at 20 seconds | |
if time_since_activity >= CONNECTION_TIMEOUT: | |
print(f"[{current_time:.3f}] TIMEOUT REACHED! Closing connection {client_id} due to timeout") | |
print(f"[{current_time:.3f}] time_since_activity: {time_since_activity:.1f}s >= CONNECTION_TIMEOUT: {CONNECTION_TIMEOUT}s") | |
# Clear the input queue before closing | |
queue_size_before = input_queue.qsize() | |
print(f"[{current_time:.3f}] Clearing input queue, size before: {queue_size_before}") | |
while not input_queue.empty(): | |
try: | |
input_queue.get_nowait() | |
input_queue.task_done() | |
except asyncio.QueueEmpty: | |
break | |
print(f"[{current_time:.3f}] Input queue cleared, size after: {input_queue.qsize()}") | |
print(f"[{current_time:.3f}] About to close WebSocket connection...") | |
await websocket.close(code=1000, reason="User inactivity timeout") | |
print(f"[{current_time:.3f}] WebSocket.close() called, returning from check_timeout") | |
return | |
await asyncio.sleep(1) # Check every second | |
except Exception as e: | |
print(f"[{time.perf_counter():.3f}] Error in timeout check for client {client_id}: {e}") | |
import traceback | |
traceback.print_exc() | |
break | |
# Function to update user activity | |
def update_user_activity(): | |
nonlocal last_user_activity_time, timeout_warning_sent | |
old_time = last_user_activity_time | |
last_user_activity_time = time.perf_counter() | |
print(f"[{time.perf_counter():.3f}] User activity detected for client {client_id}") | |
print(f"[{time.perf_counter():.3f}] last_user_activity_time updated: {old_time:.3f} -> {last_user_activity_time:.3f}") | |
if timeout_warning_sent: | |
print(f"[{time.perf_counter():.3f}] User activity detected, resetting timeout warning for client {client_id}") | |
timeout_warning_sent = False | |
print(f"[{time.perf_counter():.3f}] timeout_warning_sent reset to: {timeout_warning_sent}") | |
# Send activity reset notification to client | |
asyncio.create_task(websocket.send_json({"type": "activity_reset"})) | |
print(f"[{time.perf_counter():.3f}] Activity reset message sent to client") | |
# Start timeout checking | |
timeout_task = asyncio.create_task(check_timeout()) | |
print(f"[{time.perf_counter():.3f}] Timeout task started for client {client_id} (waiting for user interaction)") | |
async def process_input(data): | |
nonlocal previous_frame, hidden_states, keys_down, frame_num, frame_count, is_processing, user_has_interacted | |
try: | |
process_start_time = time.perf_counter() | |
queue_size = input_queue.qsize() | |
print(f"[{process_start_time:.3f}] Starting to process input. Queue size before: {queue_size}") | |
frame_num += 1 | |
frame_count += 1 # Increment total frame counter | |
# Calculate global FPS | |
total_elapsed = process_start_time - connection_start_time | |
global_fps = frame_count / total_elapsed if total_elapsed > 0 else 0 | |
# change x and y to be between 0 and width/height-1 in data | |
data['x'] = max(0, min(data['x'], SCREEN_WIDTH - 1)) | |
data['y'] = max(0, min(data['y'], SCREEN_HEIGHT - 1)) | |
x = data.get("x") | |
y = data.get("y") | |
assert 0 <= x < SCREEN_WIDTH, f"x: {x} is out of range" | |
assert 0 <= y < SCREEN_HEIGHT, f"y: {y} is out of range" | |
is_left_click = data.get("is_left_click") | |
is_right_click = data.get("is_right_click") | |
keys_down_list = data.get("keys_down", []) # Get as list | |
keys_up_list = data.get("keys_up", []) | |
is_auto_input = data.get("is_auto_input", False) | |
if is_auto_input: | |
print (f'[{time.perf_counter():.3f}] Auto-input detected') | |
else: | |
# Update user activity for non-auto inputs | |
update_user_activity() | |
# Mark that user has started interacting | |
if not user_has_interacted: | |
user_has_interacted = True | |
print(f"[{time.perf_counter():.3f}] User has started interacting with canvas for client {client_id}") | |
wheel_delta_x = data.get("wheel_delta_x", 0) | |
wheel_delta_y = data.get("wheel_delta_y", 0) | |
print(f'[{time.perf_counter():.3f}] Processing: x: {x}, y: {y}, is_left_click: {is_left_click}, is_right_click: {is_right_click}, keys_down_list: {keys_down_list}, keys_up_list: {keys_up_list}, wheel: ({wheel_delta_x},{wheel_delta_y}), time_since_activity: {time.perf_counter() - last_user_activity_time:.3f}') | |
# Update the set based on the received data | |
for key in keys_down_list: | |
key = key.lower() | |
if key in KEYMAPPING: | |
key = KEYMAPPING[key] | |
keys_down.add(key) | |
for key in keys_up_list: | |
key = key.lower() | |
if key in KEYMAPPING: | |
key = KEYMAPPING[key] | |
if key in keys_down: # Check if key exists to avoid KeyError | |
keys_down.remove(key) | |
if DEBUG_MODE: | |
print (f"DEBUG MODE, REMOVING HIDDEN STATES") | |
previous_frame = padding_image | |
if DEBUG_MODE_2: | |
print (f'dsfdasdf frame_num: {frame_num}') | |
if frame_num > NUM_MAX_FRAMES-1: | |
print (f"DEBUG MODE_2, REMOVING HIDDEN STATES") | |
previous_frame = padding_image | |
frame_num = 0 | |
inputs = prepare_model_inputs(previous_frame, hidden_states, x, y, is_right_click, is_left_click, list(keys_down), stoi, itos, frame_num) | |
# Use client-specific settings | |
client_use_rnn = client_settings["use_rnn"] | |
client_sampling_steps = client_settings["sampling_steps"] | |
print(f"[{time.perf_counter():.3f}] Starting model inference with client settings - USE_RNN: {client_use_rnn}, SAMPLING_STEPS: {client_sampling_steps}...") | |
# Pass client-specific settings to process_frame | |
previous_frame, sample_img, hidden_states, timing_info = await process_frame( | |
model, | |
inputs, | |
use_rnn=client_use_rnn, | |
num_sampling_steps=client_sampling_steps | |
) | |
print (f'Client {client_id} settings: USE_RNN: {client_use_rnn}, SAMPLING_STEPS: {client_sampling_steps}') | |
timing_info['full_frame'] = time.perf_counter() - process_start_time | |
print(f"[{time.perf_counter():.3f}] Model inference complete. Queue size now: {input_queue.qsize()}") | |
# Use the provided function to print timing statistics | |
print_timing_stats(timing_info, frame_num) | |
# Print global FPS measurement | |
print(f" Global FPS: {global_fps:.2f} (total: {frame_count} frames in {total_elapsed:.2f}s)") | |
img = Image.fromarray(sample_img) | |
buffered = io.BytesIO() | |
img.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()).decode() | |
# Send the generated frame back to the client | |
print(f"[{time.perf_counter():.3f}] Sending image to client...") | |
try: | |
await websocket.send_json({"image": img_str}) | |
print(f"[{time.perf_counter():.3f}] Image sent. Queue size before next_input: {input_queue.qsize()}") | |
except RuntimeError as e: | |
if "Cannot call 'send' once a close message has been sent" in str(e): | |
print(f"[{time.perf_counter():.3f}] WebSocket closed, skipping image send") | |
else: | |
raise e | |
except Exception as e: | |
print(f"[{time.perf_counter():.3f}] Error sending image: {e}") | |
# Log the input | |
log_interaction(client_id, data, generated_frame=sample_img) | |
finally: | |
is_processing = False | |
print(f"[{time.perf_counter():.3f}] Processing complete. Queue size before checking next input: {input_queue.qsize()}") | |
# Check if we have more inputs to process after this one | |
if not input_queue.empty(): | |
print(f"[{time.perf_counter():.3f}] Queue not empty, processing next input") | |
asyncio.create_task(process_next_input()) | |
async def process_next_input(): | |
nonlocal is_processing | |
current_time = time.perf_counter() | |
if input_queue.empty(): | |
print(f"[{current_time:.3f}] No inputs to process. Queue is empty.") | |
is_processing = False | |
return | |
# Check if WebSocket is still open by checking if it's in a closed state | |
if websocket.client_state.value >= 2: # CLOSING or CLOSED | |
print(f"[{current_time:.3f}] WebSocket in closed state ({websocket.client_state.value}), stopping processing") | |
is_processing = False | |
return | |
#if is_processing: | |
# print(f"[{current_time:.3f}] Already processing an input. Will check again later.") | |
# return | |
# Set is_processing to True before proceeding | |
is_processing = True | |
queue_size = input_queue.qsize() | |
print(f"[{current_time:.3f}] Processing next input. Queue size: {queue_size}") | |
try: | |
# Initialize variables to track progress | |
skipped = 0 | |
latest_input = None | |
# Process the queue one item at a time | |
while not input_queue.empty(): | |
current_input = await input_queue.get() | |
input_queue.task_done() | |
# Always update the latest input | |
latest_input = current_input | |
# Check if this is an interesting event | |
is_interesting = (current_input.get("is_left_click") or | |
current_input.get("is_right_click") or | |
(current_input.get("keys_down") and len(current_input.get("keys_down")) > 0) or | |
(current_input.get("keys_up") and len(current_input.get("keys_up")) > 0) or | |
current_input.get("wheel_delta_x", 0) != 0 or | |
current_input.get("wheel_delta_y", 0) != 0) | |
# Process immediately if interesting | |
if is_interesting: | |
print(f"[{current_time:.3f}] Found interesting input (skipped {skipped} events)") | |
await process_input(current_input) # AWAIT here instead of creating a task | |
is_processing = False | |
return | |
# Otherwise, continue to the next item | |
skipped += 1 | |
# If this is the last item and no interesting inputs were found | |
if input_queue.empty(): | |
print(f"[{current_time:.3f}] No interesting inputs, processing latest movement (skipped {skipped-1} events)") | |
await process_input(latest_input) # AWAIT here instead of creating a task | |
is_processing = False | |
return | |
except Exception as e: | |
print(f"[{current_time:.3f}] Error in process_next_input: {e}") | |
import traceback | |
traceback.print_exc() | |
is_processing = False # Make sure to reset on error | |
while True: | |
try: | |
# Receive user input | |
print(f"[{time.perf_counter():.3f}] Waiting for input... Queue size: {input_queue.qsize()}, is_processing: {is_processing}") | |
data = await websocket.receive_json() | |
receive_time = time.perf_counter() | |
if data.get("type") == "heartbeat": | |
await websocket.send_json({"type": "heartbeat_response"}) | |
continue | |
# Handle reset command | |
if data.get("type") == "reset": | |
print(f"[{receive_time:.3f}] Received reset command") | |
update_user_activity() # Reset activity timer | |
await reset_simulation() | |
continue | |
# Handle sampling steps update | |
if data.get("type") == "update_sampling_steps": | |
print(f"[{receive_time:.3f}] Received request to update sampling steps") | |
update_user_activity() # Reset activity timer | |
await update_sampling_steps(data.get("steps", 32)) | |
continue | |
# Handle USE_RNN update | |
if data.get("type") == "update_use_rnn": | |
print(f"[{receive_time:.3f}] Received request to update USE_RNN") | |
update_user_activity() # Reset activity timer | |
await update_use_rnn(data.get("use_rnn", False)) | |
continue | |
# Handle settings request | |
if data.get("type") == "get_settings": | |
print(f"[{receive_time:.3f}] Received request for current settings") | |
update_user_activity() # Reset activity timer | |
await websocket.send_json({ | |
"type": "settings", | |
"sampling_steps": client_settings["sampling_steps"], | |
"use_rnn": client_settings["use_rnn"] | |
}) | |
continue | |
# Add the input to our queue | |
await input_queue.put(data) | |
print(f"[{receive_time:.3f}] Received input. Queue size now: {input_queue.qsize()}") | |
# Check if WebSocket is still open before processing | |
if websocket.client_state.value >= 2: # CLOSING or CLOSED | |
print(f"[{receive_time:.3f}] WebSocket closed, skipping processing") | |
continue | |
# If we're not currently processing, start processing this input | |
if not is_processing: | |
print(f"[{receive_time:.3f}] Not currently processing, will call process_next_input()") | |
is_processing = True | |
asyncio.create_task(process_next_input()) # Create task but don't await it | |
else: | |
print(f"[{receive_time:.3f}] Currently processing, new input queued for later") | |
except asyncio.TimeoutError: | |
print("WebSocket connection timed out") | |
except WebSocketDisconnect: | |
# Log final EOS entry | |
log_interaction(client_id, {}, is_end_of_session=True) | |
print(f"[{time.perf_counter():.3f}] WebSocket disconnected: {client_id}") | |
print(f"[{time.perf_counter():.3f}] WebSocketDisconnect exception caught") | |
break | |
except Exception as e: | |
print(f"Error in WebSocket connection {client_id}: {e}") | |
import traceback | |
traceback.print_exc() | |
finally: | |
# Clean up timeout task | |
print(f"[{time.perf_counter():.3f}] Cleaning up connection {client_id}") | |
connection_active = False # Signal that connection is being cleaned up | |
if timeout_task and not timeout_task.done(): | |
print(f"[{time.perf_counter():.3f}] Cancelling timeout task for client {client_id}") | |
timeout_task.cancel() | |
try: | |
await timeout_task | |
print(f"[{time.perf_counter():.3f}] Timeout task cancelled successfully for client {client_id}") | |
except asyncio.CancelledError: | |
print(f"[{time.perf_counter():.3f}] Timeout task cancelled with CancelledError for client {client_id}") | |
pass | |
else: | |
print(f"[{time.perf_counter():.3f}] Timeout task already done or doesn't exist for client {client_id}") | |
# Print final FPS statistics when connection ends | |
if frame_num >= 0: # Only if we processed at least one frame | |
total_time = time.perf_counter() - connection_start_time | |
print(f"\nConnection {client_id} summary:") | |
print(f" Total frames processed: {frame_count}") | |
print(f" Total elapsed time: {total_time:.2f} seconds") | |
print(f" Average FPS: {frame_count/total_time:.2f}") | |
print(f"WebSocket connection closed: {client_id}") | |
def log_interaction(client_id, data, generated_frame=None, is_end_of_session=False, is_reset=False): | |
"""Log user interaction and optionally the generated frame.""" | |
timestamp = time.time() | |
# Create directory structure if it doesn't exist | |
os.makedirs("interaction_logs", exist_ok=True) | |
# Structure the log entry | |
log_entry = { | |
"timestamp": timestamp, | |
"client_id": client_id, | |
"is_eos": is_end_of_session, | |
"is_reset": is_reset | |
} | |
# Include type if present (for reset, etc.) | |
if data.get("type"): | |
log_entry["type"] = data.get("type") | |
# Only include input data if this isn't just a control message | |
if not is_end_of_session and not is_reset: | |
log_entry["inputs"] = { | |
"x": data.get("x"), | |
"y": data.get("y"), | |
"is_left_click": data.get("is_left_click"), | |
"is_right_click": data.get("is_right_click"), | |
"keys_down": data.get("keys_down", []), | |
"keys_up": data.get("keys_up", []), | |
"wheel_delta_x": data.get("wheel_delta_x", 0), | |
"wheel_delta_y": data.get("wheel_delta_y", 0), | |
"is_auto_input": data.get("is_auto_input", False) | |
} | |
else: | |
# For EOS/reset records, just include minimal info | |
log_entry["inputs"] = None | |
# Save to a file (one file per session) | |
if not os.path.exists("interaction_logs"): | |
os.makedirs("interaction_logs", exist_ok=True) | |
session_file = f"interaction_logs/session_{client_id}.jsonl" | |
with open(session_file, "a") as f: | |
f.write(json.dumps(log_entry) + "\n") | |
# Optionally save the frame if provided | |
if generated_frame is not None and not is_end_of_session and not is_reset: | |
frame_dir = f"interaction_logs/frames_{client_id}" | |
os.makedirs(frame_dir, exist_ok=True) | |
frame_file = f"{frame_dir}/{timestamp:.6f}.png" | |
# Save the frame as PNG | |
Image.fromarray(generated_frame).save(frame_file) | |