Spaces:
Runtime error
Runtime error
Commit
·
a2d3df0
1
Parent(s):
dab8c4c
Update main.py
Browse files
main.py
CHANGED
|
@@ -10,6 +10,7 @@ import asyncio
|
|
| 10 |
from utils import initialize_model, sample_frame
|
| 11 |
import torch
|
| 12 |
import os
|
|
|
|
| 13 |
|
| 14 |
app = FastAPI()
|
| 15 |
|
|
@@ -106,7 +107,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
| 106 |
x, y = pos
|
| 107 |
norm_x = x + (1920 - 256) / 2
|
| 108 |
norm_y = y + (1080 - 256) / 2
|
| 109 |
-
action_descriptions.append(f"{norm_x}:{norm_y}")
|
| 110 |
elif action_type == "left_click":
|
| 111 |
action_descriptions.append("left_click")
|
| 112 |
elif action_type == "right_click":
|
|
@@ -153,6 +154,9 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 153 |
# Store the actions
|
| 154 |
previous_actions.append((action_type, mouse_position))
|
| 155 |
|
|
|
|
|
|
|
|
|
|
| 156 |
# Predict the next frame based on the previous frames and actions
|
| 157 |
next_frame = predict_next_frame(previous_frames, previous_actions)
|
| 158 |
previous_frames.append(next_frame)
|
|
@@ -163,6 +167,10 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 163 |
img.save(buffered, format="PNG")
|
| 164 |
img_str = base64.b64encode(buffered.getvalue()).decode()
|
| 165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
# Send the generated frame back to the client
|
| 167 |
await websocket.send_json({"image": img_str})
|
| 168 |
|
|
|
|
| 10 |
from utils import initialize_model, sample_frame
|
| 11 |
import torch
|
| 12 |
import os
|
| 13 |
+
import time
|
| 14 |
|
| 15 |
app = FastAPI()
|
| 16 |
|
|
|
|
| 107 |
x, y = pos
|
| 108 |
norm_x = x + (1920 - 256) / 2
|
| 109 |
norm_y = y + (1080 - 256) / 2
|
| 110 |
+
action_descriptions.append(f"{norm_x:.0f}:{norm_y:.0f}")
|
| 111 |
elif action_type == "left_click":
|
| 112 |
action_descriptions.append("left_click")
|
| 113 |
elif action_type == "right_click":
|
|
|
|
| 154 |
# Store the actions
|
| 155 |
previous_actions.append((action_type, mouse_position))
|
| 156 |
|
| 157 |
+
# Log the start time
|
| 158 |
+
start_time = time.time()
|
| 159 |
+
|
| 160 |
# Predict the next frame based on the previous frames and actions
|
| 161 |
next_frame = predict_next_frame(previous_frames, previous_actions)
|
| 162 |
previous_frames.append(next_frame)
|
|
|
|
| 167 |
img.save(buffered, format="PNG")
|
| 168 |
img_str = base64.b64encode(buffered.getvalue()).decode()
|
| 169 |
|
| 170 |
+
# Log the processing time
|
| 171 |
+
processing_time = time.time() - start_time
|
| 172 |
+
print(f"Frame processing time: {processing_time:.2f} seconds")
|
| 173 |
+
|
| 174 |
# Send the generated frame back to the client
|
| 175 |
await websocket.send_json({"image": img_str})
|
| 176 |
|