Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
from fastapi.responses import HTMLResponse | |
from fastapi.staticfiles import StaticFiles | |
from typing import List, Tuple | |
import numpy as np | |
from PIL import Image, ImageDraw | |
import base64 | |
import io | |
import asyncio | |
from utils import initialize_model, sample_frame | |
import torch | |
import os | |
import time | |
DEBUG = True | |
app = FastAPI() | |
# Mount the static directory to serve HTML, JavaScript, and CSS files | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
# Serve the index.html file at the root URL | |
async def get(): | |
return HTMLResponse(open("static/index.html").read()) | |
def generate_random_image(width: int, height: int) -> np.ndarray: | |
return np.random.randint(0, 256, (height, width, 3), dtype=np.uint8) | |
def draw_trace(image: np.ndarray, previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray: | |
pil_image = Image.fromarray(image) | |
draw = ImageDraw.Draw(pil_image) | |
for i, (action_type, position) in enumerate(previous_actions): | |
color = (255, 0, 0) if action_type == "move" else (0, 255, 0) | |
x, y = position | |
if DEBUG: | |
x = x * 256 / 1024 | |
y = y * 256 / 1024 | |
draw.ellipse([x-2, y-2, x+2, y+2], fill=color) | |
if i > 0: | |
#prev_x, prev_y = previous_actions[i-1][1] | |
draw.line([prev_x, prev_y, x, y], fill=color, width=1) | |
prev_x, prev_y = x, y | |
return np.array(pil_image) | |
# Initialize the model at the start of your application | |
model = initialize_model("config_csllm.yaml", "yuntian-deng/computer-model") | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = model.to(device) | |
def load_initial_images(width, height): | |
initial_images = [] | |
for i in range(7): | |
initial_images.append(np.zeros((height, width, 3), dtype=np.uint8)) | |
#image_path = f"image_{i}.png" | |
#if os.path.exists(image_path): | |
# img = Image.open(image_path).resize((width, height)) | |
# initial_images.append(np.array(img)) | |
#else: | |
# print(f"Warning: {image_path} not found. Using blank image instead.") | |
# initial_images.append(np.zeros((height, width, 3), dtype=np.uint8)) | |
return initial_images | |
def normalize_images(images, target_range=(-1, 1)): | |
images = np.stack(images).astype(np.float32) | |
if target_range == (-1, 1): | |
return images / 127.5 - 1 | |
elif target_range == (0, 1): | |
return images / 255.0 | |
else: | |
raise ValueError(f"Unsupported target range: {target_range}") | |
def denormalize_image(image, source_range=(-1, 1)): | |
if source_range == (-1, 1): | |
return ((image + 1) * 127.5).clip(0, 255).astype(np.uint8) | |
elif source_range == (0, 1): | |
return (image * 255).clip(0, 255).astype(np.uint8) | |
else: | |
raise ValueError(f"Unsupported source range: {source_range}") | |
def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray: | |
width, height = 256, 256 | |
initial_images = load_initial_images(width, height) | |
# Prepare the image sequence for the model | |
image_sequence = previous_frames[-7:] # Take the last 7 frames | |
while len(image_sequence) < 7: | |
image_sequence.insert(0, initial_images[len(image_sequence)]) | |
# Convert the image sequence to a tensor and concatenate in the channel dimension | |
image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence, target_range=(-1, 1))) | |
image_sequence_tensor = image_sequence_tensor.to(device) | |
# Prepare the prompt based on the previous actions | |
action_descriptions = [] | |
initial_actions = ['901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '921:604'] | |
initial_actions = ['0:0'] * 7 | |
def unnorm_coords(x, y): | |
return int(x), int(y) #int(x - (1920 - 256) / 2), int(y - (1080 - 256) / 2) | |
# Process initial actions if there are not enough previous actions | |
while len(previous_actions) < 8: | |
x, y = map(int, initial_actions.pop(0).split(':')) | |
previous_actions.insert(0, ("move", unnorm_coords(x, y))) | |
prev_x = 0 | |
prev_y = 0 | |
for action_type, pos in previous_actions: #[-8:]: | |
if action_type == "move": | |
x, y = pos | |
norm_x = int(round(x / 256 * 1024)) #x + (1920 - 256) / 2 | |
norm_y = int(round(y / 256 * 640)) #y + (1080 - 256) / 2 | |
if DEBUG: | |
norm_x = x | |
norm_y = y | |
action_descriptions.append(f"{(norm_x-prev_x):.0f}~{(norm_y-prev_y):.0f}") | |
prev_x = norm_x | |
prev_y = norm_y | |
elif action_type == "left_click": | |
action_descriptions.append("left_click") | |
elif action_type == "right_click": | |
action_descriptions.append("right_click") | |
prompt = " ".join(action_descriptions[-8:]) | |
print(prompt) | |
# Generate the next frame | |
new_frame = sample_frame(model, prompt, image_sequence_tensor) | |
# Convert the generated frame to the correct format | |
new_frame = new_frame.transpose(1, 2, 0) | |
print (new_frame.max(), new_frame.min()) | |
new_frame_denormalized = denormalize_image(new_frame, source_range=(-1, 1)) | |
# Draw the trace of previous actions | |
new_frame_with_trace = draw_trace(new_frame_denormalized, previous_actions) | |
return new_frame_with_trace, new_frame_denormalized | |
# WebSocket endpoint for continuous user interaction | |
async def websocket_endpoint(websocket: WebSocket): | |
client_id = id(websocket) # Use a unique identifier for each connection | |
print(f"New WebSocket connection: {client_id}") | |
await websocket.accept() | |
previous_frames = [] | |
previous_actions = [] | |
positions = ['496~61', '815~335', '815~335', '815~335', '787~342', '749~345', '749~345', '703~346', '703~346', '654~347', '604~349', '604~349', '555~353', '509~357', '509~357'] | |
positions = ['815~335', '787~342', '787~342', '749~345', '703~346', '703~346', '654~347', '654~347', '604~349', '555~353', '555~353', '509~357', '509~357', '468~362', '431~368', '431~368'] | |
try: | |
while True: | |
try: | |
# Receive user input with a timeout | |
data = await asyncio.wait_for(websocket.receive_json(), timeout=90.0) | |
if data.get("type") == "heartbeat": | |
await websocket.send_json({"type": "heartbeat_response"}) | |
continue | |
action_type = data.get("action_type") | |
mouse_position = data.get("mouse_position") | |
# Store the actions | |
if DEBUG: | |
position = positions[0] | |
positions = positions[1:] | |
mouse_position = position.split('~') | |
mouse_position = [int(item) for item in mouse_position] | |
previous_actions.append((action_type, mouse_position)) | |
# Log the start time | |
start_time = time.time() | |
# Predict the next frame based on the previous frames and actions | |
next_frame, next_frame_append = predict_next_frame(previous_frames, previous_actions) | |
previous_frames.append(next_frame_append) | |
# Convert the numpy array to a base64 encoded image | |
img = Image.fromarray(next_frame) | |
buffered = io.BytesIO() | |
img.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()).decode() | |
# Log the processing time | |
processing_time = time.time() - start_time | |
print(f"Frame processing time: {processing_time:.2f} seconds") | |
# Send the generated frame back to the client | |
await websocket.send_json({"image": img_str}) | |
except asyncio.TimeoutError: | |
print("WebSocket connection timed out") | |
break | |
except WebSocketDisconnect: | |
print("WebSocket disconnected") | |
break | |
except Exception as e: | |
print(f"Error in WebSocket connection {client_id}: {e}") | |
finally: | |
print(f"WebSocket connection closed: {client_id}") | |
# Remove the explicit websocket.close() call here | |