neural-os / main.py
da03
.
0adb69d
raw
history blame
14.6 kB
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from typing import List, Tuple
import numpy as np
from PIL import Image, ImageDraw
import base64
import io
import asyncio
from utils import initialize_model, sample_frame
import torch
import os
import time
from typing import Any, Dict
from ldm.models.diffusion.ddpm import LatentDiffusion, DDIMSampler
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
SCREEN_WIDTH = 512
SCREEN_HEIGHT = 384
NUM_SAMPLING_STEPS = 8
DATA_NORMALIZATION = {
'mean': -0.54,
'std': 6.78,
}
LATENT_DIMS = (4, SCREEN_HEIGHT // 8, SCREEN_WIDTH // 8)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Initialize the model at the start of your application
#model = initialize_model("config_csllm.yaml", "yuntian-deng/computer-model")
model = initialize_model("config_rnn.yaml", "yuntian-deng/computer-model")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
#model = torch.compile(model)
padding_image = torch.zeros(1, SCREEN_HEIGHT // 8, SCREEN_WIDTH // 8, 4)
padding_image = (padding_image - DATA_NORMALIZATION['mean']) / DATA_NORMALIZATION['std']
padding_image = padding_image.to(device)
# Valid keyboard inputs
KEYS = ['\t', '\n', '\r', ' ', '!', '"', '#', '$', '%', '&', "'", '(',
')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7',
'8', '9', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`',
'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o',
'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~',
'accept', 'add', 'alt', 'altleft', 'altright', 'apps', 'backspace',
'browserback', 'browserfavorites', 'browserforward', 'browserhome',
'browserrefresh', 'browsersearch', 'browserstop', 'capslock', 'clear',
'convert', 'ctrl', 'ctrlleft', 'ctrlright', 'decimal', 'del', 'delete',
'divide', 'down', 'end', 'enter', 'esc', 'escape', 'execute', 'f1', 'f10',
'f11', 'f12', 'f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f2', 'f20',
'f21', 'f22', 'f23', 'f24', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9',
'final', 'fn', 'hanguel', 'hangul', 'hanja', 'help', 'home', 'insert', 'junja',
'kana', 'kanji', 'launchapp1', 'launchapp2', 'launchmail',
'launchmediaselect', 'left', 'modechange', 'multiply', 'nexttrack',
'nonconvert', 'num0', 'num1', 'num2', 'num3', 'num4', 'num5', 'num6',
'num7', 'num8', 'num9', 'numlock', 'pagedown', 'pageup', 'pause', 'pgdn',
'pgup', 'playpause', 'prevtrack', 'print', 'printscreen', 'prntscrn',
'prtsc', 'prtscr', 'return', 'right', 'scrolllock', 'select', 'separator',
'shift', 'shiftleft', 'shiftright', 'sleep', 'space', 'stop', 'subtract', 'tab',
'up', 'volumedown', 'volumemute', 'volumeup', 'win', 'winleft', 'winright', 'yen',
'command', 'option', 'optionleft', 'optionright']
INVALID_KEYS = ['f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f20',
'f21', 'f22', 'f23', 'f24', 'select', 'separator', 'execute']
VALID_KEYS = [key for key in KEYS if key not in INVALID_KEYS]
itos = VALID_KEYS
stoi = {key: i for i, key in enumerate(itos)}
app = FastAPI()
# Mount the static directory to serve HTML, JavaScript, and CSS files
app.mount("/static", StaticFiles(directory="static"), name="static")
# Add this at the top with other global variables
def prepare_model_inputs(
previous_frame: torch.Tensor,
hidden_states: Any,
x: int,
y: int,
right_click: bool,
left_click: bool,
keys_down: List[str],
stoi: Dict[str, int],
itos: List[str],
time_step: int
) -> Dict[str, torch.Tensor]:
"""Prepare inputs for the model."""
inputs = {
'image_features': previous_frame.to(device),
'is_padding': torch.BoolTensor([time_step == 0]).to(device),
'x': torch.LongTensor([x if x is not None else 0]).unsqueeze(0).to(device),
'y': torch.LongTensor([y if y is not None else 0]).unsqueeze(0).to(device),
'is_leftclick': torch.BoolTensor([left_click]).unsqueeze(0).to(device),
'is_rightclick': torch.BoolTensor([right_click]).unsqueeze(0).to(device),
'key_events': torch.zeros(len(itos), dtype=torch.long).to(device)
}
for key in keys_down:
key = key.lower()
inputs['key_events'][stoi[key]] = 1
if hidden_states is not None:
inputs['hidden_states'] = hidden_states
return inputs
@torch.no_grad()
def process_frame(
model: LatentDiffusion,
inputs: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, np.ndarray, Any, Dict[str, float]]:
"""Process a single frame through the model."""
timing = {}
# Temporal encoding
start = time.perf_counter()
output_from_rnn, hidden_states = model.temporal_encoder.forward_step(inputs)
timing['temporal_encoder'] = time.perf_counter() - start
# UNet sampling
start = time.perf_counter()
sampler = DDIMSampler(model)
sample_latent, _ = sampler.sample(
S=NUM_SAMPLING_STEPS,
conditioning={'c_concat': output_from_rnn},
batch_size=1,
shape=LATENT_DIMS,
verbose=False
)
timing['unet'] = time.perf_counter() - start
# Decoding
start = time.perf_counter()
sample = sample_latent * DATA_NORMALIZATION['std'] + DATA_NORMALIZATION['mean']
sample = model.decode_first_stage(sample)
sample = sample.squeeze(0).clamp(-1, 1)
timing['decode'] = time.perf_counter() - start
# Convert to image
sample_img = ((sample[:3].transpose(0,1).transpose(1,2).cpu().float().numpy() + 1) * 127.5).astype(np.uint8)
timing['total'] = sum(timing.values())
return sample_latent, sample_img, hidden_states, timing
def print_timing_stats(timing_info: Dict[str, float], frame_num: int):
"""Print timing statistics for a frame."""
print(f"\nFrame {frame_num} timing (seconds):")
for key, value in timing_info.items():
print(f" {key.title()}: {value:.4f}")
print(f" FPS: {1.0/timing_info['full_frame']:.2f}")
# Serve the index.html file at the root URL
@app.get("/")
async def get():
return HTMLResponse(open("static/index.html").read())
# WebSocket endpoint for continuous user interaction
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
client_id = id(websocket) # Use a unique identifier for each connection
print(f"New WebSocket connection: {client_id}")
await websocket.accept()
try:
previous_frame = padding_image
hidden_states = None
keys_down = set() # Initialize as an empty set
frame_num = -1
# Start timing for global FPS calculation
connection_start_time = time.perf_counter()
frame_count = 0
# Input queue management
input_queue = []
is_processing = False
async def process_input(data):
nonlocal previous_frame, hidden_states, keys_down, frame_num, frame_count, is_processing
try:
process_start_time = time.perf_counter()
print(f"[{process_start_time:.3f}] Starting to process input. Queue size before: {len(input_queue)}")
is_processing = True
frame_num += 1
frame_count += 1 # Increment total frame counter
# Calculate global FPS
total_elapsed = process_start_time - connection_start_time
global_fps = frame_count / total_elapsed if total_elapsed > 0 else 0
x = data.get("x")
y = data.get("y")
is_left_click = data.get("is_left_click")
is_right_click = data.get("is_right_click")
keys_down_list = data.get("keys_down", []) # Get as list
keys_up_list = data.get("keys_up", [])
print(f'[{time.perf_counter():.3f}] Processing: x: {x}, y: {y}, is_left_click: {is_left_click}, is_right_click: {is_right_click}, keys_down_list: {keys_down_list}, keys_up_list: {keys_up_list}')
# Update the set based on the received data
for key in keys_down_list:
keys_down.add(key)
for key in keys_up_list:
if key in keys_down: # Check if key exists to avoid KeyError
keys_down.remove(key)
inputs = prepare_model_inputs(previous_frame, hidden_states, x, y, is_right_click, is_left_click, list(keys_down), stoi, itos, frame_num)
print(f"[{time.perf_counter():.3f}] Starting model inference...")
previous_frame, sample_img, hidden_states, timing_info = process_frame(model, inputs)
timing_info['full_frame'] = time.perf_counter() - process_start_time
print(f"[{time.perf_counter():.3f}] Model inference complete. Queue size now: {len(input_queue)}")
# Use the provided function to print timing statistics
print_timing_stats(timing_info, frame_num)
# Print global FPS measurement
print(f" Global FPS: {global_fps:.2f} (total: {frame_count} frames in {total_elapsed:.2f}s)")
img = Image.fromarray(sample_img)
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
# Send the generated frame back to the client
print(f"[{time.perf_counter():.3f}] Sending image to client...")
await websocket.send_json({"image": img_str})
print(f"[{time.perf_counter():.3f}] Image sent. Queue size before next_input: {len(input_queue)}")
finally:
is_processing = False
print(f"[{time.perf_counter():.3f}] Processing complete. Queue size before checking next input: {len(input_queue)}")
# Check if we have more inputs to process after this one
process_next_input()
def process_next_input():
nonlocal input_queue
current_time = time.perf_counter()
if not input_queue:
print(f"[{current_time:.3f}] No inputs to process. Queue is empty.")
return
if is_processing:
print(f"[{current_time:.3f}] Already processing an input. Will check again later.")
return
print(f"[{current_time:.3f}] Processing next input. Queue size: {len(input_queue)}")
# Find the most recent interesting input (click or key event)
interesting_indices = [i for i, data in enumerate(input_queue)
if data.get("is_left_click") or
data.get("is_right_click") or
(data.get("keys_down") and len(data.get("keys_down")) > 0) or
(data.get("keys_up") and len(data.get("keys_up")) > 0)]
if interesting_indices:
# There are interesting events - take the most recent one
idx = interesting_indices[-1]
next_input = input_queue[idx]
skipped = idx # Number of events we're skipping
# Clear all inputs up to and including this one
input_queue = input_queue[idx+1:]
print(f"[{current_time:.3f}] Processing interesting input (skipped {skipped} events). Queue size now: {len(input_queue)}")
else:
# No interesting events - just take the most recent movement
skipped = len(input_queue) - 1 # We're processing one, so skipped = total - 1
next_input = input_queue[-1]
input_queue = []
print(f"[{current_time:.3f}] Processing latest movement (skipped {skipped} events). Queue now empty.")
# Process the selected input asynchronously
print(f"[{current_time:.3f}] Creating task to process input...")
asyncio.create_task(process_input(next_input))
while True:
try:
# Receive user input
print(f"[{time.perf_counter():.3f}] Waiting for input... Queue size: {len(input_queue)}, is_processing: {is_processing}")
data = await websocket.receive_json()
receive_time = time.perf_counter()
if data.get("type") == "heartbeat":
await websocket.send_json({"type": "heartbeat_response"})
continue
# Add the input to our queue
input_queue.append(data)
print(f"[{receive_time:.3f}] Received input. Queue size now: {len(input_queue)}")
# If we're not currently processing, start processing this input
if not is_processing:
print(f"[{receive_time:.3f}] Not currently processing, will call process_next_input()")
process_next_input()
else:
print(f"[{receive_time:.3f}] Currently processing, new input queued for later")
except asyncio.TimeoutError:
print("WebSocket connection timed out")
except WebSocketDisconnect:
print("WebSocket disconnected")
break
except Exception as e:
print(f"Error in WebSocket connection {client_id}: {e}")
import traceback
traceback.print_exc()
finally:
# Print final FPS statistics when connection ends
if frame_num >= 0: # Only if we processed at least one frame
total_time = time.perf_counter() - connection_start_time
print(f"\nConnection {client_id} summary:")
print(f" Total frames processed: {frame_count}")
print(f" Total elapsed time: {total_time:.2f} seconds")
print(f" Average FPS: {frame_count/total_time:.2f}")
print(f"WebSocket connection closed: {client_id}")