Spaces:
Runtime error
Runtime error
Commit
·
a2d3df0
1
Parent(s):
dab8c4c
Update main.py
Browse files
main.py
CHANGED
@@ -10,6 +10,7 @@ import asyncio
|
|
10 |
from utils import initialize_model, sample_frame
|
11 |
import torch
|
12 |
import os
|
|
|
13 |
|
14 |
app = FastAPI()
|
15 |
|
@@ -106,7 +107,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
106 |
x, y = pos
|
107 |
norm_x = x + (1920 - 256) / 2
|
108 |
norm_y = y + (1080 - 256) / 2
|
109 |
-
action_descriptions.append(f"{norm_x}:{norm_y}")
|
110 |
elif action_type == "left_click":
|
111 |
action_descriptions.append("left_click")
|
112 |
elif action_type == "right_click":
|
@@ -153,6 +154,9 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
153 |
# Store the actions
|
154 |
previous_actions.append((action_type, mouse_position))
|
155 |
|
|
|
|
|
|
|
156 |
# Predict the next frame based on the previous frames and actions
|
157 |
next_frame = predict_next_frame(previous_frames, previous_actions)
|
158 |
previous_frames.append(next_frame)
|
@@ -163,6 +167,10 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
163 |
img.save(buffered, format="PNG")
|
164 |
img_str = base64.b64encode(buffered.getvalue()).decode()
|
165 |
|
|
|
|
|
|
|
|
|
166 |
# Send the generated frame back to the client
|
167 |
await websocket.send_json({"image": img_str})
|
168 |
|
|
|
10 |
from utils import initialize_model, sample_frame
|
11 |
import torch
|
12 |
import os
|
13 |
+
import time
|
14 |
|
15 |
app = FastAPI()
|
16 |
|
|
|
107 |
x, y = pos
|
108 |
norm_x = x + (1920 - 256) / 2
|
109 |
norm_y = y + (1080 - 256) / 2
|
110 |
+
action_descriptions.append(f"{norm_x:.0f}:{norm_y:.0f}")
|
111 |
elif action_type == "left_click":
|
112 |
action_descriptions.append("left_click")
|
113 |
elif action_type == "right_click":
|
|
|
154 |
# Store the actions
|
155 |
previous_actions.append((action_type, mouse_position))
|
156 |
|
157 |
+
# Log the start time
|
158 |
+
start_time = time.time()
|
159 |
+
|
160 |
# Predict the next frame based on the previous frames and actions
|
161 |
next_frame = predict_next_frame(previous_frames, previous_actions)
|
162 |
previous_frames.append(next_frame)
|
|
|
167 |
img.save(buffered, format="PNG")
|
168 |
img_str = base64.b64encode(buffered.getvalue()).decode()
|
169 |
|
170 |
+
# Log the processing time
|
171 |
+
processing_time = time.time() - start_time
|
172 |
+
print(f"Frame processing time: {processing_time:.2f} seconds")
|
173 |
+
|
174 |
# Send the generated frame back to the client
|
175 |
await websocket.send_json({"image": img_str})
|
176 |
|