Spaces:
Runtime error
Runtime error
Commit
·
f9c716f
1
Parent(s):
599777e
Update main.py
Browse files
main.py
CHANGED
|
@@ -80,52 +80,48 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
| 80 |
# Prepare the image sequence for the model
|
| 81 |
image_sequence = previous_frames[-7:] # Take the last 7 frames
|
| 82 |
while len(image_sequence) < 7:
|
| 83 |
-
#image_sequence.insert(0, np.zeros((height, width, 3), dtype=np.uint8))
|
| 84 |
image_sequence.insert(0, initial_images[len(image_sequence)])
|
| 85 |
|
| 86 |
-
|
| 87 |
# Convert the image sequence to a tensor and concatenate in the channel dimension
|
| 88 |
image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence, target_range=(-1, 1)))
|
| 89 |
-
|
| 90 |
-
#image_sequence_tensor = torch.from_numpy(np.stack(image_sequence)).float() / 127.5 - 1
|
| 91 |
image_sequence_tensor = image_sequence_tensor.to(device)
|
| 92 |
|
| 93 |
-
|
| 94 |
# Prepare the prompt based on the previous actions
|
| 95 |
-
#action_descriptions = [f"{pos[0]}:{pos[1]}" for _, pos in previous_actions[-7:]]
|
| 96 |
-
#prompt = " ".join(action_descriptions)
|
| 97 |
action_descriptions = []
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
def
|
| 101 |
-
return y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
for action_type, pos in previous_actions[-7:]:
|
| 103 |
if action_type == "move":
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
|
|
|
| 107 |
elif action_type == "left_click":
|
| 108 |
action_descriptions.append("left_click")
|
| 109 |
elif action_type == "right_click":
|
| 110 |
action_descriptions.append("right_click")
|
| 111 |
|
| 112 |
prompt = " ".join(action_descriptions)
|
| 113 |
-
print
|
| 114 |
|
| 115 |
# Generate the next frame
|
| 116 |
new_frame = sample_frame(model, prompt, image_sequence_tensor)
|
| 117 |
|
| 118 |
# Convert the generated frame to the correct format
|
| 119 |
-
#new_frame = (new_frame * 255).astype(np.uint8).transpose(1, 2, 0)
|
| 120 |
new_frame = new_frame.transpose(1, 2, 0)
|
| 121 |
|
| 122 |
-
|
| 123 |
-
# Resize the frame to 256x256 if necessary
|
| 124 |
-
#if new_frame.shape[:2] != (height, width):
|
| 125 |
-
# new_frame = np.array(Image.fromarray(new_frame).resize((width, height)))
|
| 126 |
-
|
| 127 |
new_frame_denormalized = denormalize_image(new_frame, source_range=(-1, 1))
|
| 128 |
-
|
| 129 |
|
| 130 |
# Draw the trace of previous actions
|
| 131 |
new_frame_with_trace = draw_trace(new_frame_denormalized, previous_actions)
|
|
|
|
| 80 |
# Prepare the image sequence for the model
|
| 81 |
image_sequence = previous_frames[-7:] # Take the last 7 frames
|
| 82 |
while len(image_sequence) < 7:
|
|
|
|
| 83 |
image_sequence.insert(0, initial_images[len(image_sequence)])
|
| 84 |
|
|
|
|
| 85 |
# Convert the image sequence to a tensor and concatenate in the channel dimension
|
| 86 |
image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence, target_range=(-1, 1)))
|
|
|
|
|
|
|
| 87 |
image_sequence_tensor = image_sequence_tensor.to(device)
|
| 88 |
|
|
|
|
| 89 |
# Prepare the prompt based on the previous actions
|
|
|
|
|
|
|
| 90 |
action_descriptions = []
|
| 91 |
+
initial_actions = ['901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '921:604']
|
| 92 |
+
|
| 93 |
+
def unnorm_coords(x, y):
|
| 94 |
+
return int(x - (1920 - 256) / 2), int(y - (1080 - 256) / 2)
|
| 95 |
+
|
| 96 |
+
# Process initial actions if there are not enough previous actions
|
| 97 |
+
while len(previous_actions) < 7:
|
| 98 |
+
if initial_actions:
|
| 99 |
+
x, y = map(int, initial_actions.pop(0).split(':'))
|
| 100 |
+
previous_actions.insert(0, ("move", unnorm_coords(x, y)))
|
| 101 |
+
else:
|
| 102 |
+
break
|
| 103 |
+
|
| 104 |
for action_type, pos in previous_actions[-7:]:
|
| 105 |
if action_type == "move":
|
| 106 |
+
x, y = pos
|
| 107 |
+
norm_x = x + (1920 - 256) / 2
|
| 108 |
+
norm_y = y + (1080 - 256) / 2
|
| 109 |
+
action_descriptions.append(f"{norm_x}:{norm_y}")
|
| 110 |
elif action_type == "left_click":
|
| 111 |
action_descriptions.append("left_click")
|
| 112 |
elif action_type == "right_click":
|
| 113 |
action_descriptions.append("right_click")
|
| 114 |
|
| 115 |
prompt = " ".join(action_descriptions)
|
| 116 |
+
print(prompt)
|
| 117 |
|
| 118 |
# Generate the next frame
|
| 119 |
new_frame = sample_frame(model, prompt, image_sequence_tensor)
|
| 120 |
|
| 121 |
# Convert the generated frame to the correct format
|
|
|
|
| 122 |
new_frame = new_frame.transpose(1, 2, 0)
|
| 123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
new_frame_denormalized = denormalize_image(new_frame, source_range=(-1, 1))
|
|
|
|
| 125 |
|
| 126 |
# Draw the trace of previous actions
|
| 127 |
new_frame_with_trace = draw_trace(new_frame_denormalized, previous_actions)
|