da03 commited on
Commit
b07d75e
·
1 Parent(s): a9363f4
Files changed (2) hide show
  1. main.py +13 -4
  2. utils.py +6 -4
main.py CHANGED
@@ -258,16 +258,25 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
258
  prompt = " ".join(action_descriptions[-8:])
259
  print(prompt)
260
  #prompt = "N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N + 0 3 0 7 : + 0 3 7 5"
261
- x, y, action_type = parse_action_string(action_descriptions[-1])
262
- pos_map, leftclick_map, x_scaled, y_scaled = create_position_and_click_map((x, y), action_type)
263
-
 
 
 
 
 
 
 
 
 
264
 
265
  #prompt = ''
266
  #prompt = "1~1 0~0 0~0 0~0 0~0 0~0 0~0 0~0"
267
  print(prompt)
268
 
269
  # Generate the next frame
270
- new_frame = sample_frame(model, prompt, image_sequence_tensor, pos_map=pos_map, leftclick_map=leftclick_map)
271
 
272
  # Convert the generated frame to the correct format
273
  new_frame = new_frame.transpose(1, 2, 0)
 
258
  prompt = " ".join(action_descriptions[-8:])
259
  print(prompt)
260
  #prompt = "N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N + 0 3 0 7 : + 0 3 7 5"
261
+ #x, y, action_type = parse_action_string(action_descriptions[-1])
262
+ #pos_map, leftclick_map, x_scaled, y_scaled = create_position_and_click_map((x, y), action_type)
263
+ leftclick_maps = []
264
+ pos_maps = []
265
+ for j in range(1, 9):
266
+ x, y, action_type = parse_action_string(action_descriptions[-j])
267
+ pos_map_j, leftclick_map_j, x_scaled_j, y_scaled_j = create_position_and_click_map((x, y), action_type)
268
+ leftclick_maps.append(leftclick_map_j)
269
+ pos_maps.append(pos_map_j)
270
+ if j == 1:
271
+ x_scaled = x_scaled_j
272
+ y_scaled = y_scaled_j
273
 
274
  #prompt = ''
275
  #prompt = "1~1 0~0 0~0 0~0 0~0 0~0 0~0 0~0"
276
  print(prompt)
277
 
278
  # Generate the next frame
279
+ new_frame = sample_frame(model, prompt, image_sequence_tensor, pos_maps=pos_maps, leftclick_maps=leftclick_maps)
280
 
281
  # Convert the generated frame to the correct format
282
  new_frame = new_frame.transpose(1, 2, 0)
utils.py CHANGED
@@ -28,7 +28,7 @@ def load_model_from_config(config_path, model_name, device='cuda'):
28
  model.eval()
29
  return model
30
 
31
- def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tensor, pos_map=None, leftclick_map=None):
32
  sampler = DDIMSampler(model)
33
 
34
  with torch.no_grad():
@@ -46,9 +46,11 @@ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tens
46
  print (image_sequence.shape, padding_mask.shape, c['c_concat'].shape)
47
  c['c_concat'] = c['c_concat'] * (~padding_mask.unsqueeze(-1).unsqueeze(-1)) # Zero out the corresponding features
48
 
49
- if pos_map is not None:
50
- print (pos_map.shape, c['c_concat'].shape)
51
- c['c_concat'] = torch.cat([c['c_concat'][:, :, :, :], pos_map.to(c['c_concat'].device).unsqueeze(0), leftclick_map.to(c['c_concat'].device).unsqueeze(0)], dim=1)
 
 
52
 
53
  print ('sleeping')
54
  #time.sleep(120)
 
28
  model.eval()
29
  return model
30
 
31
+ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tensor, pos_maps=None, leftclick_maps=None):
32
  sampler = DDIMSampler(model)
33
 
34
  with torch.no_grad():
 
46
  print (image_sequence.shape, padding_mask.shape, c['c_concat'].shape)
47
  c['c_concat'] = c['c_concat'] * (~padding_mask.unsqueeze(-1).unsqueeze(-1)) # Zero out the corresponding features
48
 
49
+ if pos_maps is not None:
50
+ pos_map = pos_maps[0]
51
+ leftclick_map = torch.cat(leftclick_maps, dim=0)
52
+ print (pos_maps[0].shape, c['c_concat'].shape, leftclick_map.shape)
53
+ c['c_concat'] = torch.cat([c['c_concat'][:, :, :, :], pos_maps[0].to(c['c_concat'].device).unsqueeze(0), leftclick_maps[0].to(c['c_concat'].device).unsqueeze(0)], dim=1)
54
 
55
  print ('sleeping')
56
  #time.sleep(120)