Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
from fastapi.responses import HTMLResponse | |
from fastapi.staticfiles import StaticFiles | |
from typing import List, Tuple | |
import numpy as np | |
from PIL import Image, ImageDraw | |
import base64 | |
import io | |
import asyncio | |
from utils import initialize_model, sample_frame | |
import torch | |
import os | |
import time | |
from typing import Any, Dict | |
from ldm.models.diffusion.ddpm import LatentDiffusion, DDIMSampler | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
SCREEN_WIDTH = 512 | |
SCREEN_HEIGHT = 384 | |
NUM_SAMPLING_STEPS = 8 | |
DATA_NORMALIZATION = { | |
'mean': -0.54, | |
'std': 6.78, | |
} | |
LATENT_DIMS = (4, SCREEN_HEIGHT // 8, SCREEN_WIDTH // 8) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Initialize the model at the start of your application | |
#model = initialize_model("config_csllm.yaml", "yuntian-deng/computer-model") | |
model = initialize_model("config_rnn_measure_latency.yaml", "yuntian-deng/computer-model") | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = model.to(device) | |
#model = torch.compile(model) | |
padding_image = torch.zeros(1, SCREEN_HEIGHT // 8, SCREEN_WIDTH // 8, 4) | |
padding_image = (padding_image - DATA_NORMALIZATION['mean']) / DATA_NORMALIZATION['std'] | |
padding_image = padding_image.to(device) | |
# Valid keyboard inputs | |
KEYS = ['\t', '\n', '\r', ' ', '!', '"', '#', '$', '%', '&', "'", '(', | |
')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', | |
'8', '9', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', | |
'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', | |
'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', | |
'accept', 'add', 'alt', 'altleft', 'altright', 'apps', 'backspace', | |
'browserback', 'browserfavorites', 'browserforward', 'browserhome', | |
'browserrefresh', 'browsersearch', 'browserstop', 'capslock', 'clear', | |
'convert', 'ctrl', 'ctrlleft', 'ctrlright', 'decimal', 'del', 'delete', | |
'divide', 'down', 'end', 'enter', 'esc', 'escape', 'execute', 'f1', 'f10', | |
'f11', 'f12', 'f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f2', 'f20', | |
'f21', 'f22', 'f23', 'f24', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9', | |
'final', 'fn', 'hanguel', 'hangul', 'hanja', 'help', 'home', 'insert', 'junja', | |
'kana', 'kanji', 'launchapp1', 'launchapp2', 'launchmail', | |
'launchmediaselect', 'left', 'modechange', 'multiply', 'nexttrack', | |
'nonconvert', 'num0', 'num1', 'num2', 'num3', 'num4', 'num5', 'num6', | |
'num7', 'num8', 'num9', 'numlock', 'pagedown', 'pageup', 'pause', 'pgdn', | |
'pgup', 'playpause', 'prevtrack', 'print', 'printscreen', 'prntscrn', | |
'prtsc', 'prtscr', 'return', 'right', 'scrolllock', 'select', 'separator', | |
'shift', 'shiftleft', 'shiftright', 'sleep', 'space', 'stop', 'subtract', 'tab', | |
'up', 'volumedown', 'volumemute', 'volumeup', 'win', 'winleft', 'winright', 'yen', | |
'command', 'option', 'optionleft', 'optionright'] | |
INVALID_KEYS = ['f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f20', | |
'f21', 'f22', 'f23', 'f24', 'select', 'separator', 'execute'] | |
VALID_KEYS = [key for key in KEYS if key not in INVALID_KEYS] | |
itos = VALID_KEYS | |
stoi = {key: i for i, key in enumerate(itos)} | |
app = FastAPI() | |
# Mount the static directory to serve HTML, JavaScript, and CSS files | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
# Add this at the top with other global variables | |
def prepare_model_inputs( | |
previous_frame: torch.Tensor, | |
hidden_states: Any, | |
x: int, | |
y: int, | |
right_click: bool, | |
left_click: bool, | |
keys_down: List[str], | |
stoi: Dict[str, int], | |
itos: List[str], | |
time_step: int | |
) -> Dict[str, torch.Tensor]: | |
"""Prepare inputs for the model.""" | |
inputs = { | |
'image_features': previous_frame.to(device), | |
'is_padding': torch.BoolTensor([time_step == 0]).to(device), | |
'x': torch.LongTensor([x if x is not None else 0]).unsqueeze(0).to(device), | |
'y': torch.LongTensor([y if y is not None else 0]).unsqueeze(0).to(device), | |
'is_leftclick': torch.BoolTensor([left_click]).unsqueeze(0).to(device), | |
'is_rightclick': torch.BoolTensor([right_click]).unsqueeze(0).to(device), | |
'key_events': torch.zeros(len(itos), dtype=torch.long).to(device) | |
} | |
for key in keys_down: | |
key = key.lower() | |
inputs['key_events'][stoi[key]] = 1 | |
if hidden_states is not None: | |
inputs['hidden_states'] = hidden_states | |
return inputs | |
def process_frame( | |
model: LatentDiffusion, | |
inputs: Dict[str, torch.Tensor] | |
) -> Tuple[torch.Tensor, np.ndarray, Any, Dict[str, float]]: | |
"""Process a single frame through the model.""" | |
timing = {} | |
# Temporal encoding | |
start = time.perf_counter() | |
output_from_rnn, hidden_states = model.temporal_encoder.forward_step(inputs) | |
timing['temporal_encoder'] = time.perf_counter() - start | |
# UNet sampling | |
start = time.perf_counter() | |
sampler = DDIMSampler(model) | |
sample_latent, _ = sampler.sample( | |
S=NUM_SAMPLING_STEPS, | |
conditioning={'c_concat': output_from_rnn}, | |
batch_size=1, | |
shape=LATENT_DIMS, | |
verbose=False | |
) | |
timing['unet'] = time.perf_counter() - start | |
# Decoding | |
start = time.perf_counter() | |
sample = sample_latent * DATA_NORMALIZATION['std'] + DATA_NORMALIZATION['mean'] | |
sample = model.decode_first_stage(sample) | |
sample = sample.squeeze(0).clamp(-1, 1) | |
timing['decode'] = time.perf_counter() - start | |
# Convert to image | |
sample_img = ((sample[:3].transpose(0,1).transpose(1,2).cpu().float().numpy() + 1) * 127.5).astype(np.uint8) | |
timing['total'] = sum(timing.values()) | |
return sample_latent, sample_img, hidden_states, timing | |
def print_timing_stats(timing_info: Dict[str, float], frame_num: int): | |
"""Print timing statistics for a frame.""" | |
print(f"\nFrame {frame_num} timing (seconds):") | |
for key, value in timing_info.items(): | |
print(f" {key.title()}: {value:.4f}") | |
print(f" FPS: {1.0/timing_info['full_frame']:.2f}") | |
# Serve the index.html file at the root URL | |
async def get(): | |
return HTMLResponse(open("static/index.html").read()) | |
# WebSocket endpoint for continuous user interaction | |
async def websocket_endpoint(websocket: WebSocket): | |
client_id = id(websocket) # Use a unique identifier for each connection | |
print(f"New WebSocket connection: {client_id}") | |
await websocket.accept() | |
try: | |
previous_frame = padding_image | |
hidden_states = None | |
keys_down = set() # Initialize as an empty set | |
frame_num = -1 | |
# Start timing for global FPS calculation | |
connection_start_time = time.perf_counter() | |
frame_count = 0 | |
while True: | |
try: | |
# Receive user input with a timeout | |
#data = await asyncio.wait_for(websocket.receive_json(), timeout=90000.0) | |
data = await websocket.receive_json() | |
if data.get("type") == "heartbeat": | |
await websocket.send_json({"type": "heartbeat_response"}) | |
continue | |
frame_num += 1 | |
frame_count += 1 # Increment total frame counter | |
start_frame = time.perf_counter() | |
# Calculate global FPS | |
total_elapsed = start_frame - connection_start_time | |
global_fps = frame_count / total_elapsed if total_elapsed > 0 else 0 | |
x = data.get("x") | |
y = data.get("y") | |
is_left_click = data.get("is_left_click") | |
is_right_click = data.get("is_right_click") | |
keys_down_list = data.get("keys_down", []) # Get as list | |
keys_up_list = data.get("keys_up", []) | |
print (f'x: {x}, y: {y}, is_left_click: {is_left_click}, is_right_click: {is_right_click}, keys_down_list: {keys_down_list}, keys_up_list: {keys_up_list}') | |
# Update the set based on the received data | |
for key in keys_down_list: | |
keys_down.add(key) | |
for key in keys_up_list: | |
if key in keys_down: # Check if key exists to avoid KeyError | |
keys_down.remove(key) | |
inputs = prepare_model_inputs(previous_frame, hidden_states, x, y, is_right_click, is_left_click, list(keys_down), stoi, itos, frame_num) | |
previous_frame, sample_img, hidden_states, timing_info = process_frame(model, inputs) | |
timing_info['full_frame'] = time.perf_counter() - start_frame | |
# Use the provided function to print timing statistics | |
print_timing_stats(timing_info, frame_num) | |
# Print global FPS measurement | |
print(f" Global FPS: {global_fps:.2f} (total: {frame_count} frames in {total_elapsed:.2f}s)") | |
img = Image.fromarray(sample_img) | |
buffered = io.BytesIO() | |
img.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()).decode() | |
# Send the generated frame back to the client | |
await websocket.send_json({"image": img_str}) | |
except asyncio.TimeoutError: | |
print("WebSocket connection timed out") | |
#break # Exit the loop on timeout | |
except WebSocketDisconnect: | |
print("WebSocket disconnected") | |
#break # Exit the loop on disconnect | |
except Exception as e: | |
print(f"Error in WebSocket connection {client_id}: {e}") | |
finally: | |
# Print final FPS statistics when connection ends | |
if frame_num >= 0: # Only if we processed at least one frame | |
total_time = time.perf_counter() - connection_start_time | |
print(f"\nConnection {client_id} summary:") | |
print(f" Total frames processed: {frame_count}") | |
print(f" Total elapsed time: {total_time:.2f} seconds") | |
print(f" Average FPS: {frame_count/total_time:.2f}") | |
print(f"WebSocket connection closed: {client_id}") | |
#await websocket.close() # Ensure the WebSocket is closed | |