Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from typing import List, Tuple | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| import base64 | |
| import io | |
| import asyncio | |
| from utils import initialize_model, sample_frame, device | |
| import torch | |
| app = FastAPI() | |
| # Mount the static directory to serve HTML, JavaScript, and CSS files | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| # Serve the index.html file at the root URL | |
| async def get(): | |
| return HTMLResponse(open("static/index.html").read()) | |
| def generate_random_image(width: int, height: int) -> np.ndarray: | |
| return np.random.randint(0, 256, (height, width, 3), dtype=np.uint8) | |
| def draw_trace(image: np.ndarray, previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray: | |
| pil_image = Image.fromarray(image) | |
| draw = ImageDraw.Draw(pil_image) | |
| for i, (action_type, position) in enumerate(previous_actions): | |
| color = (255, 0, 0) if action_type == "move" else (0, 255, 0) | |
| x, y = position | |
| draw.ellipse([x-2, y-2, x+2, y+2], fill=color) | |
| if i > 0: | |
| prev_x, prev_y = previous_actions[i-1][1] | |
| draw.line([prev_x, prev_y, x, y], fill=color, width=1) | |
| return np.array(pil_image) | |
| # Initialize the model at the start of your application | |
| initialize_model("config_csllm.yaml", "yuntian-deng/computer-model") | |
| def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray: | |
| width, height = 256, 256 | |
| # Prepare the image sequence for the model | |
| image_sequence = previous_frames[-7:] # Take the last 7 frames | |
| while len(image_sequence) < 7: | |
| image_sequence.insert(0, np.zeros((height, width, 3), dtype=np.uint8)) | |
| # Convert the image sequence to a tensor | |
| image_sequence_tensor = torch.from_numpy(np.stack(image_sequence)).permute(0, 3, 1, 2).float() / 127.5 - 1 | |
| image_sequence_tensor = image_sequence_tensor.unsqueeze(0).to(device) | |
| # Prepare the prompt based on the previous actions | |
| action_descriptions = [f"{action} at ({pos[0]}, {pos[1]})" for action, pos in previous_actions[-7:]] | |
| prompt = "A sequence of actions: " + ", ".join(action_descriptions) | |
| # Generate the next frame | |
| new_frame = sample_frame(model, prompt, image_sequence_tensor) | |
| # Convert the generated frame to the correct format | |
| new_frame = (new_frame * 255).astype(np.uint8).transpose(1, 2, 0) | |
| # Resize the frame to 256x256 if necessary | |
| if new_frame.shape[:2] != (height, width): | |
| new_frame = np.array(Image.fromarray(new_frame).resize((width, height))) | |
| # Draw the trace of previous actions | |
| new_frame_with_trace = draw_trace(new_frame, previous_actions) | |
| return new_frame_with_trace | |
| # WebSocket endpoint for continuous user interaction | |
| async def websocket_endpoint(websocket: WebSocket): | |
| client_id = id(websocket) # Use a unique identifier for each connection | |
| print(f"New WebSocket connection: {client_id}") | |
| await websocket.accept() | |
| previous_frames = [] | |
| previous_actions = [] | |
| try: | |
| while True: | |
| try: | |
| # Receive user input with a timeout | |
| data = await asyncio.wait_for(websocket.receive_json(), timeout=30.0) | |
| if data.get("type") == "heartbeat": | |
| await websocket.send_json({"type": "heartbeat_response"}) | |
| continue | |
| action_type = data.get("action_type") | |
| mouse_position = data.get("mouse_position") | |
| # Store the actions | |
| previous_actions.append((action_type, mouse_position)) | |
| # Predict the next frame based on the previous frames and actions | |
| next_frame = predict_next_frame(previous_frames, previous_actions) | |
| previous_frames.append(next_frame) | |
| # Convert the numpy array to a base64 encoded image | |
| img = Image.fromarray(next_frame) | |
| buffered = io.BytesIO() | |
| img.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| # Send the generated frame back to the client | |
| await websocket.send_json({"image": img_str}) | |
| except asyncio.TimeoutError: | |
| print("WebSocket connection timed out") | |
| break | |
| except WebSocketDisconnect: | |
| print("WebSocket disconnected") | |
| break | |
| except Exception as e: | |
| print(f"Error in WebSocket connection {client_id}: {e}") | |
| finally: | |
| print(f"WebSocket connection closed: {client_id}") | |
| # Remove the explicit websocket.close() call here | |