da03 commited on
Commit
fc0bb07
·
1 Parent(s): 9ced953
Files changed (1) hide show
  1. main.py +18 -2
main.py CHANGED
@@ -13,6 +13,7 @@ import os
13
  import time
14
  from typing import Any, Dict
15
  from ldm.models.diffusion.ddpm import LatentDiffusion, DDIMSampler
 
16
 
17
  torch.backends.cuda.matmul.allow_tf32 = True
18
  torch.backends.cudnn.allow_tf32 = True
@@ -74,6 +75,9 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
74
 
75
  # Add this at the top with other global variables
76
 
 
 
 
77
  def prepare_model_inputs(
78
  previous_frame: torch.Tensor,
79
  hidden_states: Any,
@@ -110,11 +114,20 @@ def prepare_model_inputs(
110
  return inputs
111
 
112
  @torch.no_grad()
113
- def process_frame(
114
  model: LatentDiffusion,
115
  inputs: Dict[str, torch.Tensor]
116
  ) -> Tuple[torch.Tensor, np.ndarray, Any, Dict[str, float]]:
117
  """Process a single frame through the model."""
 
 
 
 
 
 
 
 
 
118
  timing = {}
119
  # Temporal encoding
120
  start = time.perf_counter()
@@ -136,7 +149,10 @@ def process_frame(
136
  # Decoding
137
  start = time.perf_counter()
138
  sample = sample_latent * DATA_NORMALIZATION['std'] + DATA_NORMALIZATION['mean']
 
 
139
  time.sleep(10)
 
140
  sample = model.decode_first_stage(sample)
141
  sample = sample.squeeze(0).clamp(-1, 1)
142
  timing['decode'] = time.perf_counter() - start
@@ -212,7 +228,7 @@ async def websocket_endpoint(websocket: WebSocket):
212
 
213
  inputs = prepare_model_inputs(previous_frame, hidden_states, x, y, is_right_click, is_left_click, list(keys_down), stoi, itos, frame_num)
214
  print(f"[{time.perf_counter():.3f}] Starting model inference...")
215
- previous_frame, sample_img, hidden_states, timing_info = process_frame(model, inputs)
216
  timing_info['full_frame'] = time.perf_counter() - process_start_time
217
 
218
  print(f"[{time.perf_counter():.3f}] Model inference complete. Queue size now: {input_queue.qsize()}")
 
13
  import time
14
  from typing import Any, Dict
15
  from ldm.models.diffusion.ddpm import LatentDiffusion, DDIMSampler
16
+ import concurrent.futures
17
 
18
  torch.backends.cuda.matmul.allow_tf32 = True
19
  torch.backends.cudnn.allow_tf32 = True
 
75
 
76
  # Add this at the top with other global variables
77
 
78
+ # Create a thread pool executor
79
+ thread_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
80
+
81
  def prepare_model_inputs(
82
  previous_frame: torch.Tensor,
83
  hidden_states: Any,
 
114
  return inputs
115
 
116
  @torch.no_grad()
117
+ async def process_frame(
118
  model: LatentDiffusion,
119
  inputs: Dict[str, torch.Tensor]
120
  ) -> Tuple[torch.Tensor, np.ndarray, Any, Dict[str, float]]:
121
  """Process a single frame through the model."""
122
+ # Run the heavy computation in a separate thread
123
+ loop = asyncio.get_running_loop()
124
+ return await loop.run_in_executor(
125
+ thread_executor,
126
+ lambda: _process_frame_sync(model, inputs)
127
+ )
128
+
129
+ def _process_frame_sync(model, inputs):
130
+ """Synchronous version of process_frame that runs in a thread"""
131
  timing = {}
132
  # Temporal encoding
133
  start = time.perf_counter()
 
149
  # Decoding
150
  start = time.perf_counter()
151
  sample = sample_latent * DATA_NORMALIZATION['std'] + DATA_NORMALIZATION['mean']
152
+
153
+ # Use time.sleep(10) here since it's in a separate thread
154
  time.sleep(10)
155
+
156
  sample = model.decode_first_stage(sample)
157
  sample = sample.squeeze(0).clamp(-1, 1)
158
  timing['decode'] = time.perf_counter() - start
 
228
 
229
  inputs = prepare_model_inputs(previous_frame, hidden_states, x, y, is_right_click, is_left_click, list(keys_down), stoi, itos, frame_num)
230
  print(f"[{time.perf_counter():.3f}] Starting model inference...")
231
+ previous_frame, sample_img, hidden_states, timing_info = await process_frame(model, inputs)
232
  timing_info['full_frame'] = time.perf_counter() - process_start_time
233
 
234
  print(f"[{time.perf_counter():.3f}] Model inference complete. Queue size now: {input_queue.qsize()}")