Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
from fastapi.responses import HTMLResponse | |
from fastapi.staticfiles import StaticFiles | |
from typing import List, Tuple | |
import numpy as np | |
from PIL import Image, ImageDraw | |
import base64 | |
import io | |
import asyncio | |
from utils import initialize_model, sample_frame | |
import torch | |
import os | |
import time | |
DEBUG = False | |
DEBUG_TEACHER_FORCING = False | |
app = FastAPI() | |
# Mount the static directory to serve HTML, JavaScript, and CSS files | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
# Add this at the top with other global variables | |
all_click_positions = [] # Store all historical click positions | |
def parse_action_string(action_str): | |
"""Convert formatted action string to x, y coordinates | |
Args: | |
action_str: String like 'N N N N N : N N N N N' or '+ 0 2 1 3 : + 0 3 8 3' | |
Returns: | |
tuple: (x, y) coordinates or None if action is padding | |
""" | |
action_type = action_str[0] | |
action_str = action_str[1:].strip() | |
if 'N' in action_str: | |
return (None, None, None) | |
# Split into x and y parts | |
action_str = action_str.replace(' ', '') | |
x_part, y_part = action_str.split(':') | |
# Parse x: remove sign, join digits, convert to int, apply sign | |
x = int(x_part) | |
# Parse y: remove sign, join digits, convert to int, apply sign | |
y = int(y_part) | |
return x, y, action_type | |
def create_position_and_click_map(pos,action_type, image_height=48, image_width=64, original_width=512, original_height=384): | |
"""Convert cursor position to a binary position map | |
Args: | |
x, y: Original cursor positions | |
image_size: Size of the output position map (square) | |
original_width: Original screen width (1024) | |
original_height: Original screen height (640) | |
Returns: | |
torch.Tensor: Binary position map of shape (1, image_size, image_size) | |
""" | |
x, y = pos | |
if x is None: | |
return torch.zeros((1, image_height, image_width)), torch.zeros((1, image_height, image_width)), None, None | |
# Scale the positions to new size | |
#x_scaled = int((x / original_width) * image_size) | |
#y_scaled = int((y / original_height) * image_size) | |
#screen_width, screen_height = 512, 384 | |
#video_width, video_height = 512, 384 | |
#x_scaled = x - (screen_width / 2 - video_width / 2) | |
#y_scaled = y - (screen_height / 2 - video_height / 2) | |
x_scaled = int(x / original_width * image_width) | |
y_scaled = int(y / original_height * image_height) | |
# Clamp values to ensure they're within bounds | |
x_scaled = max(0, min(x_scaled, image_width - 1)) | |
y_scaled = max(0, min(y_scaled, image_height - 1)) | |
# Create binary position map | |
pos_map = torch.zeros((1, image_height, image_width)) | |
pos_map[0, y_scaled, x_scaled] = 1.0 | |
leftclick_map = torch.zeros((1, image_height, image_width)) | |
if action_type == 'L': | |
print ('left click', x_scaled, y_scaled) | |
#print ('skipped') | |
if True: | |
leftclick_map[0, y_scaled, x_scaled] = 1.0 | |
return pos_map, leftclick_map, x_scaled, y_scaled | |
# Serve the index.html file at the root URL | |
async def get(): | |
return HTMLResponse(open("static/index.html").read()) | |
def generate_random_image(width: int, height: int) -> np.ndarray: | |
return np.random.randint(0, 256, (height, width, 3), dtype=np.uint8) | |
def draw_trace(image: np.ndarray, previous_actions: List[Tuple[str, List[int]]], x_scaled=-1, y_scaled=-1) -> np.ndarray: | |
pil_image = Image.fromarray(image) | |
draw = ImageDraw.Draw(pil_image) | |
# Draw all historical click positions | |
for click_x, click_y in all_click_positions: | |
x_draw = click_x # Scale factor for display | |
y_draw = click_y | |
# Draw historical clicks as red circles | |
draw.ellipse([x_draw-4, y_draw-4, x_draw+4, y_draw+4], fill=(255, 0, 0)) | |
# Draw current trace | |
prev_x, prev_y = None, None | |
for i, (action_type, position) in enumerate(previous_actions): | |
x, y = position | |
if x == 0 and y == 0: | |
continue | |
x_draw = x | |
y_draw = y | |
# Draw movement positions as blue dots | |
draw.ellipse([x_draw-2, y_draw-2, x_draw+2, y_draw+2], fill=(0, 0, 255)) | |
# Draw connecting lines | |
if prev_x is not None: | |
draw.line([prev_x, prev_y, x_draw, y_draw], fill=(0, 255, 0), width=1) | |
prev_x, prev_y = x_draw, y_draw | |
# Draw current position | |
if x_scaled >= 0 and y_scaled >= 0: | |
x_current = x_scaled * 8 | |
y_current = y_scaled * 8 | |
#if not DEBUG_TEACHER_FORCING: | |
# x_current = x_current *8 | |
# y_current = y_current *8 | |
print ('x_current, y_current', x_current, y_current) | |
draw.ellipse([x_current-3, y_current-3, x_current+3, y_current+3], fill=(0, 255, 0)) | |
else: | |
assert False | |
return np.array(pil_image) | |
# Initialize the model at the start of your application | |
#model = initialize_model("config_csllm.yaml", "yuntian-deng/computer-model") | |
model = initialize_model("standard_challenging_context32_nocond_all.yaml", "yuntian-deng/computer-model") | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = model.to(device) | |
def load_initial_images(width, height): | |
initial_images = [] | |
if DEBUG_TEACHER_FORCING: | |
# Load the previous 7 frames for image_81 | |
for i in range(117-7, 117): # Load images 74-80 | |
img = Image.open(f"record_10003/image_{i}.png")#.resize((width, height)) | |
initial_images.append(np.array(img)) | |
else: | |
#assert False | |
for i in range(32): | |
initial_images.append(np.zeros((height, width, 3), dtype=np.uint8)) | |
return initial_images | |
def normalize_images(images, target_range=(-1, 1)): | |
images = np.stack(images).astype(np.float32) | |
if target_range == (-1, 1): | |
return images / 127.5 - 1 | |
elif target_range == (0, 1): | |
return images / 255.0 | |
else: | |
raise ValueError(f"Unsupported target range: {target_range}") | |
def normalize_image(image, target_range=(-1, 1)): | |
image = image.astype(np.float32) | |
if target_range == (-1, 1): | |
return image / 127.5 - 1 | |
elif target_range == (0, 1): | |
return image / 255.0 | |
else: | |
raise ValueError(f"Unsupported target range: {target_range}") | |
def denormalize_image(image, source_range=(-1, 1)): | |
if source_range == (-1, 1): | |
return ((image + 1) * 127.5).clip(0, 255).astype(np.uint8) | |
elif source_range == (0, 1): | |
return (image * 255).clip(0, 255).astype(np.uint8) | |
else: | |
raise ValueError(f"Unsupported source range: {source_range}") | |
def format_action(action_str, is_padding=False, is_leftclick=False): | |
if is_padding: | |
return "N N N N N N : N N N N N" | |
# Split the x~y coordinates | |
x, y = map(int, action_str.split('~')) | |
prefix = 'N' | |
if is_leftclick: | |
prefix = 'L' | |
# Convert numbers to padded strings and add spaces between digits | |
x_str = f"{abs(x):04d}" | |
y_str = f"{abs(y):04d}" | |
x_spaced = ' '.join(x_str) | |
y_spaced = ' '.join(y_str) | |
# Format with sign and proper spacing | |
return prefix + " " + f"{'+ ' if x >= 0 else '- '}{x_spaced} : {'+ ' if y >= 0 else '- '}{y_spaced}" | |
def predict_next_frame(previous_frames, previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray: | |
width, height = 512, 384 | |
all_click_positions = [] | |
initial_images = load_initial_images(width, height) | |
print ('length of previous_frames', len(previous_frames)) | |
padding_image = torch.zeros((height//8, width//8, 4)).to(device) | |
# Prepare the image sequence for the model | |
assert len(initial_images) == 32 | |
image_sequence = previous_frames[-32:] # Take the last 7 frames | |
i = 1 | |
while len(image_sequence) < 32: | |
image_sequence.insert(0, padding_image) | |
i += 1 | |
#image_sequence.append(initial_images[len(image_sequence)]) | |
# Convert the image sequence to a tensor and concatenate in the channel dimension | |
#image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence_list, target_range=(-1, 1))) | |
#image_sequence_tensor = image_sequence_tensor.to(device) | |
image_sequence_tensor = torch.cat(image_sequence, dim=1) | |
#image_sequence_tensor = (image_sequence_tensor - data_mean) / data_std | |
# Prepare the prompt based on the previous actions | |
action_descriptions = [] | |
#initial_actions = ['901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '921:604'] | |
initial_actions = ['0:0'] * 32 | |
#initial_actions = ['N N N N N : N N N N N'] * 7 | |
def unnorm_coords(x, y): | |
return int(x), int(y) #int(x - (1920 - 256) / 2), int(y - (1080 - 256) / 2) | |
# Process initial actions if there are not enough previous actions | |
while len(previous_actions) < 33: | |
#assert False | |
x, y = map(int, initial_actions.pop(0).split(':')) | |
previous_actions.insert(0, ("N", unnorm_coords(x, y))) | |
prev_x = 0 | |
prev_y = 0 | |
#print ('here') | |
if False: | |
prompt = 'N + 0 4 1 6 : + 0 3 2 0 L + 0 2 0 0 : + 0 1 7 6 N + 0 3 8 4 : + 0 0 4 8 N + 0 3 6 0 : + 0 2 5 6 N + 0 3 6 8 : + 0 0 1 6 N + 0 0 3 2 : + 0 1 0 4 L + 0 2 8 0 : + 0 0 4 0 L + 0 5 0 4 : + 0 0 7 2' | |
previous_actions = [('move', (416, 320)), ('left_click', (200, 176)), ('move', (384, 48)), ('move', (360, 256)), ('move', (368, 16)), ('move', (32, 104)), ('left_click', (280, 40)), ('left_click', (504, 72))] | |
prompt = 'N + 0 3 4 4 : + 0 3 2 0 N + 0 4 8 0 : + 0 1 2 8 N + 0 4 4 8 : + 0 3 6 0 N + 0 4 4 8 : + 0 0 6 4 N + 0 4 6 4 : + 0 3 3 6 N + 0 0 2 4 : + 0 1 3 6 N + 0 1 2 8 : + 0 2 8 0 N + 0 4 4 0 : + 0 0 4 8' | |
previous_actions = [('move', (344, 320)), ('move', (480, 128)), ('move', (448, 360)), ('move', (448, 64)), ('move', (464, 336)), ('move', (24, 136)), ('move', (128, 280)), ('move', (440, 48))] | |
prompt = 'N + 0 4 7 2 : + 0 1 6 0 N + 0 3 0 4 : + 0 2 7 2 N + 0 0 0 0 : + 0 1 7 6 N + 0 2 0 0 : + 0 0 3 2 N + 0 1 6 8 : + 0 0 5 6 L + 0 4 3 2 : + 0 0 4 0 L + 0 2 0 8 : + 0 2 7 2 L + 0 1 8 4 : + 0 0 0 8' | |
previous_actions = [('move', (472, 160)), ('move', (304, 272)), ('move', (0, 176)), ('move', (200, 32)), ('left_click', (168, 56)), ('left_click', (432, 40)), ('left_click', (208, 272)), ('left_click', (184, 8))] | |
prompt = 'N + 0 0 1 6 : + 0 3 2 8 N + 0 3 0 4 : + 0 0 9 6 N + 0 2 4 0 : + 0 1 9 2 N + 0 1 5 2 : + 0 0 5 6 L + 0 2 8 8 : + 0 1 7 6 L + 0 0 5 6 : + 0 3 7 6 N + 0 1 3 6 : + 0 3 6 0 N + 0 1 1 2 : + 0 0 4 8' | |
previous_actions = [('move', (16, 328)), ('move', (304, 96)), ('move', (240, 192)), ('move', (152, 56)), ('left_click', (288, 176)), ('left_click', (56, 376)), ('move', (136, 360)), ('move', (112, 48))] | |
prompt = 'L + 0 0 5 6 : + 0 1 2 8 N + 0 4 0 0 : + 0 0 6 4 N + 0 5 0 4 : + 0 1 2 8 N + 0 4 2 4 : + 0 1 2 0 N + 0 3 2 0 : + 0 1 0 4 N + 0 2 8 0 : + 0 1 0 4 N + 0 2 7 2 : + 0 1 0 4 N + 0 2 7 2 : + 0 1 0 4' | |
previous_actions = [('left_click', (56, 128)), ('left_click', (400, 64)), ('move', (504, 128)), ('move', (424, 120)), ('left_click', (320, 104)), ('left_click', (280, 104)), ('move', (272, 104)), ('move', (272, 104))] | |
for action_type, pos in previous_actions[-33:]: | |
#print ('here3', action_type, pos) | |
if action_type == 'move': | |
action_type = 'N' | |
if action_type == 'left_click': | |
action_type = 'L' | |
if action_type == "N": | |
x, y = pos | |
#norm_x = int(round(x / 256 * 1024)) #x + (1920 - 256) / 2 | |
#norm_y = int(round(y / 256 * 640)) #y + (1080 - 256) / 2 | |
#norm_x = x + (1920 - 512) / 2 | |
#norm_y = y + (1080 - 512) / 2 | |
norm_x = x | |
norm_y = y | |
if False and DEBUG_TEACHER_FORCING: | |
norm_x = x | |
norm_y = y | |
#action_descriptions.append(f"{(norm_x-prev_x):.0f}~{(norm_y-prev_y):.0f}") | |
#action_descriptions.append(format_action(f'{norm_x-prev_x:.0f}~{norm_y-prev_y:.0f}', x==0 and y==0)) | |
action_descriptions.append(format_action(f'{norm_x:.0f}~{norm_y:.0f}', x==0 and y==0)) | |
prev_x = norm_x | |
prev_y = norm_y | |
elif action_type == "L": | |
x, y = pos | |
#norm_x = int(round(x / 256 * 1024)) #x + (1920 - 256) / 2 | |
#norm_y = int(round(y / 256 * 640)) #y + (1080 - 256) / 2 | |
#norm_x = x + (1920 - 512) / 2 | |
#norm_y = y + (1080 - 512) / 2 | |
norm_x = x | |
norm_y = y | |
if False and DEBUG_TEACHER_FORCING: | |
norm_x = x #+ (1920 - 512) / 2 | |
norm_y = y #+ (1080 - 512) / 2 | |
#if DEBUG: | |
# norm_x = x | |
# norm_y = y | |
#action_descriptions.append(f"{(norm_x-prev_x):.0f}~{(norm_y-prev_y):.0f}") | |
#action_descriptions.append(format_action(f'{norm_x-prev_x:.0f}~{norm_y-prev_y:.0f}', x==0 and y==0)) | |
action_descriptions.append(format_action(f'{norm_x:.0f}~{norm_y:.0f}', x==0 and y==0, True)) | |
elif action_type == "right_click": | |
assert False | |
action_descriptions.append("right_click") | |
else: | |
assert False | |
prompt = " ".join(action_descriptions[-33:]) | |
print(prompt) | |
#prompt = "N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N + 0 3 0 7 : + 0 3 7 5" | |
#x, y, action_type = parse_action_string(action_descriptions[-1]) | |
#pos_map, leftclick_map, x_scaled, y_scaled = create_position_and_click_map((x, y), action_type) | |
leftclick_maps = [] | |
pos_maps = [] | |
for j in range(1, 34): | |
print ('fsfs', action_descriptions[-j]) | |
x, y, action_type = parse_action_string(action_descriptions[-j]) | |
pos_map_j, leftclick_map_j, x_scaled_j, y_scaled_j = create_position_and_click_map((x, y), action_type) | |
leftclick_maps.append(leftclick_map_j) | |
pos_maps.append(pos_map_j) | |
if j == 1: | |
x_scaled = x_scaled_j | |
y_scaled = y_scaled_j | |
if action_type == 'L': | |
all_click_positions.append((x, y)) | |
#prompt = '' | |
#prompt = "1~1 0~0 0~0 0~0 0~0 0~0 0~0 0~0" | |
print(prompt) | |
#prompt = prompt.replace('L', 'N') | |
#print ('changing L to N') | |
# Generate the next frame | |
new_frame, new_frame_feedback = sample_frame(model, prompt, image_sequence_tensor, pos_maps=pos_maps, leftclick_maps=leftclick_maps) | |
# Convert the generated frame to the correct format | |
new_frame = new_frame.transpose(1, 2, 0) | |
print (new_frame.max(), new_frame.min()) | |
#new_frame = new_frame * data_std + data_mean | |
new_frame_denormalized = denormalize_image(new_frame, source_range=(-1, 1)) | |
# Draw the trace of previous actions | |
new_frame_with_trace = draw_trace(new_frame_denormalized, previous_actions, x_scaled, y_scaled) | |
# Track click positions | |
#x, y, action_type = parse_action_string(action_descriptions[-1]) | |
return new_frame_with_trace, new_frame_denormalized, new_frame_feedback | |
# WebSocket endpoint for continuous user interaction | |
async def websocket_endpoint(websocket: WebSocket): | |
#global all_click_positions # Add this line | |
#all_click_positions = [] # Reset at the start of each connection | |
client_id = id(websocket) # Use a unique identifier for each connection | |
print(f"New WebSocket connection: {client_id}") | |
await websocket.accept() | |
previous_frames = [] | |
previous_actions = [] | |
positions = ['815~335', '787~342', '787~342', '749~345', '703~346', '703~346', '654~347', '654~347', '604~349', '555~353', '555~353', '509~357', '509~357', '468~362', '431~368', '431~368'] | |
#positions = ['815~335', '787~342', '749~345', '703~346', '703~346', '654~347', '654~347', '604~349', '555~353', '555~353', '509~357', '509~357', '468~362', '431~368', '431~368'] | |
positions = ['307~375'] | |
positions = ['815~335'] | |
#positions = ['787~342'] | |
positions = ['300~800'] | |
if DEBUG_TEACHER_FORCING: | |
#print ('here2') | |
# Use the predefined actions for image_81 | |
debug_actions = [ | |
'N + 0 8 5 3 : + 0 4 5 0', 'N + 0 8 7 1 : + 0 4 6 3', | |
'N + 0 8 9 0 : + 0 4 7 5', 'N + 0 9 0 8 : + 0 4 8 8', | |
'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1', | |
'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1', | |
'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1', | |
'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1', | |
'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1', | |
'N + 0 9 2 7 : + 0 5 0 1', #'N + 0 9 2 7 : + 0 5 0 1' | |
] | |
debug_actions = [ | |
'N + 1 1 6 5 : + 0 4 4 3', 'N + 1 1 7 0 : + 0 4 1 8', | |
'N + 1 1 7 5 : + 0 3 9 4', 'N + 1 1 8 1 : + 0 3 7 0', | |
'N + 1 1 8 4 : + 0 3 5 8', 'N + 1 1 8 9 : + 0 3 3 3', | |
'N + 1 1 9 4 : + 0 3 0 9', 'N + 1 1 9 7 : + 0 2 9 7', | |
'N + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7', | |
'N + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7', | |
'L + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7', | |
'N + 1 1 9 7 : + 0 2 9 7' | |
] | |
debug_actions = [ | |
'N + 1 1 6 5 : + 0 4 4 3', 'N + 1 1 7 0 : + 0 4 1 8', | |
'N + 1 1 7 5 : + 0 3 9 4', 'N + 1 1 8 1 : + 0 3 7 0', | |
'N + 1 1 8 4 : + 0 3 5 8', 'N + 1 1 8 9 : + 0 3 3 3', | |
'N + 1 1 9 4 : + 0 3 0 9', 'N + 1 1 9 7 : + 0 2 9 7', | |
'N + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7', | |
'N + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7', | |
'N + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7', | |
'N + 1 1 9 7 : + 0 2 9 7' | |
] | |
debug_actions = ['N + 0 0 4 0 : + 0 2 0 4', 'N + 0 1 3 8 : + 0 1 9 0', | |
'N + 0 2 7 4 : + 0 3 8 3', 'N + 0 5 0 1 : + 0 1 7 3', | |
'L + 0 4 7 3 : + 0 0 8 7', 'N + 0 1 0 9 : + 0 3 4 4', | |
'N + 0 0 5 2 : + 0 1 9 4', 'N + 0 3 6 5 : + 0 2 3 2', | |
'N + 0 3 8 9 : + 0 2 4 5', 'N + 0 0 2 0 : + 0 0 5 9', | |
'N + 0 4 7 3 : + 0 1 5 7', 'L + 0 1 9 1 : + 0 0 8 7', | |
'L + 0 1 9 1 : + 0 0 8 7', 'N + 0 3 4 3 : + 0 2 6 3', ] | |
#'N + 0 2 0 5 : + 0 1 3 3'] | |
previous_actions = [] | |
for action in debug_actions[-8:]: | |
#action = action.replace('1 1', '0 4') | |
x, y, action_type = parse_action_string(action) | |
previous_actions.append((action_type, (x, y))) | |
positions = [ | |
'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 1 8 : + 0 4 9 2', | |
'N + 0 9 0 8 : + 0 4 8 3', 'N + 0 8 9 8 : + 0 4 7 4', | |
'N + 0 8 8 9 : + 0 4 6 5', 'N + 0 8 8 0 : + 0 4 5 6', | |
'N + 0 8 7 0 : + 0 4 4 7', 'N + 0 8 6 0 : + 0 4 3 8', | |
'N + 0 8 5 1 : + 0 4 2 9', 'N + 0 8 4 2 : + 0 4 2 0', | |
'N + 0 8 3 2 : + 0 4 1 1', 'N + 0 8 3 2 : + 0 4 1 1' | |
] | |
positions = [ | |
#'L + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7', | |
'N + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7', | |
'N + 1 1 7 9 : + 0 3 0 3', 'N + 1 1 4 2 : + 0 3 1 4', | |
'N + 1 1 0 6 : + 0 3 2 6', 'N + 1 0 6 9 : + 0 3 3 7', | |
'N + 1 0 5 1 : + 0 3 4 3', 'N + 1 0 1 4 : + 0 3 5 4', | |
'N + 0 9 7 8 : + 0 3 6 5', 'N + 0 9 4 2 : + 0 3 7 7', | |
'N + 0 9 0 5 : + 0 3 8 8', 'N + 0 8 6 8 : + 0 4 0 0', | |
'N + 0 8 3 2 : + 0 4 1 1' | |
] | |
positions = ['L + 0 1 9 1 : + 0 0 8 7', | |
'L + 0 1 9 1 : + 0 0 8 7', 'N + 0 3 4 3 : + 0 2 6 3', | |
'N + 0 2 0 5 : + 0 1 3 3', 'N + 0 0 7 6 : + 0 3 4 5', | |
'N + 0 3 1 8 : + 0 3 3 3', 'N + 0 2 5 4 : + 0 2 9 0', | |
'N + 0 1 0 6 : + 0 1 6 4', 'N + 0 0 7 4 : + 0 2 8 4', | |
'N + 0 0 2 4 : + 0 0 4 1', 'N + 0 1 5 0 : + 0 3 8 3', | |
'N + 0 4 0 5 : + 0 1 6 8', 'N + 0 0 5 4 : + 0 3 2 4', | |
'N + 0 2 9 0 : + 0 1 4 1', 'N + 0 4 0 2 : + 0 0 0 9', | |
'N + 0 3 0 7 : + 0 3 3 2', 'N + 0 2 2 0 : + 0 3 7 1', | |
'N + 0 0 8 2 : + 0 1 5 1'] | |
positions = positions[3:] | |
#positions = positions[:4] | |
#position = positions[0] | |
#positions = positions[1:] | |
#x, y, action_type = parse_action_string(position) | |
#mouse_position = (x, y) | |
#previous_actions.append((action_type, mouse_position)) | |
if not DEBUG_TEACHER_FORCING: | |
previous_actions = [] | |
for t in range(15): # Generate 15 actions | |
# Random movement | |
x = np.random.randint(0, 64) | |
y = np.random.randint(0, 48) | |
#x = max(0, min(63, x + dx)) | |
#y = max(0, min(47, y + dy)) | |
# Random click with 20% probability | |
if np.random.random() < 0.2: | |
action_type = 'L' | |
else: | |
action_type = 'N' | |
# Format action string | |
previous_actions.append((action_type, (x*8, y*8))) | |
try: | |
previous_actions = [] | |
previous_frames = [] | |
frames_since_update = 0 | |
frame_times = [] | |
while True: | |
try: | |
# Receive user input with a timeout | |
#data = await asyncio.wait_for(websocket.receive_json(), timeout=90000.0) | |
data = await websocket.receive_json() | |
if data.get("type") == "heartbeat": | |
await websocket.send_json({"type": "heartbeat_response"}) | |
continue | |
action_type = data.get("action_type") | |
mouse_position = data.get("mouse_position") | |
#if np.random.random() < 0.9: | |
# print ('setting left click') | |
# action_type = 'left_click' | |
#else: | |
# print ('not setting left click') | |
#action_type = 'move' | |
#print ('setting normal move') | |
# Store the actions | |
if False and DEBUG: | |
position = positions[0] | |
#positions = positions[1:] | |
#mouse_position = position.split('~') | |
#mouse_position = [int(item) for item in mouse_position] | |
#mouse_position = '+ 0 8 1 5 : + 0 3 3 5' | |
if DEBUG_TEACHER_FORCING: | |
position = positions[0] | |
positions = positions[1:] | |
x, y, action_type = parse_action_string(position) | |
mouse_position = (x, y) | |
previous_actions.append((action_type, mouse_position)) | |
if True: | |
previous_actions.append((action_type, mouse_position)) | |
#previous_actions = [(action_type, mouse_position)] | |
#if not DEBUG_TEACHER_FORCING: | |
# x, y = mouse_position | |
# x = x//8 * 8 | |
# y = y // 8 * 8 | |
# assert x % 8 == 0 | |
# assert y % 8 == 0 | |
# mouse_position = (x, y) | |
# #mouse_position = (x//8, y//8) | |
# previous_actions.append((action_type, mouse_position)) | |
# Log the start time | |
start_time = time.time() | |
# Predict the next frame based on the previous frames and actions | |
#if DEBUG_TEACHER_FORCING: | |
# print ('predicting', f"record_10003/image_{117+len(previous_frames)}.png") | |
print ('previous_actions', previous_actions) | |
next_frame, next_frame_append, next_frame_feedback = predict_next_frame(previous_frames, previous_actions) | |
feedback = True | |
if feedback: | |
previous_frames.append(next_frame_feedback) | |
else: | |
#previous_frames = [] | |
previous_actions = [] | |
processing_time = time.time() - start_time | |
print(f"Frame processing time: {processing_time:.2f} seconds") | |
frame_times.append(processing_time) | |
frames_since_update += 1 | |
print (f"Average frame processing time: {np.mean(frame_times):.2f} seconds") | |
fps = 1 / np.mean(frame_times) | |
print (f"FPS: {fps:.2f}") | |
#previous_actions = [] | |
# Load and append the corresponding ground truth image instead of model output | |
#print ('here4', len(previous_frames)) | |
#if DEBUG_TEACHER_FORCING: | |
# img = Image.open(f"record_10003/image_{117+len(previous_frames)}.png") | |
# previous_frames.append(np.array(img)) | |
#else: | |
# assert False | |
# previous_frames.append(next_frame_append) | |
# pass | |
#previous_frames = [] | |
#previous_actions = [] | |
# Convert the numpy array to a base64 encoded image | |
img = Image.fromarray(next_frame) | |
buffered = io.BytesIO() | |
img.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()).decode() | |
# Log the processing time | |
# Send the generated frame back to the client | |
await websocket.send_json({"image": img_str}) | |
except asyncio.TimeoutError: | |
print("WebSocket connection timed out") | |
#break # Exit the loop on timeout | |
except WebSocketDisconnect: | |
print("WebSocket disconnected") | |
#break # Exit the loop on disconnect | |
except Exception as e: | |
print(f"Error in WebSocket connection {client_id}: {e}") | |
finally: | |
print(f"WebSocket connection closed: {client_id}") | |
#await websocket.close() # Ensure the WebSocket is closed | |