da03 commited on
Commit
9d0127f
·
1 Parent(s): 487aaae
Files changed (2) hide show
  1. main.py +21 -13
  2. utils.py +8 -7
main.py CHANGED
@@ -168,6 +168,15 @@ def normalize_images(images, target_range=(-1, 1)):
168
  return images / 255.0
169
  else:
170
  raise ValueError(f"Unsupported target range: {target_range}")
 
 
 
 
 
 
 
 
 
171
 
172
  def denormalize_image(image, source_range=(-1, 1)):
173
  if source_range == (-1, 1):
@@ -195,28 +204,27 @@ def format_action(action_str, is_padding=False, is_leftclick=False):
195
  # Format with sign and proper spacing
196
  return prefix + " " + f"{'+ ' if x >= 0 else '- '}{x_spaced} : {'+ ' if y >= 0 else '- '}{y_spaced}"
197
 
198
- def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
199
  width, height = 512, 384
200
  all_click_positions = []
201
  initial_images = load_initial_images(width, height)
202
  print ('length of previous_frames', len(previous_frames))
 
203
 
204
  # Prepare the image sequence for the model
205
  assert len(initial_images) == 32
206
  image_sequence = previous_frames[-32:] # Take the last 7 frames
207
  i = 1
208
  while len(image_sequence) < 32:
209
- image_sequence.insert(0, initial_images[-i])
210
  i += 1
211
  #image_sequence.append(initial_images[len(image_sequence)])
212
-
213
  # Convert the image sequence to a tensor and concatenate in the channel dimension
214
- image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence, target_range=(-1, 1)))
215
- image_sequence_tensor = image_sequence_tensor.to(device)
216
- data_mean = -0.54
217
- data_std = 6.78
218
- data_min = -27.681446075439453
219
- data_max = 30.854148864746094
220
  #image_sequence_tensor = (image_sequence_tensor - data_mean) / data_std
221
 
222
  # Prepare the prompt based on the previous actions
@@ -318,7 +326,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
318
  #print ('changing L to N')
319
 
320
  # Generate the next frame
321
- new_frame = sample_frame(model, prompt, image_sequence_tensor, pos_maps=pos_maps, leftclick_maps=leftclick_maps)
322
 
323
  # Convert the generated frame to the correct format
324
  new_frame = new_frame.transpose(1, 2, 0)
@@ -333,7 +341,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
333
  #x, y, action_type = parse_action_string(action_descriptions[-1])
334
 
335
 
336
- return new_frame_with_trace, new_frame_denormalized
337
 
338
  # WebSocket endpoint for continuous user interaction
339
  @app.websocket("/ws")
@@ -513,10 +521,10 @@ async def websocket_endpoint(websocket: WebSocket):
513
  #if DEBUG_TEACHER_FORCING:
514
  # print ('predicting', f"record_10003/image_{117+len(previous_frames)}.png")
515
  print ('previous_actions', previous_actions)
516
- next_frame, next_frame_append = predict_next_frame(previous_frames, previous_actions)
517
  feedback = True
518
  if feedback:
519
- previous_frames.append(next_frame_append)
520
  else:
521
  #previous_frames = []
522
  previous_actions = []
 
168
  return images / 255.0
169
  else:
170
  raise ValueError(f"Unsupported target range: {target_range}")
171
+
172
+ def normalize_image(image, target_range=(-1, 1)):
173
+ image = image.astype(np.float32)
174
+ if target_range == (-1, 1):
175
+ return image / 127.5 - 1
176
+ elif target_range == (0, 1):
177
+ return image / 255.0
178
+ else:
179
+ raise ValueError(f"Unsupported target range: {target_range}")
180
 
181
  def denormalize_image(image, source_range=(-1, 1)):
182
  if source_range == (-1, 1):
 
204
  # Format with sign and proper spacing
205
  return prefix + " " + f"{'+ ' if x >= 0 else '- '}{x_spaced} : {'+ ' if y >= 0 else '- '}{y_spaced}"
206
 
207
+ def predict_next_frame(previous_frames: List[np.ndarray, Tuple[str, np.ndarray]], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
208
  width, height = 512, 384
209
  all_click_positions = []
210
  initial_images = load_initial_images(width, height)
211
  print ('length of previous_frames', len(previous_frames))
212
+ padding_image = torch.zeros((height//8, width//8, 4)).to(device)
213
 
214
  # Prepare the image sequence for the model
215
  assert len(initial_images) == 32
216
  image_sequence = previous_frames[-32:] # Take the last 7 frames
217
  i = 1
218
  while len(image_sequence) < 32:
219
+ image_sequence.insert(0, padding_image)
220
  i += 1
221
  #image_sequence.append(initial_images[len(image_sequence)])
222
+
223
  # Convert the image sequence to a tensor and concatenate in the channel dimension
224
+ #image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence_list, target_range=(-1, 1)))
225
+ #image_sequence_tensor = image_sequence_tensor.to(device)
226
+ image_sequence_tensor = torch.cat(image_sequence, dim=1)
227
+
 
 
228
  #image_sequence_tensor = (image_sequence_tensor - data_mean) / data_std
229
 
230
  # Prepare the prompt based on the previous actions
 
326
  #print ('changing L to N')
327
 
328
  # Generate the next frame
329
+ new_frame, new_frame_feedback = sample_frame(model, prompt, image_sequence_tensor, pos_maps=pos_maps, leftclick_maps=leftclick_maps)
330
 
331
  # Convert the generated frame to the correct format
332
  new_frame = new_frame.transpose(1, 2, 0)
 
341
  #x, y, action_type = parse_action_string(action_descriptions[-1])
342
 
343
 
344
+ return new_frame_with_trace, new_frame_denormalized, new_frame_feedback
345
 
346
  # WebSocket endpoint for continuous user interaction
347
  @app.websocket("/ws")
 
521
  #if DEBUG_TEACHER_FORCING:
522
  # print ('predicting', f"record_10003/image_{117+len(previous_frames)}.png")
523
  print ('previous_actions', previous_actions)
524
+ next_frame, next_frame_append, next_frame_feedback = predict_next_frame(previous_frames, previous_actions)
525
  feedback = True
526
  if feedback:
527
+ previous_frames.append(next_frame_feedback)
528
  else:
529
  #previous_frames = []
530
  previous_actions = []
utils.py CHANGED
@@ -45,13 +45,13 @@ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tens
45
  #print (c['c_crossattn'][0])
46
  print (prompt)
47
  c = {}
48
- c = model.enc_concat_seq(c, c_dict, 'c_concat')
49
  # Zero out the corresponding subtensors in c_concat for padding images
50
- padding_mask = torch.isclose(image_sequence, torch.tensor(-1.0), rtol=1e-5, atol=1e-5).all(dim=(1, 2, 3)).unsqueeze(0)
51
- print (padding_mask)
52
- padding_mask = padding_mask.repeat(1, 4) # Repeat mask 4 times for each projected channel
53
- print (image_sequence.shape, padding_mask.shape, c['c_concat'].shape)
54
- c['c_concat'] = c['c_concat'] * (~padding_mask.unsqueeze(-1).unsqueeze(-1)) # Zero out the corresponding features
55
  data_mean = -0.54
56
  data_std = 6.78
57
  data_min = -27.681446075439453
@@ -108,6 +108,7 @@ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tens
108
  data_max = 30.854148864746094
109
  x_samples_ddim = samples_ddim
110
  x_samples_ddim = x_samples_ddim * data_std + data_mean
 
111
  x_samples_ddim = model.decode_first_stage(x_samples_ddim)
112
  print ('dfsf3')
113
  #x_samples_ddim = pos_map.to(c['c_concat'].device).unsqueeze(0).expand(-1, 3, -1, -1)
@@ -115,7 +116,7 @@ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tens
115
  #x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
116
  x_samples_ddim = torch.clamp(x_samples_ddim, min=-1.0, max=1.0)
117
 
118
- return x_samples_ddim.squeeze(0).cpu().numpy()
119
 
120
  # Global variables for model and device
121
  #model = None
 
45
  #print (c['c_crossattn'][0])
46
  print (prompt)
47
  c = {}
48
+ #c = model.enc_concat_seq(c, c_dict, 'c_concat')
49
  # Zero out the corresponding subtensors in c_concat for padding images
50
+ #padding_mask = torch.isclose(image_sequence, torch.tensor(-1.0), rtol=1e-5, atol=1e-5).all(dim=(1, 2, 3)).unsqueeze(0)
51
+ #print (padding_mask)
52
+ #padding_mask = padding_mask.repeat(1, 4) # Repeat mask 4 times for each projected channel
53
+ #print (image_sequence.shape, padding_mask.shape, c['c_concat'].shape)
54
+ #c['c_concat'] = c['c_concat'] * (~padding_mask.unsqueeze(-1).unsqueeze(-1)) # Zero out the corresponding features
55
  data_mean = -0.54
56
  data_std = 6.78
57
  data_min = -27.681446075439453
 
108
  data_max = 30.854148864746094
109
  x_samples_ddim = samples_ddim
110
  x_samples_ddim = x_samples_ddim * data_std + data_mean
111
+ x_samples_ddim_feedback = x_samples_ddim
112
  x_samples_ddim = model.decode_first_stage(x_samples_ddim)
113
  print ('dfsf3')
114
  #x_samples_ddim = pos_map.to(c['c_concat'].device).unsqueeze(0).expand(-1, 3, -1, -1)
 
116
  #x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
117
  x_samples_ddim = torch.clamp(x_samples_ddim, min=-1.0, max=1.0)
118
 
119
+ return x_samples_ddim.squeeze(0).cpu().numpy(), x_samples_ddim_feedback.squeeze(0)
120
 
121
  # Global variables for model and device
122
  #model = None