da03 commited on
Commit
a9d9852
·
1 Parent(s): 76598de
Files changed (2) hide show
  1. main.py +33 -17
  2. utils.py +9 -2
main.py CHANGED
@@ -26,8 +26,10 @@ def parse_action_string(action_str):
26
  Returns:
27
  tuple: (x, y) coordinates or None if action is padding
28
  """
 
 
29
  if 'N' in action_str:
30
- return (None, None)
31
 
32
  # Split into x and y parts
33
  action_str = action_str.replace(' ', '')
@@ -40,9 +42,9 @@ def parse_action_string(action_str):
40
  # Parse y: remove sign, join digits, convert to int, apply sign
41
  y = int(y_part)
42
 
43
- return (x, y)
44
 
45
- def create_position_map(pos, image_size=64, original_width=1024, original_height=640):
46
  """Convert cursor position to a binary position map
47
  Args:
48
  x, y: Original cursor positions
@@ -53,12 +55,18 @@ def create_position_map(pos, image_size=64, original_width=1024, original_height
53
  torch.Tensor: Binary position map of shape (1, image_size, image_size)
54
  """
55
  x, y = pos
56
- #x, y = 307, 375
57
  if x is None:
58
- return torch.zeros((1, image_size, image_size))
59
  # Scale the positions to new size
60
- x_scaled = int((x / original_width) * image_size)
61
- y_scaled = int((y / original_height) * image_size)
 
 
 
 
 
 
 
62
 
63
  # Clamp values to ensure they're within bounds
64
  x_scaled = max(0, min(x_scaled, image_size - 1))
@@ -67,8 +75,13 @@ def create_position_map(pos, image_size=64, original_width=1024, original_height
67
  # Create binary position map
68
  pos_map = torch.zeros((1, image_size, image_size))
69
  pos_map[0, y_scaled, x_scaled] = 1.0
 
 
 
 
 
70
 
71
- return pos_map, x_scaled, y_scaled
72
 
73
  # Serve the index.html file at the root URL
74
  @app.get("/")
@@ -145,11 +158,11 @@ def denormalize_image(image, source_range=(-1, 1)):
145
 
146
  def format_action(action_str, is_padding=False):
147
  if is_padding:
148
- return "N N N N N : N N N N N"
149
 
150
  # Split the x~y coordinates
151
  x, y = map(int, action_str.split('~'))
152
-
153
  # Convert numbers to padded strings and add spaces between digits
154
  x_str = f"{abs(x):04d}"
155
  y_str = f"{abs(y):04d}"
@@ -157,10 +170,10 @@ def format_action(action_str, is_padding=False):
157
  y_spaced = ' '.join(y_str)
158
 
159
  # Format with sign and proper spacing
160
- return f"{'+ ' if x >= 0 else '- '}{x_spaced} : {'+ ' if y >= 0 else '- '}{y_spaced}"
161
 
162
  def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
163
- width, height = 256, 256
164
  initial_images = load_initial_images(width, height)
165
 
166
  # Prepare the image sequence for the model
@@ -174,7 +187,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
174
 
175
  # Prepare the prompt based on the previous actions
176
  action_descriptions = []
177
- initial_actions = ['901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '921:604']
178
  initial_actions = ['0:0'] * 7
179
  #initial_actions = ['N N N N N : N N N N N'] * 7
180
  def unnorm_coords(x, y):
@@ -191,8 +204,10 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
191
  for action_type, pos in previous_actions: #[-8:]:
192
  if action_type == "move":
193
  x, y = pos
194
- norm_x = int(round(x / 256 * 1024)) #x + (1920 - 256) / 2
195
- norm_y = int(round(y / 256 * 640)) #y + (1080 - 256) / 2
 
 
196
  #if DEBUG:
197
  # norm_x = x
198
  # norm_y = y
@@ -207,9 +222,10 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
207
  action_descriptions.append("right_click")
208
 
209
  prompt = " ".join(action_descriptions[-8:])
 
210
  #prompt = "N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N + 0 3 0 7 : + 0 3 7 5"
211
 
212
- pos_map, x_scaled, y_scaled = create_position_map(parse_action_string(action_descriptions[-1]))
213
 
214
 
215
  #prompt = ''
@@ -217,7 +233,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
217
  print(prompt)
218
 
219
  # Generate the next frame
220
- new_frame = sample_frame(model, prompt, image_sequence_tensor, pos_map=pos_map)
221
 
222
  # Convert the generated frame to the correct format
223
  new_frame = new_frame.transpose(1, 2, 0)
 
26
  Returns:
27
  tuple: (x, y) coordinates or None if action is padding
28
  """
29
+ action_type = action_str[0]
30
+ action_str = action_str[1:].strip()
31
  if 'N' in action_str:
32
+ return (None, None, None)
33
 
34
  # Split into x and y parts
35
  action_str = action_str.replace(' ', '')
 
42
  # Parse y: remove sign, join digits, convert to int, apply sign
43
  y = int(y_part)
44
 
45
+ return x, y, action_type
46
 
47
+ def create_position_and_click_map(pos,action_type,image_size=64, original_width=1024, original_height=640):
48
  """Convert cursor position to a binary position map
49
  Args:
50
  x, y: Original cursor positions
 
55
  torch.Tensor: Binary position map of shape (1, image_size, image_size)
56
  """
57
  x, y = pos
 
58
  if x is None:
59
+ return torch.zeros((1, image_size, image_size)), torch.zeros((1, image_size, image_size))
60
  # Scale the positions to new size
61
+ #x_scaled = int((x / original_width) * image_size)
62
+ #y_scaled = int((y / original_height) * image_size)
63
+ screen_width, screen_height = 1920, 1080
64
+ video_width, video_height = 512, 512
65
+
66
+ x_scaled = x - (screen_width / 2 - video_width / 2)
67
+ y_scaled = y - (screen_height / 2 - video_height / 2)
68
+ x_scaled = int(x_scaled / video_width * image_size)
69
+ y_scaled = int(y_scaled / video_height * image_size)
70
 
71
  # Clamp values to ensure they're within bounds
72
  x_scaled = max(0, min(x_scaled, image_size - 1))
 
75
  # Create binary position map
76
  pos_map = torch.zeros((1, image_size, image_size))
77
  pos_map[0, y_scaled, x_scaled] = 1.0
78
+
79
+ leftclick_map = torch.zeros((1, image_size, image_size))
80
+ if action_type == 'L':
81
+ leftclick_map[0, y_scaled, x_scaled] = 1.0
82
+
83
 
84
+ return pos_map, leftclick_map, x_scaled, y_scaled
85
 
86
  # Serve the index.html file at the root URL
87
  @app.get("/")
 
158
 
159
  def format_action(action_str, is_padding=False):
160
  if is_padding:
161
+ return "N N N N N N : N N N N N"
162
 
163
  # Split the x~y coordinates
164
  x, y = map(int, action_str.split('~'))
165
+ prefix = 'N'
166
  # Convert numbers to padded strings and add spaces between digits
167
  x_str = f"{abs(x):04d}"
168
  y_str = f"{abs(y):04d}"
 
170
  y_spaced = ' '.join(y_str)
171
 
172
  # Format with sign and proper spacing
173
+ return prefix + " " + f"{'+ ' if x >= 0 else '- '}{x_spaced} : {'+ ' if y >= 0 else '- '}{y_spaced}"
174
 
175
  def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
176
+ width, height = 512, 512
177
  initial_images = load_initial_images(width, height)
178
 
179
  # Prepare the image sequence for the model
 
187
 
188
  # Prepare the prompt based on the previous actions
189
  action_descriptions = []
190
+ #initial_actions = ['901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '921:604']
191
  initial_actions = ['0:0'] * 7
192
  #initial_actions = ['N N N N N : N N N N N'] * 7
193
  def unnorm_coords(x, y):
 
204
  for action_type, pos in previous_actions: #[-8:]:
205
  if action_type == "move":
206
  x, y = pos
207
+ #norm_x = int(round(x / 256 * 1024)) #x + (1920 - 256) / 2
208
+ #norm_y = int(round(y / 256 * 640)) #y + (1080 - 256) / 2
209
+ norm_x = x + (1920 - 512) / 2
210
+ norm_y = y + (1080 - 512) / 2
211
  #if DEBUG:
212
  # norm_x = x
213
  # norm_y = y
 
222
  action_descriptions.append("right_click")
223
 
224
  prompt = " ".join(action_descriptions[-8:])
225
+ print(prompt)
226
  #prompt = "N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N + 0 3 0 7 : + 0 3 7 5"
227
 
228
+ pos_map, leftclick_map, x_scaled, y_scaled = create_position_and_click_map(parse_action_string(action_descriptions[-1]))
229
 
230
 
231
  #prompt = ''
 
233
  print(prompt)
234
 
235
  # Generate the next frame
236
+ new_frame = sample_frame(model, prompt, image_sequence_tensor, pos_map=pos_map, leftclick_map=leftclick_map)
237
 
238
  # Convert the generated frame to the correct format
239
  new_frame = new_frame.transpose(1, 2, 0)
utils.py CHANGED
@@ -28,7 +28,7 @@ def load_model_from_config(config_path, model_name, device='cuda'):
28
  model.eval()
29
  return model
30
 
31
- def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tensor, pos_map=None):
32
  sampler = DDIMSampler(model)
33
 
34
  with torch.no_grad():
@@ -39,9 +39,16 @@ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tens
39
  c_dict = {'c_crossattn': prompt, 'c_concat': image_sequence}
40
  c = model.get_learned_conditioning(c_dict)
41
  c = model.enc_concat_seq(c, c_dict, 'c_concat')
 
 
 
 
 
 
 
42
  if pos_map is not None:
43
  print (pos_map.shape, c['c_concat'].shape)
44
- c['c_concat'] = torch.cat([c['c_concat'][:, :, :, :], pos_map.to(c['c_concat'].device).unsqueeze(0)], dim=1)
45
 
46
  print ('sleeping')
47
  #time.sleep(120)
 
28
  model.eval()
29
  return model
30
 
31
+ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tensor, pos_map=None, leftclick_map=None):
32
  sampler = DDIMSampler(model)
33
 
34
  with torch.no_grad():
 
39
  c_dict = {'c_crossattn': prompt, 'c_concat': image_sequence}
40
  c = model.get_learned_conditioning(c_dict)
41
  c = model.enc_concat_seq(c, c_dict, 'c_concat')
42
+ # Zero out the corresponding subtensors in c_concat for padding images
43
+ padding_mask = torch.isclose(image_sequence, torch.tensor(-1.0), rtol=1e-5, atol=1e-5).all(dim=(1, 2, 3)).unsqueeze(1)
44
+ print (padding_mask)
45
+ padding_mask = padding_mask.repeat(1, 4) # Repeat mask 4 times for each projected channel
46
+ print (image_sequence.shape, padding_mask.shape, c['c_concat'].shape)
47
+ c['c_concat'] = c['c_concat'] * (~padding_mask) # Zero out the corresponding features
48
+
49
  if pos_map is not None:
50
  print (pos_map.shape, c['c_concat'].shape)
51
+ c['c_concat'] = torch.cat([c['c_concat'][:, :, :, :], pos_map.to(c['c_concat'].device).unsqueeze(0), leftclick_map.to(c['c_concat'].device).unsqueeze(0)], dim=1)
52
 
53
  print ('sleeping')
54
  #time.sleep(120)