Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
50eea75
1
Parent(s):
a10a91e
main.py
CHANGED
@@ -13,6 +13,7 @@ import os
|
|
13 |
import time
|
14 |
|
15 |
DEBUG = True
|
|
|
16 |
app = FastAPI()
|
17 |
|
18 |
# Mount the static directory to serve HTML, JavaScript, and CSS files
|
@@ -128,15 +129,14 @@ model = model.to(device)
|
|
128 |
|
129 |
def load_initial_images(width, height):
|
130 |
initial_images = []
|
131 |
-
|
132 |
-
|
133 |
-
#
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
# initial_images.append(np.zeros((height, width, 3), dtype=np.uint8))
|
140 |
return initial_images
|
141 |
|
142 |
def normalize_images(images, target_range=(-1, 1)):
|
@@ -156,13 +156,15 @@ def denormalize_image(image, source_range=(-1, 1)):
|
|
156 |
else:
|
157 |
raise ValueError(f"Unsupported source range: {source_range}")
|
158 |
|
159 |
-
def format_action(action_str, is_padding=False):
|
160 |
if is_padding:
|
161 |
return "N N N N N N : N N N N N"
|
162 |
|
163 |
# Split the x~y coordinates
|
164 |
x, y = map(int, action_str.split('~'))
|
165 |
prefix = 'N'
|
|
|
|
|
166 |
# Convert numbers to padded strings and add spaces between digits
|
167 |
x_str = f"{abs(x):04d}"
|
168 |
y_str = f"{abs(y):04d}"
|
@@ -200,6 +202,22 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
200 |
prev_x = 0
|
201 |
prev_y = 0
|
202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
|
204 |
for action_type, pos in previous_actions: #[-8:]:
|
205 |
if action_type == "move":
|
@@ -217,7 +235,17 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
217 |
prev_x = norm_x
|
218 |
prev_y = norm_y
|
219 |
elif action_type == "left_click":
|
220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
elif action_type == "right_click":
|
222 |
action_descriptions.append("right_click")
|
223 |
|
|
|
13 |
import time
|
14 |
|
15 |
DEBUG = True
|
16 |
+
DEBUG_TEACHER_FORCING = True
|
17 |
app = FastAPI()
|
18 |
|
19 |
# Mount the static directory to serve HTML, JavaScript, and CSS files
|
|
|
129 |
|
130 |
def load_initial_images(width, height):
|
131 |
initial_images = []
|
132 |
+
if DEBUG_TEACHER_FORCING:
|
133 |
+
# Load the previous 7 frames for image_81
|
134 |
+
for i in range(74, 81): # Load images 74-80
|
135 |
+
img = Image.open(f"record_100/image_{i}.png").resize((width, height))
|
136 |
+
initial_images.append(np.array(img))
|
137 |
+
else:
|
138 |
+
for i in range(7):
|
139 |
+
initial_images.append(np.zeros((height, width, 3), dtype=np.uint8))
|
|
|
140 |
return initial_images
|
141 |
|
142 |
def normalize_images(images, target_range=(-1, 1)):
|
|
|
156 |
else:
|
157 |
raise ValueError(f"Unsupported source range: {source_range}")
|
158 |
|
159 |
+
def format_action(action_str, is_padding=False, is_leftclick=False):
|
160 |
if is_padding:
|
161 |
return "N N N N N N : N N N N N"
|
162 |
|
163 |
# Split the x~y coordinates
|
164 |
x, y = map(int, action_str.split('~'))
|
165 |
prefix = 'N'
|
166 |
+
if is_leftclick:
|
167 |
+
prefix = 'L'
|
168 |
# Convert numbers to padded strings and add spaces between digits
|
169 |
x_str = f"{abs(x):04d}"
|
170 |
y_str = f"{abs(y):04d}"
|
|
|
202 |
prev_x = 0
|
203 |
prev_y = 0
|
204 |
|
205 |
+
if DEBUG_TEACHER_FORCING:
|
206 |
+
# Use the predefined actions for image_81
|
207 |
+
debug_actions = [
|
208 |
+
'N + 0 8 5 3 : + 0 4 5 0', 'N + 0 8 7 1 : + 0 4 6 3',
|
209 |
+
'N + 0 8 9 0 : + 0 4 7 5', 'N + 0 9 0 8 : + 0 4 8 8',
|
210 |
+
'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
211 |
+
'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
212 |
+
'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
213 |
+
'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
214 |
+
'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
215 |
+
'N + 0 9 2 7 : + 0 5 0 1'
|
216 |
+
]
|
217 |
+
previous_actions = []
|
218 |
+
for action in debug_actions:
|
219 |
+
x, y, action_type = parse_action_string(action)
|
220 |
+
previous_actions.append((action_type, (x, y)))
|
221 |
|
222 |
for action_type, pos in previous_actions: #[-8:]:
|
223 |
if action_type == "move":
|
|
|
235 |
prev_x = norm_x
|
236 |
prev_y = norm_y
|
237 |
elif action_type == "left_click":
|
238 |
+
x, y = pos
|
239 |
+
#norm_x = int(round(x / 256 * 1024)) #x + (1920 - 256) / 2
|
240 |
+
#norm_y = int(round(y / 256 * 640)) #y + (1080 - 256) / 2
|
241 |
+
norm_x = x + (1920 - 512) / 2
|
242 |
+
norm_y = y + (1080 - 512) / 2
|
243 |
+
#if DEBUG:
|
244 |
+
# norm_x = x
|
245 |
+
# norm_y = y
|
246 |
+
#action_descriptions.append(f"{(norm_x-prev_x):.0f}~{(norm_y-prev_y):.0f}")
|
247 |
+
#action_descriptions.append(format_action(f'{norm_x-prev_x:.0f}~{norm_y-prev_y:.0f}', x==0 and y==0))
|
248 |
+
action_descriptions.append(format_action(f'{norm_x:.0f}~{norm_y:.0f}', x==0 and y==0, True))
|
249 |
elif action_type == "right_click":
|
250 |
action_descriptions.append("right_click")
|
251 |
|