yuntian-deng commited on
Commit
5754a1c
·
1 Parent(s): 7422728

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +60 -2
main.py CHANGED
@@ -18,6 +18,57 @@ app = FastAPI()
18
  # Mount the static directory to serve HTML, JavaScript, and CSS files
19
  app.mount("/static", StaticFiles(directory="static"), name="static")
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # Serve the index.html file at the root URL
22
  @app.get("/")
23
  async def get():
@@ -131,6 +182,8 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
131
  previous_actions.insert(0, ("move", unnorm_coords(x, y)))
132
  prev_x = 0
133
  prev_y = 0
 
 
134
  for action_type, pos in previous_actions: #[-8:]:
135
  if action_type == "move":
136
  x, y = pos
@@ -140,7 +193,8 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
140
  norm_x = x
141
  norm_y = y
142
  #action_descriptions.append(f"{(norm_x-prev_x):.0f}~{(norm_y-prev_y):.0f}")
143
- action_descriptions.append(format_action(f'{norm_x-prev_x:.0f}~{norm_y-prev_y:.0f}', x==0 and y==0))
 
144
  prev_x = norm_x
145
  prev_y = norm_y
146
  elif action_type == "left_click":
@@ -149,12 +203,16 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
149
  action_descriptions.append("right_click")
150
 
151
  prompt = " ".join(action_descriptions[-8:])
 
 
 
 
152
  #prompt = ''
153
  #prompt = "1~1 0~0 0~0 0~0 0~0 0~0 0~0 0~0"
154
  print(prompt)
155
 
156
  # Generate the next frame
157
- new_frame = sample_frame(model, prompt, image_sequence_tensor)
158
 
159
  # Convert the generated frame to the correct format
160
  new_frame = new_frame.transpose(1, 2, 0)
 
18
  # Mount the static directory to serve HTML, JavaScript, and CSS files
19
  app.mount("/static", StaticFiles(directory="static"), name="static")
20
 
21
+
22
+ def parse_action_string(action_str):
23
+ """Convert formatted action string to x, y coordinates
24
+ Args:
25
+ action_str: String like 'N N N N N : N N N N N' or '+ 0 2 1 3 : + 0 3 8 3'
26
+ Returns:
27
+ tuple: (x, y) coordinates or None if action is padding
28
+ """
29
+ if 'N' in action_str:
30
+ return (None, None)
31
+
32
+ # Split into x and y parts
33
+ action_str = action_str.replace(' ', '')
34
+ x_part, y_part = action_str.split(':')
35
+
36
+ # Parse x: remove sign, join digits, convert to int, apply sign
37
+
38
+ x = int(x_part)
39
+
40
+ # Parse y: remove sign, join digits, convert to int, apply sign
41
+ y = int(y_part)
42
+
43
+ return (x, y)
44
+
45
+ def create_position_map(pos, image_size=64, original_width=1024, original_height=640):
46
+ """Convert cursor position to a binary position map
47
+ Args:
48
+ x, y: Original cursor positions
49
+ image_size: Size of the output position map (square)
50
+ original_width: Original screen width (1024)
51
+ original_height: Original screen height (640)
52
+ Returns:
53
+ torch.Tensor: Binary position map of shape (1, image_size, image_size)
54
+ """
55
+ x, y = pos
56
+ if x is None:
57
+ return torch.zeros((1, image_size, image_size))
58
+ # Scale the positions to new size
59
+ x_scaled = int((x / original_width) * image_size)
60
+ y_scaled = int((y / original_height) * image_size)
61
+
62
+ # Clamp values to ensure they're within bounds
63
+ x_scaled = max(0, min(x_scaled, image_size - 1))
64
+ y_scaled = max(0, min(y_scaled, image_size - 1))
65
+
66
+ # Create binary position map
67
+ pos_map = torch.zeros((1, image_size, image_size))
68
+ pos_map[0, y_scaled, x_scaled] = 1.0
69
+
70
+ return pos_map
71
+
72
  # Serve the index.html file at the root URL
73
  @app.get("/")
74
  async def get():
 
182
  previous_actions.insert(0, ("move", unnorm_coords(x, y)))
183
  prev_x = 0
184
  prev_y = 0
185
+
186
+
187
  for action_type, pos in previous_actions: #[-8:]:
188
  if action_type == "move":
189
  x, y = pos
 
193
  norm_x = x
194
  norm_y = y
195
  #action_descriptions.append(f"{(norm_x-prev_x):.0f}~{(norm_y-prev_y):.0f}")
196
+ #action_descriptions.append(format_action(f'{norm_x-prev_x:.0f}~{norm_y-prev_y:.0f}', x==0 and y==0))
197
+ action_descriptions.append(format_action(f'{norm_x:.0f}~{norm_y:.0f}', x==0 and y==0))
198
  prev_x = norm_x
199
  prev_y = norm_y
200
  elif action_type == "left_click":
 
203
  action_descriptions.append("right_click")
204
 
205
  prompt = " ".join(action_descriptions[-8:])
206
+
207
+ pos_map = create_position_map(parse_action_string(action_descriptions[-1]))
208
+
209
+
210
  #prompt = ''
211
  #prompt = "1~1 0~0 0~0 0~0 0~0 0~0 0~0 0~0"
212
  print(prompt)
213
 
214
  # Generate the next frame
215
+ new_frame = sample_frame(model, prompt, image_sequence_tensor, pos_map=pos_map)
216
 
217
  # Convert the generated frame to the correct format
218
  new_frame = new_frame.transpose(1, 2, 0)