yuntian-deng commited on
Commit
f9c716f
·
1 Parent(s): 599777e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +18 -22
main.py CHANGED
@@ -80,52 +80,48 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
80
  # Prepare the image sequence for the model
81
  image_sequence = previous_frames[-7:] # Take the last 7 frames
82
  while len(image_sequence) < 7:
83
- #image_sequence.insert(0, np.zeros((height, width, 3), dtype=np.uint8))
84
  image_sequence.insert(0, initial_images[len(image_sequence)])
85
 
86
-
87
  # Convert the image sequence to a tensor and concatenate in the channel dimension
88
  image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence, target_range=(-1, 1)))
89
-
90
- #image_sequence_tensor = torch.from_numpy(np.stack(image_sequence)).float() / 127.5 - 1
91
  image_sequence_tensor = image_sequence_tensor.to(device)
92
 
93
-
94
  # Prepare the prompt based on the previous actions
95
- #action_descriptions = [f"{pos[0]}:{pos[1]}" for _, pos in previous_actions[-7:]]
96
- #prompt = " ".join(action_descriptions)
97
  action_descriptions = []
98
- def norm_x(x):
99
- return x + (1920 - 256) / 2
100
- def norm_y(y):
101
- return y + (1080 - 256) / 2
 
 
 
 
 
 
 
 
 
102
  for action_type, pos in previous_actions[-7:]:
103
  if action_type == "move":
104
- print (pos[0], pos[1])
105
- action_descriptions.append(f"{norm_x(pos[0])}:{norm_y(pos[1])}")
106
-
 
107
  elif action_type == "left_click":
108
  action_descriptions.append("left_click")
109
  elif action_type == "right_click":
110
  action_descriptions.append("right_click")
111
 
112
  prompt = " ".join(action_descriptions)
113
- print (prompt)
114
 
115
  # Generate the next frame
116
  new_frame = sample_frame(model, prompt, image_sequence_tensor)
117
 
118
  # Convert the generated frame to the correct format
119
- #new_frame = (new_frame * 255).astype(np.uint8).transpose(1, 2, 0)
120
  new_frame = new_frame.transpose(1, 2, 0)
121
 
122
-
123
- # Resize the frame to 256x256 if necessary
124
- #if new_frame.shape[:2] != (height, width):
125
- # new_frame = np.array(Image.fromarray(new_frame).resize((width, height)))
126
-
127
  new_frame_denormalized = denormalize_image(new_frame, source_range=(-1, 1))
128
-
129
 
130
  # Draw the trace of previous actions
131
  new_frame_with_trace = draw_trace(new_frame_denormalized, previous_actions)
 
80
  # Prepare the image sequence for the model
81
  image_sequence = previous_frames[-7:] # Take the last 7 frames
82
  while len(image_sequence) < 7:
 
83
  image_sequence.insert(0, initial_images[len(image_sequence)])
84
 
 
85
  # Convert the image sequence to a tensor and concatenate in the channel dimension
86
  image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence, target_range=(-1, 1)))
 
 
87
  image_sequence_tensor = image_sequence_tensor.to(device)
88
 
 
89
  # Prepare the prompt based on the previous actions
 
 
90
  action_descriptions = []
91
+ initial_actions = ['901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '921:604']
92
+
93
+ def unnorm_coords(x, y):
94
+ return int(x - (1920 - 256) / 2), int(y - (1080 - 256) / 2)
95
+
96
+ # Process initial actions if there are not enough previous actions
97
+ while len(previous_actions) < 7:
98
+ if initial_actions:
99
+ x, y = map(int, initial_actions.pop(0).split(':'))
100
+ previous_actions.insert(0, ("move", unnorm_coords(x, y)))
101
+ else:
102
+ break
103
+
104
  for action_type, pos in previous_actions[-7:]:
105
  if action_type == "move":
106
+ x, y = pos
107
+ norm_x = x + (1920 - 256) / 2
108
+ norm_y = y + (1080 - 256) / 2
109
+ action_descriptions.append(f"{norm_x}:{norm_y}")
110
  elif action_type == "left_click":
111
  action_descriptions.append("left_click")
112
  elif action_type == "right_click":
113
  action_descriptions.append("right_click")
114
 
115
  prompt = " ".join(action_descriptions)
116
+ print(prompt)
117
 
118
  # Generate the next frame
119
  new_frame = sample_frame(model, prompt, image_sequence_tensor)
120
 
121
  # Convert the generated frame to the correct format
 
122
  new_frame = new_frame.transpose(1, 2, 0)
123
 
 
 
 
 
 
124
  new_frame_denormalized = denormalize_image(new_frame, source_range=(-1, 1))
 
125
 
126
  # Draw the trace of previous actions
127
  new_frame_with_trace = draw_trace(new_frame_denormalized, previous_actions)