from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from typing import List, Tuple import numpy as np from PIL import Image, ImageDraw import base64 import io import json import asyncio from utils import initialize_model, sample_frame import torch import os import time from typing import Any, Dict from ldm.models.diffusion.ddpm import LatentDiffusion, DDIMSampler import concurrent.futures torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') DEBUG_MODE = False DEBUG_MODE_2 = False NUM_MAX_FRAMES = 1 SCREEN_WIDTH = 512 SCREEN_HEIGHT = 384 NUM_SAMPLING_STEPS = 32 MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-384k" MODEL_NAME = "yuntian-deng/computer-model-noss-forsure" MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-2k" MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-10k" MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-54k" MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-unfreezernn-160k" MODEL_NAME = "yuntian-deng/computer-model-ss005-cont-lr2e5-computecanada-newnewd-freezernn-origunet-nospatial-368k" print (f'setting: DEBUG_MODE: {DEBUG_MODE}, DEBUG_MODE_2: {DEBUG_MODE_2}, NUM_MAX_FRAMES: {NUM_MAX_FRAMES}, NUM_SAMPLING_STEPS: {NUM_SAMPLING_STEPS}, MODEL_NAME: {MODEL_NAME}') with open('latent_stats.json', 'r') as f: latent_stats = json.load(f) DATA_NORMALIZATION = {'mean': torch.tensor(latent_stats['mean']).to(device), 'std': torch.tensor(latent_stats['std']).to(device)} LATENT_DIMS = (16, SCREEN_HEIGHT // 8, SCREEN_WIDTH // 8) # Initialize the model at the start of your application #model = initialize_model("config_csllm.yaml", "yuntian-deng/computer-model") #model = initialize_model("config_rnn.yaml", "yuntian-deng/computer-model") #model = initialize_model("config_final_model.yaml", "yuntian-deng/computer-model-noss") #model = initialize_model("config_final_model.yaml", "yuntian-deng/computer-model") #model = initialize_model("config_final_model.yaml", MODEL_NAME) model = initialize_model("config_final_model_origunet_nospatial.yaml", MODEL_NAME) model = model.to(device) #model = torch.compile(model) padding_image = torch.zeros(*LATENT_DIMS).unsqueeze(0).to(device) padding_image = (padding_image - DATA_NORMALIZATION['mean'].view(1, -1, 1, 1)) / DATA_NORMALIZATION['std'].view(1, -1, 1, 1) # Valid keyboard inputs KEYS = ['\t', '\n', '\r', ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', 'accept', 'add', 'alt', 'altleft', 'altright', 'apps', 'backspace', 'browserback', 'browserfavorites', 'browserforward', 'browserhome', 'browserrefresh', 'browsersearch', 'browserstop', 'capslock', 'clear', 'convert', 'ctrl', 'ctrlleft', 'ctrlright', 'decimal', 'del', 'delete', 'divide', 'down', 'end', 'enter', 'esc', 'escape', 'execute', 'f1', 'f10', 'f11', 'f12', 'f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f2', 'f20', 'f21', 'f22', 'f23', 'f24', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9', 'final', 'fn', 'hanguel', 'hangul', 'hanja', 'help', 'home', 'insert', 'junja', 'kana', 'kanji', 'launchapp1', 'launchapp2', 'launchmail', 'launchmediaselect', 'left', 'modechange', 'multiply', 'nexttrack', 'nonconvert', 'num0', 'num1', 'num2', 'num3', 'num4', 'num5', 'num6', 'num7', 'num8', 'num9', 'numlock', 'pagedown', 'pageup', 'pause', 'pgdn', 'pgup', 'playpause', 'prevtrack', 'print', 'printscreen', 'prntscrn', 'prtsc', 'prtscr', 'return', 'right', 'scrolllock', 'select', 'separator', 'shift', 'shiftleft', 'shiftright', 'sleep', 'space', 'stop', 'subtract', 'tab', 'up', 'volumedown', 'volumemute', 'volumeup', 'win', 'winleft', 'winright', 'yen', 'command', 'option', 'optionleft', 'optionright'] INVALID_KEYS = ['f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f20', 'f21', 'f22', 'f23', 'f24', 'select', 'separator', 'execute'] VALID_KEYS = [key for key in KEYS if key not in INVALID_KEYS] itos = VALID_KEYS stoi = {key: i for i, key in enumerate(itos)} app = FastAPI() # Mount the static directory to serve HTML, JavaScript, and CSS files app.mount("/static", StaticFiles(directory="static"), name="static") # Add this at the top with other global variables connection_counter = 0 # Create a thread pool executor thread_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) def prepare_model_inputs( previous_frame: torch.Tensor, hidden_states: Any, x: int, y: int, right_click: bool, left_click: bool, keys_down: List[str], stoi: Dict[str, int], itos: List[str], time_step: int ) -> Dict[str, torch.Tensor]: """Prepare inputs for the model.""" # Clamp coordinates to valid ranges x = min(max(0, x), SCREEN_WIDTH - 1) if x is not None else 0 y = min(max(0, y), SCREEN_HEIGHT - 1) if y is not None else 0 if DEBUG_MODE: print ('DEBUG MODE, SETTING TIME STEP TO 0') time_step = 0 if DEBUG_MODE_2: if time_step > NUM_MAX_FRAMES-1: print ('DEBUG MODE_2, SETTING TIME STEP TO 0') time_step = 0 inputs = { 'image_features': previous_frame.to(device), 'is_padding': torch.BoolTensor([time_step == 0]).to(device), 'x': torch.LongTensor([x]).unsqueeze(0).to(device), 'y': torch.LongTensor([y]).unsqueeze(0).to(device), 'is_leftclick': torch.BoolTensor([left_click]).unsqueeze(0).to(device), 'is_rightclick': torch.BoolTensor([right_click]).unsqueeze(0).to(device), 'key_events': torch.zeros(len(itos), dtype=torch.long).to(device) } for key in keys_down: key = key.lower() if key in stoi: inputs['key_events'][stoi[key]] = 1 else: print (f'Key {key} not found in stoi') if hidden_states is not None: inputs['hidden_states'] = hidden_states if DEBUG_MODE: print ('DEBUG MODE, REMOVING INPUTS') if 'hidden_states' in inputs: del inputs['hidden_states'] if DEBUG_MODE_2: if time_step > NUM_MAX_FRAMES-1: print ('DEBUG MODE_2, REMOVING HIDDEN STATES') if 'hidden_states' in inputs: del inputs['hidden_states'] print (f'Time step: {time_step}') return inputs @torch.no_grad() async def process_frame( model: LatentDiffusion, inputs: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, np.ndarray, Any, Dict[str, float]]: """Process a single frame through the model.""" # Run the heavy computation in a separate thread loop = asyncio.get_running_loop() return await loop.run_in_executor( thread_executor, lambda: _process_frame_sync(model, inputs) ) def _process_frame_sync(model, inputs): """Synchronous version of process_frame that runs in a thread""" timing = {} # Temporal encoding start = time.perf_counter() output_from_rnn, hidden_states = model.temporal_encoder.forward_step(inputs) timing['temporal_encoder'] = time.perf_counter() - start # UNet sampling start = time.perf_counter() use_rnn = False print (f"use_rnn: {use_rnn}, NUM_SAMPLING_STEPS: {NUM_SAMPLING_STEPS}") if use_rnn: sample_latent = output_from_rnn[:, :16] else: #NUM_SAMPLING_STEPS = 8 if NUM_SAMPLING_STEPS >= 1000: sample_latent = model.p_sample_loop(cond={'c_concat': output_from_rnn}, shape=[1, *LATENT_DIMS], return_intermediates=False, verbose=True) else: sampler = DDIMSampler(model) sample_latent, _ = sampler.sample( S=NUM_SAMPLING_STEPS, conditioning={'c_concat': output_from_rnn}, batch_size=1, shape=LATENT_DIMS, verbose=False ) timing['unet'] = time.perf_counter() - start # Decoding start = time.perf_counter() sample = sample_latent * DATA_NORMALIZATION['std'].view(1, -1, 1, 1) + DATA_NORMALIZATION['mean'].view(1, -1, 1, 1) # Use time.sleep(10) here since it's in a separate thread #time.sleep(10) sample = model.decode_first_stage(sample) sample = sample.squeeze(0).clamp(-1, 1) timing['decode'] = time.perf_counter() - start # Convert to image sample_img = ((sample[:3].transpose(0,1).transpose(1,2).cpu().float().numpy() + 1) * 127.5).astype(np.uint8) timing['total'] = sum(timing.values()) return sample_latent, sample_img, hidden_states, timing def print_timing_stats(timing_info: Dict[str, float], frame_num: int): """Print timing statistics for a frame.""" print(f"\nFrame {frame_num} timing (seconds):") for key, value in timing_info.items(): print(f" {key.title()}: {value:.4f}") print(f" FPS: {1.0/timing_info['full_frame']:.2f}") # Serve the index.html file at the root URL @app.get("/") async def get(): return HTMLResponse(open("static/index.html").read()) # WebSocket endpoint for continuous user interaction @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): global connection_counter connection_counter += 1 client_id = f"{int(time.time())}_{connection_counter}" print(f"New WebSocket connection: {client_id}") await websocket.accept() try: previous_frame = padding_image hidden_states = None keys_down = set() # Initialize as an empty set frame_num = -1 # Start timing for global FPS calculation connection_start_time = time.perf_counter() frame_count = 0 # Input queue management - use asyncio.Queue instead of a list input_queue = asyncio.Queue() is_processing = False # Add a function to reset the simulation async def reset_simulation(): nonlocal previous_frame, hidden_states, keys_down, frame_num, is_processing, input_queue # Log the reset action log_interaction( client_id, {"type": "reset"}, is_end_of_session=False, is_reset=True # Add this parameter to the log_interaction function ) # Clear the input queue while not input_queue.empty(): try: input_queue.get_nowait() input_queue.task_done() except asyncio.QueueEmpty: break # Reset all state variables previous_frame = padding_image hidden_states = None keys_down = set() frame_num = -1 is_processing = False print(f"[{time.perf_counter():.3f}] Simulation reset to initial state") # Send confirmation to client await websocket.send_json({"type": "reset_confirmed"}) # Add a function to update sampling steps async def update_sampling_steps(steps): global NUM_SAMPLING_STEPS # Validate the input if steps < 1: print(f"[{time.perf_counter():.3f}] Invalid sampling steps value: {steps}") await websocket.send_json({"type": "error", "message": "Invalid sampling steps value"}) return # Update the global variable old_steps = NUM_SAMPLING_STEPS NUM_SAMPLING_STEPS = steps print(f"[{time.perf_counter():.3f}] Updated NUM_SAMPLING_STEPS from {old_steps} to {steps}") # Send confirmation to client await websocket.send_json({"type": "steps_updated", "steps": steps}) async def process_input(data): nonlocal previous_frame, hidden_states, keys_down, frame_num, frame_count, is_processing try: process_start_time = time.perf_counter() queue_size = input_queue.qsize() print(f"[{process_start_time:.3f}] Starting to process input. Queue size before: {queue_size}") frame_num += 1 frame_count += 1 # Increment total frame counter # Calculate global FPS total_elapsed = process_start_time - connection_start_time global_fps = frame_count / total_elapsed if total_elapsed > 0 else 0 x = data.get("x") y = data.get("y") is_left_click = data.get("is_left_click") is_right_click = data.get("is_right_click") keys_down_list = data.get("keys_down", []) # Get as list keys_up_list = data.get("keys_up", []) print(f'[{time.perf_counter():.3f}] Processing: x: {x}, y: {y}, is_left_click: {is_left_click}, is_right_click: {is_right_click}, keys_down_list: {keys_down_list}, keys_up_list: {keys_up_list}') # Update the set based on the received data for key in keys_down_list: keys_down.add(key) for key in keys_up_list: if key in keys_down: # Check if key exists to avoid KeyError keys_down.remove(key) if DEBUG_MODE: print (f"DEBUG MODE, REMOVING HIDDEN STATES") previous_frame = padding_image if DEBUG_MODE_2: print (f'dsfdasdf frame_num: {frame_num}') if frame_num > NUM_MAX_FRAMES-1: print (f"DEBUG MODE_2, REMOVING HIDDEN STATES") previous_frame = padding_image frame_num = 0 inputs = prepare_model_inputs(previous_frame, hidden_states, x, y, is_right_click, is_left_click, list(keys_down), stoi, itos, frame_num) print(f"[{time.perf_counter():.3f}] Starting model inference...") previous_frame, sample_img, hidden_states, timing_info = await process_frame(model, inputs) print (f'aaa setting: DEBUG_MODE: {DEBUG_MODE}, DEBUG_MODE_2: {DEBUG_MODE_2}, NUM_MAX_FRAMES: {NUM_MAX_FRAMES}, NUM_SAMPLING_STEPS: {NUM_SAMPLING_STEPS}') timing_info['full_frame'] = time.perf_counter() - process_start_time print(f"[{time.perf_counter():.3f}] Model inference complete. Queue size now: {input_queue.qsize()}") # Use the provided function to print timing statistics print_timing_stats(timing_info, frame_num) # Print global FPS measurement print(f" Global FPS: {global_fps:.2f} (total: {frame_count} frames in {total_elapsed:.2f}s)") img = Image.fromarray(sample_img) buffered = io.BytesIO() img.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() # Send the generated frame back to the client print(f"[{time.perf_counter():.3f}] Sending image to client...") await websocket.send_json({"image": img_str}) print(f"[{time.perf_counter():.3f}] Image sent. Queue size before next_input: {input_queue.qsize()}") # Log the input log_interaction(client_id, data, generated_frame=sample_img) finally: is_processing = False print(f"[{time.perf_counter():.3f}] Processing complete. Queue size before checking next input: {input_queue.qsize()}") # Check if we have more inputs to process after this one if not input_queue.empty(): print(f"[{time.perf_counter():.3f}] Queue not empty, processing next input") asyncio.create_task(process_next_input()) async def process_next_input(): nonlocal is_processing current_time = time.perf_counter() if input_queue.empty(): print(f"[{current_time:.3f}] No inputs to process. Queue is empty.") is_processing = False return #if is_processing: # print(f"[{current_time:.3f}] Already processing an input. Will check again later.") # return # Set is_processing to True before proceeding is_processing = True queue_size = input_queue.qsize() print(f"[{current_time:.3f}] Processing next input. Queue size: {queue_size}") try: # Initialize variables to track progress skipped = 0 latest_input = None # Process the queue one item at a time while not input_queue.empty(): current_input = await input_queue.get() input_queue.task_done() # Always update the latest input latest_input = current_input # Check if this is an interesting event is_interesting = (current_input.get("is_left_click") or current_input.get("is_right_click") or (current_input.get("keys_down") and len(current_input.get("keys_down")) > 0) or (current_input.get("keys_up") and len(current_input.get("keys_up")) > 0)) # Process immediately if interesting if is_interesting: print(f"[{current_time:.3f}] Found interesting input (skipped {skipped} events)") await process_input(current_input) # AWAIT here instead of creating a task is_processing = False return # Otherwise, continue to the next item skipped += 1 # If this is the last item and no interesting inputs were found if input_queue.empty(): print(f"[{current_time:.3f}] No interesting inputs, processing latest movement (skipped {skipped-1} events)") await process_input(latest_input) # AWAIT here instead of creating a task is_processing = False return except Exception as e: print(f"[{current_time:.3f}] Error in process_next_input: {e}") import traceback traceback.print_exc() is_processing = False # Make sure to reset on error while True: try: # Receive user input print(f"[{time.perf_counter():.3f}] Waiting for input... Queue size: {input_queue.qsize()}, is_processing: {is_processing}") data = await websocket.receive_json() receive_time = time.perf_counter() if data.get("type") == "heartbeat": await websocket.send_json({"type": "heartbeat_response"}) continue # Handle reset command if data.get("type") == "reset": print(f"[{receive_time:.3f}] Received reset command") await reset_simulation() continue # Handle sampling steps update if data.get("type") == "update_sampling_steps": print(f"[{receive_time:.3f}] Received request to update sampling steps") await update_sampling_steps(data.get("steps", 32)) continue # Add the input to our queue await input_queue.put(data) print(f"[{receive_time:.3f}] Received input. Queue size now: {input_queue.qsize()}") # If we're not currently processing, start processing this input if not is_processing: print(f"[{receive_time:.3f}] Not currently processing, will call process_next_input()") is_processing = True asyncio.create_task(process_next_input()) # Create task but don't await it else: print(f"[{receive_time:.3f}] Currently processing, new input queued for later") except asyncio.TimeoutError: print("WebSocket connection timed out") except WebSocketDisconnect: # Log final EOS entry log_interaction(client_id, {}, is_end_of_session=True) print(f"WebSocket disconnected: {client_id}") break except Exception as e: print(f"Error in WebSocket connection {client_id}: {e}") import traceback traceback.print_exc() finally: # Print final FPS statistics when connection ends if frame_num >= 0: # Only if we processed at least one frame total_time = time.perf_counter() - connection_start_time print(f"\nConnection {client_id} summary:") print(f" Total frames processed: {frame_count}") print(f" Total elapsed time: {total_time:.2f} seconds") print(f" Average FPS: {frame_count/total_time:.2f}") print(f"WebSocket connection closed: {client_id}") def log_interaction(client_id, data, generated_frame=None, is_end_of_session=False, is_reset=False): """Log user interaction and optionally the generated frame.""" timestamp = time.time() # Create directory structure if it doesn't exist os.makedirs("interaction_logs", exist_ok=True) # Structure the log entry log_entry = { "timestamp": timestamp, "client_id": client_id, "is_eos": is_end_of_session, "is_reset": is_reset } # Include type if present (for reset, etc.) if data.get("type"): log_entry["type"] = data.get("type") # Only include input data if this isn't just a control message if not is_end_of_session and not is_reset: log_entry["inputs"] = { "x": data.get("x"), "y": data.get("y"), "is_left_click": data.get("is_left_click"), "is_right_click": data.get("is_right_click"), "keys_down": data.get("keys_down", []), "keys_up": data.get("keys_up", []) } else: # For EOS/reset records, just include minimal info log_entry["inputs"] = None # Save to a file (one file per session) session_file = f"interaction_logs/session_{client_id}.jsonl" with open(session_file, "a") as f: f.write(json.dumps(log_entry) + "\n") # Optionally save the frame if provided if generated_frame is not None and not is_end_of_session and not is_reset: frame_dir = f"interaction_logs/frames_{client_id}" os.makedirs(frame_dir, exist_ok=True) frame_file = f"{frame_dir}/{timestamp:.6f}.png" # Save the frame as PNG Image.fromarray(generated_frame).save(frame_file)