da03 commited on
Commit
e858976
·
1 Parent(s): 042c554
Files changed (1) hide show
  1. main.py +24 -2
main.py CHANGED
@@ -14,8 +14,8 @@ import time
14
  from typing import Any, Dict
15
  from ldm.models.diffusion.ddpm import LatentDiffusion, DDIMSampler
16
 
17
- #torch.backends.cuda.matmul.allow_tf32 = True
18
- #torch.backends.cudnn.allow_tf32 = True
19
  SCREEN_WIDTH = 512
20
  SCREEN_HEIGHT = 384
21
  NUM_SAMPLING_STEPS = 8
@@ -167,6 +167,11 @@ async def websocket_endpoint(websocket: WebSocket):
167
  hidden_states = None
168
  keys_down = set() # Initialize as an empty set
169
  frame_num = -1
 
 
 
 
 
170
  while True:
171
  try:
172
  # Receive user input with a timeout
@@ -176,7 +181,13 @@ async def websocket_endpoint(websocket: WebSocket):
176
  await websocket.send_json({"type": "heartbeat_response"})
177
  continue
178
  frame_num += 1
 
179
  start_frame = time.perf_counter()
 
 
 
 
 
180
  x = data.get("x")
181
  y = data.get("y")
182
  is_left_click = data.get("is_left_click")
@@ -197,6 +208,9 @@ async def websocket_endpoint(websocket: WebSocket):
197
  # Use the provided function to print timing statistics
198
  print_timing_stats(timing_info, frame_num)
199
 
 
 
 
200
  img = Image.fromarray(sample_img)
201
  buffered = io.BytesIO()
202
  img.save(buffered, format="PNG")
@@ -217,5 +231,13 @@ async def websocket_endpoint(websocket: WebSocket):
217
  print(f"Error in WebSocket connection {client_id}: {e}")
218
 
219
  finally:
 
 
 
 
 
 
 
 
220
  print(f"WebSocket connection closed: {client_id}")
221
  #await websocket.close() # Ensure the WebSocket is closed
 
14
  from typing import Any, Dict
15
  from ldm.models.diffusion.ddpm import LatentDiffusion, DDIMSampler
16
 
17
+ torch.backends.cuda.matmul.allow_tf32 = True
18
+ torch.backends.cudnn.allow_tf32 = True
19
  SCREEN_WIDTH = 512
20
  SCREEN_HEIGHT = 384
21
  NUM_SAMPLING_STEPS = 8
 
167
  hidden_states = None
168
  keys_down = set() # Initialize as an empty set
169
  frame_num = -1
170
+
171
+ # Start timing for global FPS calculation
172
+ connection_start_time = time.perf_counter()
173
+ frame_count = 0
174
+
175
  while True:
176
  try:
177
  # Receive user input with a timeout
 
181
  await websocket.send_json({"type": "heartbeat_response"})
182
  continue
183
  frame_num += 1
184
+ frame_count += 1 # Increment total frame counter
185
  start_frame = time.perf_counter()
186
+
187
+ # Calculate global FPS
188
+ total_elapsed = start_frame - connection_start_time
189
+ global_fps = frame_count / total_elapsed if total_elapsed > 0 else 0
190
+
191
  x = data.get("x")
192
  y = data.get("y")
193
  is_left_click = data.get("is_left_click")
 
208
  # Use the provided function to print timing statistics
209
  print_timing_stats(timing_info, frame_num)
210
 
211
+ # Print global FPS measurement
212
+ print(f" Global FPS: {global_fps:.2f} (total: {frame_count} frames in {total_elapsed:.2f}s)")
213
+
214
  img = Image.fromarray(sample_img)
215
  buffered = io.BytesIO()
216
  img.save(buffered, format="PNG")
 
231
  print(f"Error in WebSocket connection {client_id}: {e}")
232
 
233
  finally:
234
+ # Print final FPS statistics when connection ends
235
+ if frame_num >= 0: # Only if we processed at least one frame
236
+ total_time = time.perf_counter() - connection_start_time
237
+ print(f"\nConnection {client_id} summary:")
238
+ print(f" Total frames processed: {frame_count}")
239
+ print(f" Total elapsed time: {total_time:.2f} seconds")
240
+ print(f" Average FPS: {frame_count/total_time:.2f}")
241
+
242
  print(f"WebSocket connection closed: {client_id}")
243
  #await websocket.close() # Ensure the WebSocket is closed