Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from typing import List, Tuple | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| import base64 | |
| import io | |
| import asyncio | |
| from utils import initialize_model, sample_frame | |
| import torch | |
| import os | |
| import time | |
| DEBUG = True | |
| DEBUG_TEACHER_FORCING = True | |
| app = FastAPI() | |
| # Mount the static directory to serve HTML, JavaScript, and CSS files | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| def parse_action_string(action_str): | |
| """Convert formatted action string to x, y coordinates | |
| Args: | |
| action_str: String like 'N N N N N : N N N N N' or '+ 0 2 1 3 : + 0 3 8 3' | |
| Returns: | |
| tuple: (x, y) coordinates or None if action is padding | |
| """ | |
| action_type = action_str[0] | |
| action_str = action_str[1:].strip() | |
| if 'N' in action_str: | |
| return (None, None, None) | |
| # Split into x and y parts | |
| action_str = action_str.replace(' ', '') | |
| x_part, y_part = action_str.split(':') | |
| # Parse x: remove sign, join digits, convert to int, apply sign | |
| x = int(x_part) | |
| # Parse y: remove sign, join digits, convert to int, apply sign | |
| y = int(y_part) | |
| return x, y, action_type | |
| def create_position_and_click_map(pos,action_type,image_size=64, original_width=1024, original_height=640): | |
| """Convert cursor position to a binary position map | |
| Args: | |
| x, y: Original cursor positions | |
| image_size: Size of the output position map (square) | |
| original_width: Original screen width (1024) | |
| original_height: Original screen height (640) | |
| Returns: | |
| torch.Tensor: Binary position map of shape (1, image_size, image_size) | |
| """ | |
| x, y = pos | |
| if x is None: | |
| return torch.zeros((1, image_size, image_size)), torch.zeros((1, image_size, image_size)) | |
| # Scale the positions to new size | |
| #x_scaled = int((x / original_width) * image_size) | |
| #y_scaled = int((y / original_height) * image_size) | |
| screen_width, screen_height = 1920, 1080 | |
| video_width, video_height = 512, 512 | |
| x_scaled = x - (screen_width / 2 - video_width / 2) | |
| y_scaled = y - (screen_height / 2 - video_height / 2) | |
| x_scaled = int(x_scaled / video_width * image_size) | |
| y_scaled = int(y_scaled / video_height * image_size) | |
| # Clamp values to ensure they're within bounds | |
| x_scaled = max(0, min(x_scaled, image_size - 1)) | |
| y_scaled = max(0, min(y_scaled, image_size - 1)) | |
| # Create binary position map | |
| pos_map = torch.zeros((1, image_size, image_size)) | |
| pos_map[0, y_scaled, x_scaled] = 1.0 | |
| leftclick_map = torch.zeros((1, image_size, image_size)) | |
| if action_type == 'L': | |
| leftclick_map[0, y_scaled, x_scaled] = 1.0 | |
| return pos_map, leftclick_map, x_scaled, y_scaled | |
| # Serve the index.html file at the root URL | |
| async def get(): | |
| return HTMLResponse(open("static/index.html").read()) | |
| def generate_random_image(width: int, height: int) -> np.ndarray: | |
| return np.random.randint(0, 256, (height, width, 3), dtype=np.uint8) | |
| def draw_trace(image: np.ndarray, previous_actions: List[Tuple[str, List[int]]], x_scaled=-1, y_scaled=-1) -> np.ndarray: | |
| pil_image = Image.fromarray(image) | |
| #pil_image = Image.open('image_3.png') | |
| draw = ImageDraw.Draw(pil_image) | |
| flag = True | |
| prev_x, prev_y = None, None | |
| for i, (action_type, position) in enumerate(previous_actions): | |
| color = (255, 0, 0) if action_type == "move" else (0, 255, 0) | |
| x, y = position | |
| if x == 0 and y == 0 and flag: | |
| continue | |
| else: | |
| flag = False | |
| #if DEBUG: | |
| # x = x * 256 / 1024 | |
| # y = y * 256 / 640 | |
| #draw.ellipse([x-2, y-2, x+2, y+2], fill=color) | |
| #if prev_x is not None: | |
| # #prev_x, prev_y = previous_actions[i-1][1] | |
| # draw.line([prev_x, prev_y, x, y], fill=color, width=1) | |
| prev_x, prev_y = x, y | |
| draw.ellipse([x_scaled*8-2, y_scaled*8-2, x_scaled*8+2, y_scaled*8+2], fill=(0, 255, 0)) | |
| #pil_image = pil_image.convert("RGB") | |
| return np.array(pil_image) | |
| # Initialize the model at the start of your application | |
| #model = initialize_model("config_csllm.yaml", "yuntian-deng/computer-model") | |
| model = initialize_model("pssearch_bsz64_acc1_lr8e5_512_leftclick.yaml", "yuntian-deng/computer-model") | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = model.to(device) | |
| def load_initial_images(width, height): | |
| initial_images = [] | |
| if DEBUG_TEACHER_FORCING: | |
| # Load the previous 7 frames for image_81 | |
| for i in range(75, 82): # Load images 74-80 | |
| img = Image.open(f"record_100/image_{i}.png").resize((width, height)) | |
| initial_images.append(np.array(img)) | |
| else: | |
| for i in range(7): | |
| initial_images.append(np.zeros((height, width, 3), dtype=np.uint8)) | |
| return initial_images | |
| def normalize_images(images, target_range=(-1, 1)): | |
| images = np.stack(images).astype(np.float32) | |
| if target_range == (-1, 1): | |
| return images / 127.5 - 1 | |
| elif target_range == (0, 1): | |
| return images / 255.0 | |
| else: | |
| raise ValueError(f"Unsupported target range: {target_range}") | |
| def denormalize_image(image, source_range=(-1, 1)): | |
| if source_range == (-1, 1): | |
| return ((image + 1) * 127.5).clip(0, 255).astype(np.uint8) | |
| elif source_range == (0, 1): | |
| return (image * 255).clip(0, 255).astype(np.uint8) | |
| else: | |
| raise ValueError(f"Unsupported source range: {source_range}") | |
| def format_action(action_str, is_padding=False, is_leftclick=False): | |
| if is_padding: | |
| return "N N N N N N : N N N N N" | |
| # Split the x~y coordinates | |
| x, y = map(int, action_str.split('~')) | |
| prefix = 'N' | |
| if is_leftclick: | |
| prefix = 'L' | |
| # Convert numbers to padded strings and add spaces between digits | |
| x_str = f"{abs(x):04d}" | |
| y_str = f"{abs(y):04d}" | |
| x_spaced = ' '.join(x_str) | |
| y_spaced = ' '.join(y_str) | |
| # Format with sign and proper spacing | |
| return prefix + " " + f"{'+ ' if x >= 0 else '- '}{x_spaced} : {'+ ' if y >= 0 else '- '}{y_spaced}" | |
| def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray: | |
| width, height = 512, 512 | |
| initial_images = load_initial_images(width, height) | |
| # Prepare the image sequence for the model | |
| image_sequence = previous_frames[-7:] # Take the last 7 frames | |
| while len(image_sequence) < 7: | |
| image_sequence.insert(0, initial_images[len(image_sequence)]) | |
| # Convert the image sequence to a tensor and concatenate in the channel dimension | |
| image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence, target_range=(-1, 1))) | |
| image_sequence_tensor = image_sequence_tensor.to(device) | |
| # Prepare the prompt based on the previous actions | |
| action_descriptions = [] | |
| #initial_actions = ['901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '921:604'] | |
| initial_actions = ['0:0'] * 7 | |
| #initial_actions = ['N N N N N : N N N N N'] * 7 | |
| def unnorm_coords(x, y): | |
| return int(x), int(y) #int(x - (1920 - 256) / 2), int(y - (1080 - 256) / 2) | |
| # Process initial actions if there are not enough previous actions | |
| while len(previous_actions) < 8: | |
| x, y = map(int, initial_actions.pop(0).split(':')) | |
| previous_actions.insert(0, ("N", unnorm_coords(x, y))) | |
| prev_x = 0 | |
| prev_y = 0 | |
| #print ('here') | |
| for action_type, pos in previous_actions: #[-8:]: | |
| print ('here3', action_type, pos) | |
| if action_type == "N": | |
| x, y = pos | |
| #norm_x = int(round(x / 256 * 1024)) #x + (1920 - 256) / 2 | |
| #norm_y = int(round(y / 256 * 640)) #y + (1080 - 256) / 2 | |
| norm_x = x + (1920 - 512) / 2 | |
| norm_y = y + (1080 - 512) / 2 | |
| #if DEBUG: | |
| # norm_x = x | |
| # norm_y = y | |
| #action_descriptions.append(f"{(norm_x-prev_x):.0f}~{(norm_y-prev_y):.0f}") | |
| #action_descriptions.append(format_action(f'{norm_x-prev_x:.0f}~{norm_y-prev_y:.0f}', x==0 and y==0)) | |
| action_descriptions.append(format_action(f'{norm_x:.0f}~{norm_y:.0f}', x==0 and y==0)) | |
| prev_x = norm_x | |
| prev_y = norm_y | |
| elif action_type == "L": | |
| x, y = pos | |
| #norm_x = int(round(x / 256 * 1024)) #x + (1920 - 256) / 2 | |
| #norm_y = int(round(y / 256 * 640)) #y + (1080 - 256) / 2 | |
| norm_x = x + (1920 - 512) / 2 | |
| norm_y = y + (1080 - 512) / 2 | |
| #if DEBUG: | |
| # norm_x = x | |
| # norm_y = y | |
| #action_descriptions.append(f"{(norm_x-prev_x):.0f}~{(norm_y-prev_y):.0f}") | |
| #action_descriptions.append(format_action(f'{norm_x-prev_x:.0f}~{norm_y-prev_y:.0f}', x==0 and y==0)) | |
| action_descriptions.append(format_action(f'{norm_x:.0f}~{norm_y:.0f}', x==0 and y==0, True)) | |
| elif action_type == "right_click": | |
| assert False | |
| action_descriptions.append("right_click") | |
| else: | |
| assert False | |
| prompt = " ".join(action_descriptions[-8:]) | |
| print(prompt) | |
| #prompt = "N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N + 0 3 0 7 : + 0 3 7 5" | |
| #x, y, action_type = parse_action_string(action_descriptions[-1]) | |
| #pos_map, leftclick_map, x_scaled, y_scaled = create_position_and_click_map((x, y), action_type) | |
| leftclick_maps = [] | |
| pos_maps = [] | |
| for j in range(1, 9): | |
| x, y, action_type = parse_action_string(action_descriptions[-j]) | |
| pos_map_j, leftclick_map_j, x_scaled_j, y_scaled_j = create_position_and_click_map((x, y), action_type) | |
| leftclick_maps.append(leftclick_map_j) | |
| pos_maps.append(pos_map_j) | |
| if j == 1: | |
| x_scaled = x_scaled_j | |
| y_scaled = y_scaled_j | |
| #prompt = '' | |
| #prompt = "1~1 0~0 0~0 0~0 0~0 0~0 0~0 0~0" | |
| print(prompt) | |
| # Generate the next frame | |
| new_frame = sample_frame(model, prompt, image_sequence_tensor, pos_maps=pos_maps, leftclick_maps=leftclick_maps) | |
| # Convert the generated frame to the correct format | |
| new_frame = new_frame.transpose(1, 2, 0) | |
| print (new_frame.max(), new_frame.min()) | |
| new_frame_denormalized = denormalize_image(new_frame, source_range=(-1, 1)) | |
| # Draw the trace of previous actions | |
| new_frame_with_trace = draw_trace(new_frame_denormalized, previous_actions, x_scaled, y_scaled) | |
| return new_frame_with_trace, new_frame_denormalized | |
| # WebSocket endpoint for continuous user interaction | |
| async def websocket_endpoint(websocket: WebSocket): | |
| client_id = id(websocket) # Use a unique identifier for each connection | |
| print(f"New WebSocket connection: {client_id}") | |
| await websocket.accept() | |
| previous_frames = [] | |
| previous_actions = [] | |
| positions = ['815~335', '787~342', '787~342', '749~345', '703~346', '703~346', '654~347', '654~347', '604~349', '555~353', '555~353', '509~357', '509~357', '468~362', '431~368', '431~368'] | |
| #positions = ['815~335', '787~342', '749~345', '703~346', '703~346', '654~347', '654~347', '604~349', '555~353', '555~353', '509~357', '509~357', '468~362', '431~368', '431~368'] | |
| positions = ['307~375'] | |
| positions = ['815~335'] | |
| #positions = ['787~342'] | |
| positions = ['300~800'] | |
| if DEBUG_TEACHER_FORCING: | |
| #print ('here2') | |
| # Use the predefined actions for image_81 | |
| debug_actions = [ | |
| 'N + 0 8 5 3 : + 0 4 5 0', 'N + 0 8 7 1 : + 0 4 6 3', | |
| 'N + 0 8 9 0 : + 0 4 7 5', 'N + 0 9 0 8 : + 0 4 8 8', | |
| 'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1', | |
| 'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1', | |
| 'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1', | |
| 'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1', | |
| 'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1', | |
| 'N + 0 9 2 7 : + 0 5 0 1', #'N + 0 9 2 7 : + 0 5 0 1' | |
| ] | |
| previous_actions = [] | |
| for action in debug_actions[-8:]: | |
| x, y, action_type = parse_action_string(action) | |
| previous_actions.append((action_type, (x, y))) | |
| positions = [ | |
| 'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 1 8 : + 0 4 9 2', | |
| 'N + 0 9 0 8 : + 0 4 8 3', 'N + 0 8 9 8 : + 0 4 7 4', | |
| 'N + 0 8 8 9 : + 0 4 6 5', 'N + 0 8 8 0 : + 0 4 5 6', | |
| 'N + 0 8 7 0 : + 0 4 4 7', 'N + 0 8 6 0 : + 0 4 3 8', | |
| 'N + 0 8 5 1 : + 0 4 2 9', 'N + 0 8 4 2 : + 0 4 2 0', | |
| 'N + 0 8 3 2 : + 0 4 1 1', 'N + 0 8 3 2 : + 0 4 1 1'] | |
| #positions = positions[:4] | |
| try: | |
| while True: | |
| try: | |
| # Receive user input with a timeout | |
| #data = await asyncio.wait_for(websocket.receive_json(), timeout=90000.0) | |
| data = await websocket.receive_json() | |
| if data.get("type") == "heartbeat": | |
| await websocket.send_json({"type": "heartbeat_response"}) | |
| continue | |
| action_type = data.get("action_type") | |
| mouse_position = data.get("mouse_position") | |
| # Store the actions | |
| if DEBUG: | |
| position = positions[0] | |
| #positions = positions[1:] | |
| #mouse_position = position.split('~') | |
| #mouse_position = [int(item) for item in mouse_position] | |
| #mouse_position = '+ 0 8 1 5 : + 0 3 3 5' | |
| if DEBUG_TEACHER_FORCING: | |
| position = positions[0] | |
| positions = positions[1:] | |
| x, y, action_type = parse_action_string(position) | |
| mouse_position = (x, y) | |
| previous_actions.append((action_type, mouse_position)) | |
| #previous_actions = [(action_type, mouse_position)] | |
| # Log the start time | |
| start_time = time.time() | |
| # Predict the next frame based on the previous frames and actions | |
| next_frame, next_frame_append = predict_next_frame(previous_frames, previous_actions) | |
| # Load and append the corresponding ground truth image instead of model output | |
| #img = Image.open(f"image_{len(previous_frames)%7}.png") | |
| previous_frames.append(next_frame_append) | |
| # Convert the numpy array to a base64 encoded image | |
| img = Image.fromarray(next_frame) | |
| buffered = io.BytesIO() | |
| img.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| # Log the processing time | |
| processing_time = time.time() - start_time | |
| print(f"Frame processing time: {processing_time:.2f} seconds") | |
| # Send the generated frame back to the client | |
| await websocket.send_json({"image": img_str}) | |
| except asyncio.TimeoutError: | |
| print("WebSocket connection timed out") | |
| #break # Exit the loop on timeout | |
| except WebSocketDisconnect: | |
| print("WebSocket disconnected") | |
| #break # Exit the loop on disconnect | |
| except Exception as e: | |
| print(f"Error in WebSocket connection {client_id}: {e}") | |
| finally: | |
| print(f"WebSocket connection closed: {client_id}") | |
| #await websocket.close() # Ensure the WebSocket is closed | |