Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
313bb52
1
Parent(s):
94b146f
main.py
CHANGED
@@ -203,23 +203,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
203 |
prev_y = 0
|
204 |
#print ('here')
|
205 |
|
206 |
-
|
207 |
-
#print ('here2')
|
208 |
-
# Use the predefined actions for image_81
|
209 |
-
debug_actions = [
|
210 |
-
'N + 0 8 5 3 : + 0 4 5 0', 'N + 0 8 7 1 : + 0 4 6 3',
|
211 |
-
'N + 0 8 9 0 : + 0 4 7 5', 'N + 0 9 0 8 : + 0 4 8 8',
|
212 |
-
'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
213 |
-
'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
214 |
-
'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
215 |
-
'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
216 |
-
'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
217 |
-
'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1'
|
218 |
-
]
|
219 |
-
previous_actions = []
|
220 |
-
for action in debug_actions[-8:]:
|
221 |
-
x, y, action_type = parse_action_string(action)
|
222 |
-
previous_actions.append((action_type, (x, y)))
|
223 |
|
224 |
for action_type, pos in previous_actions: #[-8:]:
|
225 |
print ('here3', action_type, pos)
|
@@ -302,6 +286,31 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
302 |
positions = ['815~335']
|
303 |
#positions = ['787~342']
|
304 |
positions = ['300~800']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
#positions = positions[:4]
|
306 |
try:
|
307 |
while True:
|
@@ -325,9 +334,13 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
325 |
#mouse_position = position.split('~')
|
326 |
#mouse_position = [int(item) for item in mouse_position]
|
327 |
#mouse_position = '+ 0 8 1 5 : + 0 3 3 5'
|
328 |
-
|
|
|
|
|
|
|
|
|
329 |
previous_actions.append((action_type, mouse_position))
|
330 |
-
previous_actions = [(action_type, mouse_position)]
|
331 |
|
332 |
# Log the start time
|
333 |
start_time = time.time()
|
@@ -336,7 +349,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
336 |
next_frame, next_frame_append = predict_next_frame(previous_frames, previous_actions)
|
337 |
# Load and append the corresponding ground truth image instead of model output
|
338 |
#img = Image.open(f"image_{len(previous_frames)%7}.png")
|
339 |
-
|
340 |
|
341 |
# Convert the numpy array to a base64 encoded image
|
342 |
img = Image.fromarray(next_frame)
|
|
|
203 |
prev_y = 0
|
204 |
#print ('here')
|
205 |
|
206 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
|
208 |
for action_type, pos in previous_actions: #[-8:]:
|
209 |
print ('here3', action_type, pos)
|
|
|
286 |
positions = ['815~335']
|
287 |
#positions = ['787~342']
|
288 |
positions = ['300~800']
|
289 |
+
|
290 |
+
if DEBUG_TEACHER_FORCING:
|
291 |
+
#print ('here2')
|
292 |
+
# Use the predefined actions for image_81
|
293 |
+
debug_actions = [
|
294 |
+
'N + 0 8 5 3 : + 0 4 5 0', 'N + 0 8 7 1 : + 0 4 6 3',
|
295 |
+
'N + 0 8 9 0 : + 0 4 7 5', 'N + 0 9 0 8 : + 0 4 8 8',
|
296 |
+
'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
297 |
+
'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
298 |
+
'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
299 |
+
'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
300 |
+
'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
301 |
+
'N + 0 9 2 7 : + 0 5 0 1', #'N + 0 9 2 7 : + 0 5 0 1'
|
302 |
+
]
|
303 |
+
previous_actions = []
|
304 |
+
for action in debug_actions[-8:]:
|
305 |
+
x, y, action_type = parse_action_string(action)
|
306 |
+
previous_actions.append((action_type, (x, y)))
|
307 |
+
positions = [
|
308 |
+
'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 1 8 : + 0 4 9 2',
|
309 |
+
'N + 0 9 0 8 : + 0 4 8 3', 'N + 0 8 9 8 : + 0 4 7 4',
|
310 |
+
'N + 0 8 8 9 : + 0 4 6 5', 'N + 0 8 8 0 : + 0 4 5 6',
|
311 |
+
'N + 0 8 7 0 : + 0 4 4 7', 'N + 0 8 6 0 : + 0 4 3 8',
|
312 |
+
'N + 0 8 5 1 : + 0 4 2 9', 'N + 0 8 4 2 : + 0 4 2 0',
|
313 |
+
'N + 0 8 3 2 : + 0 4 1 1', 'N + 0 8 3 2 : + 0 4 1 1']
|
314 |
#positions = positions[:4]
|
315 |
try:
|
316 |
while True:
|
|
|
334 |
#mouse_position = position.split('~')
|
335 |
#mouse_position = [int(item) for item in mouse_position]
|
336 |
#mouse_position = '+ 0 8 1 5 : + 0 3 3 5'
|
337 |
+
if DEBUG_TEACHER_FORCING:
|
338 |
+
position = positions[0]
|
339 |
+
positions = positions[1:]
|
340 |
+
x, y, action_type = parse_action_string(position)
|
341 |
+
mouse_position = (x, y)
|
342 |
previous_actions.append((action_type, mouse_position))
|
343 |
+
#previous_actions = [(action_type, mouse_position)]
|
344 |
|
345 |
# Log the start time
|
346 |
start_time = time.time()
|
|
|
349 |
next_frame, next_frame_append = predict_next_frame(previous_frames, previous_actions)
|
350 |
# Load and append the corresponding ground truth image instead of model output
|
351 |
#img = Image.open(f"image_{len(previous_frames)%7}.png")
|
352 |
+
previous_frames.append(next_frame_append)
|
353 |
|
354 |
# Convert the numpy array to a base64 encoded image
|
355 |
img = Image.fromarray(next_frame)
|