da03 commited on
Commit
fc0af52
·
1 Parent(s): 4450019
Files changed (1) hide show
  1. main.py +37 -16
main.py CHANGED
@@ -19,6 +19,8 @@ app = FastAPI()
19
  # Mount the static directory to serve HTML, JavaScript, and CSS files
20
  app.mount("/static", StaticFiles(directory="static"), name="static")
21
 
 
 
22
 
23
  def parse_action_string(action_str):
24
  """Convert formatted action string to x, y coordinates
@@ -79,6 +81,7 @@ def create_position_and_click_map(pos,action_type, image_height=48, image_width=
79
 
80
  leftclick_map = torch.zeros((1, image_height, image_width))
81
  if action_type == 'L':
 
82
  leftclick_map[0, y_scaled, x_scaled] = 1.0
83
 
84
 
@@ -94,29 +97,38 @@ def generate_random_image(width: int, height: int) -> np.ndarray:
94
 
95
  def draw_trace(image: np.ndarray, previous_actions: List[Tuple[str, List[int]]], x_scaled=-1, y_scaled=-1) -> np.ndarray:
96
  pil_image = Image.fromarray(image)
97
- #pil_image = Image.open('image_3.png')
98
  draw = ImageDraw.Draw(pil_image)
99
- flag = True
 
 
 
 
 
 
 
 
100
  prev_x, prev_y = None, None
101
  for i, (action_type, position) in enumerate(previous_actions):
102
- color = (255, 0, 0) if action_type == "move" else (0, 255, 0)
103
  x, y = position
104
- if x == 0 and y == 0 and flag:
105
  continue
106
- else:
107
- flag = False
108
- #if DEBUG:
109
- # x = x * 256 / 1024
110
- # y = y * 256 / 640
111
- #draw.ellipse([x-2, y-2, x+2, y+2], fill=color)
112
 
 
 
113
 
114
- #if prev_x is not None:
115
- # #prev_x, prev_y = previous_actions[i-1][1]
116
- # draw.line([prev_x, prev_y, x, y], fill=color, width=1)
117
- prev_x, prev_y = x, y
118
- draw.ellipse([x_scaled*8-2, y_scaled*8-2, x_scaled*8+2, y_scaled*8+2], fill=(0, 255, 0))
119
- #pil_image = pil_image.convert("RGB")
 
 
 
 
120
 
121
  return np.array(pil_image)
122
 
@@ -267,6 +279,8 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
267
  if j == 1:
268
  x_scaled = x_scaled_j
269
  y_scaled = y_scaled_j
 
 
270
 
271
  #prompt = ''
272
  #prompt = "1~1 0~0 0~0 0~0 0~0 0~0 0~0 0~0"
@@ -283,11 +297,18 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
283
  # Draw the trace of previous actions
284
  new_frame_with_trace = draw_trace(new_frame_denormalized, previous_actions, x_scaled, y_scaled)
285
 
 
 
 
 
286
  return new_frame_with_trace, new_frame_denormalized
287
 
288
  # WebSocket endpoint for continuous user interaction
289
  @app.websocket("/ws")
290
  async def websocket_endpoint(websocket: WebSocket):
 
 
 
291
  client_id = id(websocket) # Use a unique identifier for each connection
292
  print(f"New WebSocket connection: {client_id}")
293
  await websocket.accept()
 
19
  # Mount the static directory to serve HTML, JavaScript, and CSS files
20
  app.mount("/static", StaticFiles(directory="static"), name="static")
21
 
22
+ # Add this at the top with other global variables
23
+ all_click_positions = [] # Store all historical click positions
24
 
25
  def parse_action_string(action_str):
26
  """Convert formatted action string to x, y coordinates
 
81
 
82
  leftclick_map = torch.zeros((1, image_height, image_width))
83
  if action_type == 'L':
84
+ print ('left click', x_scaled, y_scaled)
85
  leftclick_map[0, y_scaled, x_scaled] = 1.0
86
 
87
 
 
97
 
98
  def draw_trace(image: np.ndarray, previous_actions: List[Tuple[str, List[int]]], x_scaled=-1, y_scaled=-1) -> np.ndarray:
99
  pil_image = Image.fromarray(image)
 
100
  draw = ImageDraw.Draw(pil_image)
101
+
102
+ # Draw all historical click positions
103
+ for click_x, click_y in all_click_positions:
104
+ x_draw = click_x # Scale factor for display
105
+ y_draw = click_y
106
+ # Draw historical clicks as red circles
107
+ draw.ellipse([x_draw-4, y_draw-4, x_draw+4, y_draw+4], fill=(255, 0, 0))
108
+
109
+ # Draw current trace
110
  prev_x, prev_y = None, None
111
  for i, (action_type, position) in enumerate(previous_actions):
 
112
  x, y = position
113
+ if x == 0 and y == 0:
114
  continue
115
+
116
+ x_draw = x
117
+ y_draw = y
 
 
 
118
 
119
+ # Draw movement positions as blue dots
120
+ draw.ellipse([x_draw-2, y_draw-2, x_draw+2, y_draw+2], fill=(0, 0, 255))
121
 
122
+ # Draw connecting lines
123
+ if prev_x is not None:
124
+ draw.line([prev_x, prev_y, x_draw, y_draw], fill=(0, 255, 0), width=1)
125
+ prev_x, prev_y = x_draw, y_draw
126
+
127
+ # Draw current position
128
+ if x_scaled >= 0 and y_scaled >= 0:
129
+ x_current = x_scaled * 8
130
+ y_current = y_scaled * 8
131
+ draw.ellipse([x_current-3, y_current-3, x_current+3, y_current+3], fill=(0, 255, 0))
132
 
133
  return np.array(pil_image)
134
 
 
279
  if j == 1:
280
  x_scaled = x_scaled_j
281
  y_scaled = y_scaled_j
282
+ if action_type == 'L':
283
+ all_click_positions.append((x, y))
284
 
285
  #prompt = ''
286
  #prompt = "1~1 0~0 0~0 0~0 0~0 0~0 0~0 0~0"
 
297
  # Draw the trace of previous actions
298
  new_frame_with_trace = draw_trace(new_frame_denormalized, previous_actions, x_scaled, y_scaled)
299
 
300
+ # Track click positions
301
+ #x, y, action_type = parse_action_string(action_descriptions[-1])
302
+
303
+
304
  return new_frame_with_trace, new_frame_denormalized
305
 
306
  # WebSocket endpoint for continuous user interaction
307
  @app.websocket("/ws")
308
  async def websocket_endpoint(websocket: WebSocket):
309
+ global all_click_positions # Add this line
310
+ all_click_positions = [] # Reset at the start of each connection
311
+
312
  client_id = id(websocket) # Use a unique identifier for each connection
313
  print(f"New WebSocket connection: {client_id}")
314
  await websocket.accept()