yuntian-deng commited on
Commit
a2d3df0
·
1 Parent(s): dab8c4c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +9 -1
main.py CHANGED
@@ -10,6 +10,7 @@ import asyncio
10
  from utils import initialize_model, sample_frame
11
  import torch
12
  import os
 
13
 
14
  app = FastAPI()
15
 
@@ -106,7 +107,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
106
  x, y = pos
107
  norm_x = x + (1920 - 256) / 2
108
  norm_y = y + (1080 - 256) / 2
109
- action_descriptions.append(f"{norm_x}:{norm_y}")
110
  elif action_type == "left_click":
111
  action_descriptions.append("left_click")
112
  elif action_type == "right_click":
@@ -153,6 +154,9 @@ async def websocket_endpoint(websocket: WebSocket):
153
  # Store the actions
154
  previous_actions.append((action_type, mouse_position))
155
 
 
 
 
156
  # Predict the next frame based on the previous frames and actions
157
  next_frame = predict_next_frame(previous_frames, previous_actions)
158
  previous_frames.append(next_frame)
@@ -163,6 +167,10 @@ async def websocket_endpoint(websocket: WebSocket):
163
  img.save(buffered, format="PNG")
164
  img_str = base64.b64encode(buffered.getvalue()).decode()
165
 
 
 
 
 
166
  # Send the generated frame back to the client
167
  await websocket.send_json({"image": img_str})
168
 
 
10
  from utils import initialize_model, sample_frame
11
  import torch
12
  import os
13
+ import time
14
 
15
  app = FastAPI()
16
 
 
107
  x, y = pos
108
  norm_x = x + (1920 - 256) / 2
109
  norm_y = y + (1080 - 256) / 2
110
+ action_descriptions.append(f"{norm_x:.0f}:{norm_y:.0f}")
111
  elif action_type == "left_click":
112
  action_descriptions.append("left_click")
113
  elif action_type == "right_click":
 
154
  # Store the actions
155
  previous_actions.append((action_type, mouse_position))
156
 
157
+ # Log the start time
158
+ start_time = time.time()
159
+
160
  # Predict the next frame based on the previous frames and actions
161
  next_frame = predict_next_frame(previous_frames, previous_actions)
162
  previous_frames.append(next_frame)
 
167
  img.save(buffered, format="PNG")
168
  img_str = base64.b64encode(buffered.getvalue()).decode()
169
 
170
+ # Log the processing time
171
+ processing_time = time.time() - start_time
172
+ print(f"Frame processing time: {processing_time:.2f} seconds")
173
+
174
  # Send the generated frame back to the client
175
  await websocket.send_json({"image": img_str})
176