Spaces:
Runtime error
Runtime error
Commit
·
f9c716f
1
Parent(s):
599777e
Update main.py
Browse files
main.py
CHANGED
@@ -80,52 +80,48 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
80 |
# Prepare the image sequence for the model
|
81 |
image_sequence = previous_frames[-7:] # Take the last 7 frames
|
82 |
while len(image_sequence) < 7:
|
83 |
-
#image_sequence.insert(0, np.zeros((height, width, 3), dtype=np.uint8))
|
84 |
image_sequence.insert(0, initial_images[len(image_sequence)])
|
85 |
|
86 |
-
|
87 |
# Convert the image sequence to a tensor and concatenate in the channel dimension
|
88 |
image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence, target_range=(-1, 1)))
|
89 |
-
|
90 |
-
#image_sequence_tensor = torch.from_numpy(np.stack(image_sequence)).float() / 127.5 - 1
|
91 |
image_sequence_tensor = image_sequence_tensor.to(device)
|
92 |
|
93 |
-
|
94 |
# Prepare the prompt based on the previous actions
|
95 |
-
#action_descriptions = [f"{pos[0]}:{pos[1]}" for _, pos in previous_actions[-7:]]
|
96 |
-
#prompt = " ".join(action_descriptions)
|
97 |
action_descriptions = []
|
98 |
-
|
99 |
-
|
100 |
-
def
|
101 |
-
return y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
for action_type, pos in previous_actions[-7:]:
|
103 |
if action_type == "move":
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
107 |
elif action_type == "left_click":
|
108 |
action_descriptions.append("left_click")
|
109 |
elif action_type == "right_click":
|
110 |
action_descriptions.append("right_click")
|
111 |
|
112 |
prompt = " ".join(action_descriptions)
|
113 |
-
print
|
114 |
|
115 |
# Generate the next frame
|
116 |
new_frame = sample_frame(model, prompt, image_sequence_tensor)
|
117 |
|
118 |
# Convert the generated frame to the correct format
|
119 |
-
#new_frame = (new_frame * 255).astype(np.uint8).transpose(1, 2, 0)
|
120 |
new_frame = new_frame.transpose(1, 2, 0)
|
121 |
|
122 |
-
|
123 |
-
# Resize the frame to 256x256 if necessary
|
124 |
-
#if new_frame.shape[:2] != (height, width):
|
125 |
-
# new_frame = np.array(Image.fromarray(new_frame).resize((width, height)))
|
126 |
-
|
127 |
new_frame_denormalized = denormalize_image(new_frame, source_range=(-1, 1))
|
128 |
-
|
129 |
|
130 |
# Draw the trace of previous actions
|
131 |
new_frame_with_trace = draw_trace(new_frame_denormalized, previous_actions)
|
|
|
80 |
# Prepare the image sequence for the model
|
81 |
image_sequence = previous_frames[-7:] # Take the last 7 frames
|
82 |
while len(image_sequence) < 7:
|
|
|
83 |
image_sequence.insert(0, initial_images[len(image_sequence)])
|
84 |
|
|
|
85 |
# Convert the image sequence to a tensor and concatenate in the channel dimension
|
86 |
image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence, target_range=(-1, 1)))
|
|
|
|
|
87 |
image_sequence_tensor = image_sequence_tensor.to(device)
|
88 |
|
|
|
89 |
# Prepare the prompt based on the previous actions
|
|
|
|
|
90 |
action_descriptions = []
|
91 |
+
initial_actions = ['901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '921:604']
|
92 |
+
|
93 |
+
def unnorm_coords(x, y):
|
94 |
+
return int(x - (1920 - 256) / 2), int(y - (1080 - 256) / 2)
|
95 |
+
|
96 |
+
# Process initial actions if there are not enough previous actions
|
97 |
+
while len(previous_actions) < 7:
|
98 |
+
if initial_actions:
|
99 |
+
x, y = map(int, initial_actions.pop(0).split(':'))
|
100 |
+
previous_actions.insert(0, ("move", unnorm_coords(x, y)))
|
101 |
+
else:
|
102 |
+
break
|
103 |
+
|
104 |
for action_type, pos in previous_actions[-7:]:
|
105 |
if action_type == "move":
|
106 |
+
x, y = pos
|
107 |
+
norm_x = x + (1920 - 256) / 2
|
108 |
+
norm_y = y + (1080 - 256) / 2
|
109 |
+
action_descriptions.append(f"{norm_x}:{norm_y}")
|
110 |
elif action_type == "left_click":
|
111 |
action_descriptions.append("left_click")
|
112 |
elif action_type == "right_click":
|
113 |
action_descriptions.append("right_click")
|
114 |
|
115 |
prompt = " ".join(action_descriptions)
|
116 |
+
print(prompt)
|
117 |
|
118 |
# Generate the next frame
|
119 |
new_frame = sample_frame(model, prompt, image_sequence_tensor)
|
120 |
|
121 |
# Convert the generated frame to the correct format
|
|
|
122 |
new_frame = new_frame.transpose(1, 2, 0)
|
123 |
|
|
|
|
|
|
|
|
|
|
|
124 |
new_frame_denormalized = denormalize_image(new_frame, source_range=(-1, 1))
|
|
|
125 |
|
126 |
# Draw the trace of previous actions
|
127 |
new_frame_with_trace = draw_trace(new_frame_denormalized, previous_actions)
|