yuntian-deng commited on
Commit
9df27df
·
1 Parent(s): 8bb7641

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +4 -4
main.py CHANGED
@@ -96,7 +96,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
96
  return int(x), int(y) #int(x - (1920 - 256) / 2), int(y - (1080 - 256) / 2)
97
 
98
  # Process initial actions if there are not enough previous actions
99
- while len(previous_actions) < 7:
100
  if initial_actions:
101
  x, y = map(int, initial_actions.pop(0).split(':'))
102
  previous_actions.insert(0, ("move", unnorm_coords(x, y)))
@@ -128,7 +128,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
128
  # Draw the trace of previous actions
129
  new_frame_with_trace = draw_trace(new_frame_denormalized, previous_actions)
130
 
131
- return new_frame_with_trace
132
 
133
  # WebSocket endpoint for continuous user interaction
134
  @app.websocket("/ws")
@@ -159,8 +159,8 @@ async def websocket_endpoint(websocket: WebSocket):
159
  start_time = time.time()
160
 
161
  # Predict the next frame based on the previous frames and actions
162
- next_frame = predict_next_frame(previous_frames, previous_actions)
163
- previous_frames.append(next_frame)
164
 
165
  # Convert the numpy array to a base64 encoded image
166
  img = Image.fromarray(next_frame)
 
96
  return int(x), int(y) #int(x - (1920 - 256) / 2), int(y - (1080 - 256) / 2)
97
 
98
  # Process initial actions if there are not enough previous actions
99
+ while len(previous_actions) < 8:
100
  if initial_actions:
101
  x, y = map(int, initial_actions.pop(0).split(':'))
102
  previous_actions.insert(0, ("move", unnorm_coords(x, y)))
 
128
  # Draw the trace of previous actions
129
  new_frame_with_trace = draw_trace(new_frame_denormalized, previous_actions)
130
 
131
+ return new_frame_with_trace, new_frame_denormalized
132
 
133
  # WebSocket endpoint for continuous user interaction
134
  @app.websocket("/ws")
 
159
  start_time = time.time()
160
 
161
  # Predict the next frame based on the previous frames and actions
162
+ next_frame, next_frame_append = predict_next_frame(previous_frames, previous_actions)
163
+ previous_frames.append(next_frame_append)
164
 
165
  # Convert the numpy array to a base64 encoded image
166
  img = Image.fromarray(next_frame)