Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
fc0af52
1
Parent(s):
4450019
main.py
CHANGED
@@ -19,6 +19,8 @@ app = FastAPI()
|
|
19 |
# Mount the static directory to serve HTML, JavaScript, and CSS files
|
20 |
app.mount("/static", StaticFiles(directory="static"), name="static")
|
21 |
|
|
|
|
|
22 |
|
23 |
def parse_action_string(action_str):
|
24 |
"""Convert formatted action string to x, y coordinates
|
@@ -79,6 +81,7 @@ def create_position_and_click_map(pos,action_type, image_height=48, image_width=
|
|
79 |
|
80 |
leftclick_map = torch.zeros((1, image_height, image_width))
|
81 |
if action_type == 'L':
|
|
|
82 |
leftclick_map[0, y_scaled, x_scaled] = 1.0
|
83 |
|
84 |
|
@@ -94,29 +97,38 @@ def generate_random_image(width: int, height: int) -> np.ndarray:
|
|
94 |
|
95 |
def draw_trace(image: np.ndarray, previous_actions: List[Tuple[str, List[int]]], x_scaled=-1, y_scaled=-1) -> np.ndarray:
|
96 |
pil_image = Image.fromarray(image)
|
97 |
-
#pil_image = Image.open('image_3.png')
|
98 |
draw = ImageDraw.Draw(pil_image)
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
prev_x, prev_y = None, None
|
101 |
for i, (action_type, position) in enumerate(previous_actions):
|
102 |
-
color = (255, 0, 0) if action_type == "move" else (0, 255, 0)
|
103 |
x, y = position
|
104 |
-
if x == 0 and y == 0
|
105 |
continue
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
# x = x * 256 / 1024
|
110 |
-
# y = y * 256 / 640
|
111 |
-
#draw.ellipse([x-2, y-2, x+2, y+2], fill=color)
|
112 |
|
|
|
|
|
113 |
|
114 |
-
#
|
115 |
-
|
116 |
-
|
117 |
-
prev_x, prev_y =
|
118 |
-
|
119 |
-
#
|
|
|
|
|
|
|
|
|
120 |
|
121 |
return np.array(pil_image)
|
122 |
|
@@ -267,6 +279,8 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
267 |
if j == 1:
|
268 |
x_scaled = x_scaled_j
|
269 |
y_scaled = y_scaled_j
|
|
|
|
|
270 |
|
271 |
#prompt = ''
|
272 |
#prompt = "1~1 0~0 0~0 0~0 0~0 0~0 0~0 0~0"
|
@@ -283,11 +297,18 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
283 |
# Draw the trace of previous actions
|
284 |
new_frame_with_trace = draw_trace(new_frame_denormalized, previous_actions, x_scaled, y_scaled)
|
285 |
|
|
|
|
|
|
|
|
|
286 |
return new_frame_with_trace, new_frame_denormalized
|
287 |
|
288 |
# WebSocket endpoint for continuous user interaction
|
289 |
@app.websocket("/ws")
|
290 |
async def websocket_endpoint(websocket: WebSocket):
|
|
|
|
|
|
|
291 |
client_id = id(websocket) # Use a unique identifier for each connection
|
292 |
print(f"New WebSocket connection: {client_id}")
|
293 |
await websocket.accept()
|
|
|
19 |
# Mount the static directory to serve HTML, JavaScript, and CSS files
|
20 |
app.mount("/static", StaticFiles(directory="static"), name="static")
|
21 |
|
22 |
+
# Add this at the top with other global variables
|
23 |
+
all_click_positions = [] # Store all historical click positions
|
24 |
|
25 |
def parse_action_string(action_str):
|
26 |
"""Convert formatted action string to x, y coordinates
|
|
|
81 |
|
82 |
leftclick_map = torch.zeros((1, image_height, image_width))
|
83 |
if action_type == 'L':
|
84 |
+
print ('left click', x_scaled, y_scaled)
|
85 |
leftclick_map[0, y_scaled, x_scaled] = 1.0
|
86 |
|
87 |
|
|
|
97 |
|
98 |
def draw_trace(image: np.ndarray, previous_actions: List[Tuple[str, List[int]]], x_scaled=-1, y_scaled=-1) -> np.ndarray:
|
99 |
pil_image = Image.fromarray(image)
|
|
|
100 |
draw = ImageDraw.Draw(pil_image)
|
101 |
+
|
102 |
+
# Draw all historical click positions
|
103 |
+
for click_x, click_y in all_click_positions:
|
104 |
+
x_draw = click_x # Scale factor for display
|
105 |
+
y_draw = click_y
|
106 |
+
# Draw historical clicks as red circles
|
107 |
+
draw.ellipse([x_draw-4, y_draw-4, x_draw+4, y_draw+4], fill=(255, 0, 0))
|
108 |
+
|
109 |
+
# Draw current trace
|
110 |
prev_x, prev_y = None, None
|
111 |
for i, (action_type, position) in enumerate(previous_actions):
|
|
|
112 |
x, y = position
|
113 |
+
if x == 0 and y == 0:
|
114 |
continue
|
115 |
+
|
116 |
+
x_draw = x
|
117 |
+
y_draw = y
|
|
|
|
|
|
|
118 |
|
119 |
+
# Draw movement positions as blue dots
|
120 |
+
draw.ellipse([x_draw-2, y_draw-2, x_draw+2, y_draw+2], fill=(0, 0, 255))
|
121 |
|
122 |
+
# Draw connecting lines
|
123 |
+
if prev_x is not None:
|
124 |
+
draw.line([prev_x, prev_y, x_draw, y_draw], fill=(0, 255, 0), width=1)
|
125 |
+
prev_x, prev_y = x_draw, y_draw
|
126 |
+
|
127 |
+
# Draw current position
|
128 |
+
if x_scaled >= 0 and y_scaled >= 0:
|
129 |
+
x_current = x_scaled * 8
|
130 |
+
y_current = y_scaled * 8
|
131 |
+
draw.ellipse([x_current-3, y_current-3, x_current+3, y_current+3], fill=(0, 255, 0))
|
132 |
|
133 |
return np.array(pil_image)
|
134 |
|
|
|
279 |
if j == 1:
|
280 |
x_scaled = x_scaled_j
|
281 |
y_scaled = y_scaled_j
|
282 |
+
if action_type == 'L':
|
283 |
+
all_click_positions.append((x, y))
|
284 |
|
285 |
#prompt = ''
|
286 |
#prompt = "1~1 0~0 0~0 0~0 0~0 0~0 0~0 0~0"
|
|
|
297 |
# Draw the trace of previous actions
|
298 |
new_frame_with_trace = draw_trace(new_frame_denormalized, previous_actions, x_scaled, y_scaled)
|
299 |
|
300 |
+
# Track click positions
|
301 |
+
#x, y, action_type = parse_action_string(action_descriptions[-1])
|
302 |
+
|
303 |
+
|
304 |
return new_frame_with_trace, new_frame_denormalized
|
305 |
|
306 |
# WebSocket endpoint for continuous user interaction
|
307 |
@app.websocket("/ws")
|
308 |
async def websocket_endpoint(websocket: WebSocket):
|
309 |
+
global all_click_positions # Add this line
|
310 |
+
all_click_positions = [] # Reset at the start of each connection
|
311 |
+
|
312 |
client_id = id(websocket) # Use a unique identifier for each connection
|
313 |
print(f"New WebSocket connection: {client_id}")
|
314 |
await websocket.accept()
|