Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
fc0bb07
1
Parent(s):
9ced953
main.py
CHANGED
@@ -13,6 +13,7 @@ import os
|
|
13 |
import time
|
14 |
from typing import Any, Dict
|
15 |
from ldm.models.diffusion.ddpm import LatentDiffusion, DDIMSampler
|
|
|
16 |
|
17 |
torch.backends.cuda.matmul.allow_tf32 = True
|
18 |
torch.backends.cudnn.allow_tf32 = True
|
@@ -74,6 +75,9 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
74 |
|
75 |
# Add this at the top with other global variables
|
76 |
|
|
|
|
|
|
|
77 |
def prepare_model_inputs(
|
78 |
previous_frame: torch.Tensor,
|
79 |
hidden_states: Any,
|
@@ -110,11 +114,20 @@ def prepare_model_inputs(
|
|
110 |
return inputs
|
111 |
|
112 |
@torch.no_grad()
|
113 |
-
def process_frame(
|
114 |
model: LatentDiffusion,
|
115 |
inputs: Dict[str, torch.Tensor]
|
116 |
) -> Tuple[torch.Tensor, np.ndarray, Any, Dict[str, float]]:
|
117 |
"""Process a single frame through the model."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
timing = {}
|
119 |
# Temporal encoding
|
120 |
start = time.perf_counter()
|
@@ -136,7 +149,10 @@ def process_frame(
|
|
136 |
# Decoding
|
137 |
start = time.perf_counter()
|
138 |
sample = sample_latent * DATA_NORMALIZATION['std'] + DATA_NORMALIZATION['mean']
|
|
|
|
|
139 |
time.sleep(10)
|
|
|
140 |
sample = model.decode_first_stage(sample)
|
141 |
sample = sample.squeeze(0).clamp(-1, 1)
|
142 |
timing['decode'] = time.perf_counter() - start
|
@@ -212,7 +228,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
212 |
|
213 |
inputs = prepare_model_inputs(previous_frame, hidden_states, x, y, is_right_click, is_left_click, list(keys_down), stoi, itos, frame_num)
|
214 |
print(f"[{time.perf_counter():.3f}] Starting model inference...")
|
215 |
-
previous_frame, sample_img, hidden_states, timing_info = process_frame(model, inputs)
|
216 |
timing_info['full_frame'] = time.perf_counter() - process_start_time
|
217 |
|
218 |
print(f"[{time.perf_counter():.3f}] Model inference complete. Queue size now: {input_queue.qsize()}")
|
|
|
13 |
import time
|
14 |
from typing import Any, Dict
|
15 |
from ldm.models.diffusion.ddpm import LatentDiffusion, DDIMSampler
|
16 |
+
import concurrent.futures
|
17 |
|
18 |
torch.backends.cuda.matmul.allow_tf32 = True
|
19 |
torch.backends.cudnn.allow_tf32 = True
|
|
|
75 |
|
76 |
# Add this at the top with other global variables
|
77 |
|
78 |
+
# Create a thread pool executor
|
79 |
+
thread_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
80 |
+
|
81 |
def prepare_model_inputs(
|
82 |
previous_frame: torch.Tensor,
|
83 |
hidden_states: Any,
|
|
|
114 |
return inputs
|
115 |
|
116 |
@torch.no_grad()
|
117 |
+
async def process_frame(
|
118 |
model: LatentDiffusion,
|
119 |
inputs: Dict[str, torch.Tensor]
|
120 |
) -> Tuple[torch.Tensor, np.ndarray, Any, Dict[str, float]]:
|
121 |
"""Process a single frame through the model."""
|
122 |
+
# Run the heavy computation in a separate thread
|
123 |
+
loop = asyncio.get_running_loop()
|
124 |
+
return await loop.run_in_executor(
|
125 |
+
thread_executor,
|
126 |
+
lambda: _process_frame_sync(model, inputs)
|
127 |
+
)
|
128 |
+
|
129 |
+
def _process_frame_sync(model, inputs):
|
130 |
+
"""Synchronous version of process_frame that runs in a thread"""
|
131 |
timing = {}
|
132 |
# Temporal encoding
|
133 |
start = time.perf_counter()
|
|
|
149 |
# Decoding
|
150 |
start = time.perf_counter()
|
151 |
sample = sample_latent * DATA_NORMALIZATION['std'] + DATA_NORMALIZATION['mean']
|
152 |
+
|
153 |
+
# Use time.sleep(10) here since it's in a separate thread
|
154 |
time.sleep(10)
|
155 |
+
|
156 |
sample = model.decode_first_stage(sample)
|
157 |
sample = sample.squeeze(0).clamp(-1, 1)
|
158 |
timing['decode'] = time.perf_counter() - start
|
|
|
228 |
|
229 |
inputs = prepare_model_inputs(previous_frame, hidden_states, x, y, is_right_click, is_left_click, list(keys_down), stoi, itos, frame_num)
|
230 |
print(f"[{time.perf_counter():.3f}] Starting model inference...")
|
231 |
+
previous_frame, sample_img, hidden_states, timing_info = await process_frame(model, inputs)
|
232 |
timing_info['full_frame'] = time.perf_counter() - process_start_time
|
233 |
|
234 |
print(f"[{time.perf_counter():.3f}] Model inference complete. Queue size now: {input_queue.qsize()}")
|