da03 commited on
Commit
dfecf95
·
1 Parent(s): 03af1a4
Files changed (1) hide show
  1. main.py +4 -1
main.py CHANGED
@@ -22,6 +22,7 @@ torch.backends.cudnn.allow_tf32 = True
22
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
 
24
 
 
25
 
26
  SCREEN_WIDTH = 512
27
  SCREEN_HEIGHT = 384
@@ -113,7 +114,6 @@ def prepare_model_inputs(
113
 
114
  if hidden_states is not None:
115
  inputs['hidden_states'] = hidden_states
116
- DEBUG_MODE = True
117
  if DEBUG_MODE:
118
  print ('DEBUG MODE, REMOVING INPUTS')
119
  if 'hidden_states' in inputs:
@@ -245,6 +245,9 @@ async def websocket_endpoint(websocket: WebSocket):
245
  inputs = prepare_model_inputs(previous_frame, hidden_states, x, y, is_right_click, is_left_click, list(keys_down), stoi, itos, frame_num)
246
  print(f"[{time.perf_counter():.3f}] Starting model inference...")
247
  previous_frame, sample_img, hidden_states, timing_info = await process_frame(model, inputs)
 
 
 
248
  timing_info['full_frame'] = time.perf_counter() - process_start_time
249
 
250
  print(f"[{time.perf_counter():.3f}] Model inference complete. Queue size now: {input_queue.qsize()}")
 
22
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
 
24
 
25
+ DEBUG_MODE = True
26
 
27
  SCREEN_WIDTH = 512
28
  SCREEN_HEIGHT = 384
 
114
 
115
  if hidden_states is not None:
116
  inputs['hidden_states'] = hidden_states
 
117
  if DEBUG_MODE:
118
  print ('DEBUG MODE, REMOVING INPUTS')
119
  if 'hidden_states' in inputs:
 
245
  inputs = prepare_model_inputs(previous_frame, hidden_states, x, y, is_right_click, is_left_click, list(keys_down), stoi, itos, frame_num)
246
  print(f"[{time.perf_counter():.3f}] Starting model inference...")
247
  previous_frame, sample_img, hidden_states, timing_info = await process_frame(model, inputs)
248
+ if DEBUG_MODE:
249
+ print (f"DEBUG MODE, REMOVING HIDDEN STATES")
250
+ previous_frame = padding_image
251
  timing_info['full_frame'] = time.perf_counter() - process_start_time
252
 
253
  print(f"[{time.perf_counter():.3f}] Model inference complete. Queue size now: {input_queue.qsize()}")