da03 commited on
Commit
e519ed0
·
1 Parent(s): 8594846
Files changed (1) hide show
  1. main.py +18 -4
main.py CHANGED
@@ -23,6 +23,7 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
 
24
 
25
  DEBUG_MODE = False
 
26
 
27
  SCREEN_WIDTH = 512
28
  SCREEN_HEIGHT = 384
@@ -101,6 +102,10 @@ def prepare_model_inputs(
101
  if DEBUG_MODE:
102
  print ('DEBUG MODE, SETTING TIME STEP TO 0')
103
  time_step = 0
 
 
 
 
104
 
105
  inputs = {
106
  'image_features': previous_frame.to(device),
@@ -121,6 +126,11 @@ def prepare_model_inputs(
121
  print ('DEBUG MODE, REMOVING INPUTS')
122
  if 'hidden_states' in inputs:
123
  del inputs['hidden_states']
 
 
 
 
 
124
  return inputs
125
 
126
  @torch.no_grad()
@@ -244,13 +254,17 @@ async def websocket_endpoint(websocket: WebSocket):
244
  for key in keys_up_list:
245
  if key in keys_down: # Check if key exists to avoid KeyError
246
  keys_down.remove(key)
247
-
248
- inputs = prepare_model_inputs(previous_frame, hidden_states, x, y, is_right_click, is_left_click, list(keys_down), stoi, itos, frame_num)
249
- print(f"[{time.perf_counter():.3f}] Starting model inference...")
250
- previous_frame, sample_img, hidden_states, timing_info = await process_frame(model, inputs)
251
  if DEBUG_MODE:
252
  print (f"DEBUG MODE, REMOVING HIDDEN STATES")
253
  previous_frame = padding_image
 
 
 
 
 
 
 
 
254
  timing_info['full_frame'] = time.perf_counter() - process_start_time
255
 
256
  print(f"[{time.perf_counter():.3f}] Model inference complete. Queue size now: {input_queue.qsize()}")
 
23
 
24
 
25
  DEBUG_MODE = False
26
+ DEBUG_MODE_2 = True
27
 
28
  SCREEN_WIDTH = 512
29
  SCREEN_HEIGHT = 384
 
102
  if DEBUG_MODE:
103
  print ('DEBUG MODE, SETTING TIME STEP TO 0')
104
  time_step = 0
105
+ if DEBUG_MODE_2:
106
+ if time_step > 1:
107
+ print ('DEBUG MODE_2, SETTING TIME STEP TO 0')
108
+ time_step = 0
109
 
110
  inputs = {
111
  'image_features': previous_frame.to(device),
 
126
  print ('DEBUG MODE, REMOVING INPUTS')
127
  if 'hidden_states' in inputs:
128
  del inputs['hidden_states']
129
+ if DEBUG_MODE_2:
130
+ if time_step > 1:
131
+ print ('DEBUG MODE_2, REMOVING HIDDEN STATES')
132
+ if 'hidden_states' in inputs:
133
+ del inputs['hidden_states']
134
  return inputs
135
 
136
  @torch.no_grad()
 
254
  for key in keys_up_list:
255
  if key in keys_down: # Check if key exists to avoid KeyError
256
  keys_down.remove(key)
 
 
 
 
257
  if DEBUG_MODE:
258
  print (f"DEBUG MODE, REMOVING HIDDEN STATES")
259
  previous_frame = padding_image
260
+ if DEBUG_MODE_2:
261
+ if frame_num > 1:
262
+ print (f"DEBUG MODE_2, REMOVING HIDDEN STATES")
263
+ previous_frame = padding_image
264
+ inputs = prepare_model_inputs(previous_frame, hidden_states, x, y, is_right_click, is_left_click, list(keys_down), stoi, itos, frame_num)
265
+ print(f"[{time.perf_counter():.3f}] Starting model inference...")
266
+ previous_frame, sample_img, hidden_states, timing_info = await process_frame(model, inputs)
267
+
268
  timing_info['full_frame'] = time.perf_counter() - process_start_time
269
 
270
  print(f"[{time.perf_counter():.3f}] Model inference complete. Queue size now: {input_queue.qsize()}")