Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
e858976
1
Parent(s):
042c554
main.py
CHANGED
@@ -14,8 +14,8 @@ import time
|
|
14 |
from typing import Any, Dict
|
15 |
from ldm.models.diffusion.ddpm import LatentDiffusion, DDIMSampler
|
16 |
|
17 |
-
|
18 |
-
|
19 |
SCREEN_WIDTH = 512
|
20 |
SCREEN_HEIGHT = 384
|
21 |
NUM_SAMPLING_STEPS = 8
|
@@ -167,6 +167,11 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
167 |
hidden_states = None
|
168 |
keys_down = set() # Initialize as an empty set
|
169 |
frame_num = -1
|
|
|
|
|
|
|
|
|
|
|
170 |
while True:
|
171 |
try:
|
172 |
# Receive user input with a timeout
|
@@ -176,7 +181,13 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
176 |
await websocket.send_json({"type": "heartbeat_response"})
|
177 |
continue
|
178 |
frame_num += 1
|
|
|
179 |
start_frame = time.perf_counter()
|
|
|
|
|
|
|
|
|
|
|
180 |
x = data.get("x")
|
181 |
y = data.get("y")
|
182 |
is_left_click = data.get("is_left_click")
|
@@ -197,6 +208,9 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
197 |
# Use the provided function to print timing statistics
|
198 |
print_timing_stats(timing_info, frame_num)
|
199 |
|
|
|
|
|
|
|
200 |
img = Image.fromarray(sample_img)
|
201 |
buffered = io.BytesIO()
|
202 |
img.save(buffered, format="PNG")
|
@@ -217,5 +231,13 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
217 |
print(f"Error in WebSocket connection {client_id}: {e}")
|
218 |
|
219 |
finally:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
print(f"WebSocket connection closed: {client_id}")
|
221 |
#await websocket.close() # Ensure the WebSocket is closed
|
|
|
14 |
from typing import Any, Dict
|
15 |
from ldm.models.diffusion.ddpm import LatentDiffusion, DDIMSampler
|
16 |
|
17 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
18 |
+
torch.backends.cudnn.allow_tf32 = True
|
19 |
SCREEN_WIDTH = 512
|
20 |
SCREEN_HEIGHT = 384
|
21 |
NUM_SAMPLING_STEPS = 8
|
|
|
167 |
hidden_states = None
|
168 |
keys_down = set() # Initialize as an empty set
|
169 |
frame_num = -1
|
170 |
+
|
171 |
+
# Start timing for global FPS calculation
|
172 |
+
connection_start_time = time.perf_counter()
|
173 |
+
frame_count = 0
|
174 |
+
|
175 |
while True:
|
176 |
try:
|
177 |
# Receive user input with a timeout
|
|
|
181 |
await websocket.send_json({"type": "heartbeat_response"})
|
182 |
continue
|
183 |
frame_num += 1
|
184 |
+
frame_count += 1 # Increment total frame counter
|
185 |
start_frame = time.perf_counter()
|
186 |
+
|
187 |
+
# Calculate global FPS
|
188 |
+
total_elapsed = start_frame - connection_start_time
|
189 |
+
global_fps = frame_count / total_elapsed if total_elapsed > 0 else 0
|
190 |
+
|
191 |
x = data.get("x")
|
192 |
y = data.get("y")
|
193 |
is_left_click = data.get("is_left_click")
|
|
|
208 |
# Use the provided function to print timing statistics
|
209 |
print_timing_stats(timing_info, frame_num)
|
210 |
|
211 |
+
# Print global FPS measurement
|
212 |
+
print(f" Global FPS: {global_fps:.2f} (total: {frame_count} frames in {total_elapsed:.2f}s)")
|
213 |
+
|
214 |
img = Image.fromarray(sample_img)
|
215 |
buffered = io.BytesIO()
|
216 |
img.save(buffered, format="PNG")
|
|
|
231 |
print(f"Error in WebSocket connection {client_id}: {e}")
|
232 |
|
233 |
finally:
|
234 |
+
# Print final FPS statistics when connection ends
|
235 |
+
if frame_num >= 0: # Only if we processed at least one frame
|
236 |
+
total_time = time.perf_counter() - connection_start_time
|
237 |
+
print(f"\nConnection {client_id} summary:")
|
238 |
+
print(f" Total frames processed: {frame_count}")
|
239 |
+
print(f" Total elapsed time: {total_time:.2f} seconds")
|
240 |
+
print(f" Average FPS: {frame_count/total_time:.2f}")
|
241 |
+
|
242 |
print(f"WebSocket connection closed: {client_id}")
|
243 |
#await websocket.close() # Ensure the WebSocket is closed
|