Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
7ea045b
1
Parent(s):
facfd46
- main.py +64 -21
- static/index.html +11 -0
main.py
CHANGED
@@ -170,17 +170,19 @@ def prepare_model_inputs(
|
|
170 |
@torch.no_grad()
|
171 |
async def process_frame(
|
172 |
model: LatentDiffusion,
|
173 |
-
inputs: Dict[str, torch.Tensor]
|
|
|
|
|
174 |
) -> Tuple[torch.Tensor, np.ndarray, Any, Dict[str, float]]:
|
175 |
"""Process a single frame through the model."""
|
176 |
# Run the heavy computation in a separate thread
|
177 |
loop = asyncio.get_running_loop()
|
178 |
return await loop.run_in_executor(
|
179 |
thread_executor,
|
180 |
-
lambda: _process_frame_sync(model, inputs)
|
181 |
)
|
182 |
|
183 |
-
def _process_frame_sync(model, inputs):
|
184 |
"""Synchronous version of process_frame that runs in a thread"""
|
185 |
timing = {}
|
186 |
# Temporal encoding
|
@@ -190,17 +192,17 @@ def _process_frame_sync(model, inputs):
|
|
190 |
|
191 |
# UNet sampling
|
192 |
start = time.perf_counter()
|
193 |
-
print (f"USE_RNN: {
|
194 |
-
if
|
195 |
sample_latent = output_from_rnn[:, :16]
|
196 |
else:
|
197 |
#NUM_SAMPLING_STEPS = 8
|
198 |
-
if
|
199 |
sample_latent = model.p_sample_loop(cond={'c_concat': output_from_rnn}, shape=[1, *LATENT_DIMS], return_intermediates=False, verbose=True)
|
200 |
else:
|
201 |
sampler = DDIMSampler(model)
|
202 |
sample_latent, _ = sampler.sample(
|
203 |
-
S=
|
204 |
conditioning={'c_concat': output_from_rnn},
|
205 |
batch_size=1,
|
206 |
shape=LATENT_DIMS,
|
@@ -253,6 +255,12 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
253 |
keys_down = set() # Initialize as an empty set
|
254 |
frame_num = -1
|
255 |
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
# Start timing for global FPS calculation
|
257 |
connection_start_time = time.perf_counter()
|
258 |
frame_count = 0
|
@@ -264,6 +272,8 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
264 |
# Add a function to reset the simulation
|
265 |
async def reset_simulation():
|
266 |
nonlocal previous_frame, hidden_states, keys_down, frame_num, is_processing, input_queue
|
|
|
|
|
267 |
|
268 |
# Log the reset action
|
269 |
log_interaction(
|
@@ -288,14 +298,24 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
288 |
frame_num = -1
|
289 |
is_processing = False
|
290 |
|
291 |
-
|
|
|
|
|
|
|
292 |
|
293 |
# Send confirmation to client
|
294 |
await websocket.send_json({"type": "reset_confirmed"})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
|
296 |
# Add a function to update sampling steps
|
297 |
async def update_sampling_steps(steps):
|
298 |
-
|
299 |
|
300 |
# Validate the input
|
301 |
if steps < 1:
|
@@ -303,24 +323,24 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
303 |
await websocket.send_json({"type": "error", "message": "Invalid sampling steps value"})
|
304 |
return
|
305 |
|
306 |
-
# Update the
|
307 |
-
old_steps =
|
308 |
-
|
309 |
|
310 |
-
print(f"[{time.perf_counter():.3f}] Updated
|
311 |
|
312 |
# Send confirmation to client
|
313 |
await websocket.send_json({"type": "steps_updated", "steps": steps})
|
314 |
|
315 |
# Add a function to update USE_RNN setting
|
316 |
async def update_use_rnn(use_rnn):
|
317 |
-
|
318 |
|
319 |
-
# Update the
|
320 |
-
old_setting =
|
321 |
-
|
322 |
|
323 |
-
print(f"[{time.perf_counter():.3f}] Updated USE_RNN from {old_setting} to {use_rnn}")
|
324 |
|
325 |
# Send confirmation to client
|
326 |
await websocket.send_json({"type": "rnn_updated", "use_rnn": use_rnn})
|
@@ -369,9 +389,22 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
369 |
previous_frame = padding_image
|
370 |
frame_num = 0
|
371 |
inputs = prepare_model_inputs(previous_frame, hidden_states, x, y, is_right_click, is_left_click, list(keys_down), stoi, itos, frame_num)
|
372 |
-
|
373 |
-
|
374 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
375 |
|
376 |
|
377 |
timing_info['full_frame'] = time.perf_counter() - process_start_time
|
@@ -493,6 +526,16 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
493 |
await update_use_rnn(data.get("use_rnn", False))
|
494 |
continue
|
495 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
496 |
# Add the input to our queue
|
497 |
await input_queue.put(data)
|
498 |
print(f"[{receive_time:.3f}] Received input. Queue size now: {input_queue.qsize()}")
|
|
|
170 |
@torch.no_grad()
|
171 |
async def process_frame(
|
172 |
model: LatentDiffusion,
|
173 |
+
inputs: Dict[str, torch.Tensor],
|
174 |
+
use_rnn: bool = False,
|
175 |
+
num_sampling_steps: int = 32
|
176 |
) -> Tuple[torch.Tensor, np.ndarray, Any, Dict[str, float]]:
|
177 |
"""Process a single frame through the model."""
|
178 |
# Run the heavy computation in a separate thread
|
179 |
loop = asyncio.get_running_loop()
|
180 |
return await loop.run_in_executor(
|
181 |
thread_executor,
|
182 |
+
lambda: _process_frame_sync(model, inputs, use_rnn, num_sampling_steps)
|
183 |
)
|
184 |
|
185 |
+
def _process_frame_sync(model, inputs, use_rnn, num_sampling_steps):
|
186 |
"""Synchronous version of process_frame that runs in a thread"""
|
187 |
timing = {}
|
188 |
# Temporal encoding
|
|
|
192 |
|
193 |
# UNet sampling
|
194 |
start = time.perf_counter()
|
195 |
+
print (f"USE_RNN: {use_rnn}, NUM_SAMPLING_STEPS: {num_sampling_steps}")
|
196 |
+
if use_rnn:
|
197 |
sample_latent = output_from_rnn[:, :16]
|
198 |
else:
|
199 |
#NUM_SAMPLING_STEPS = 8
|
200 |
+
if num_sampling_steps >= 1000:
|
201 |
sample_latent = model.p_sample_loop(cond={'c_concat': output_from_rnn}, shape=[1, *LATENT_DIMS], return_intermediates=False, verbose=True)
|
202 |
else:
|
203 |
sampler = DDIMSampler(model)
|
204 |
sample_latent, _ = sampler.sample(
|
205 |
+
S=num_sampling_steps,
|
206 |
conditioning={'c_concat': output_from_rnn},
|
207 |
batch_size=1,
|
208 |
shape=LATENT_DIMS,
|
|
|
255 |
keys_down = set() # Initialize as an empty set
|
256 |
frame_num = -1
|
257 |
|
258 |
+
# Client-specific settings
|
259 |
+
client_settings = {
|
260 |
+
"use_rnn": USE_RNN, # Start with default global value
|
261 |
+
"sampling_steps": NUM_SAMPLING_STEPS # Start with default global value
|
262 |
+
}
|
263 |
+
|
264 |
# Start timing for global FPS calculation
|
265 |
connection_start_time = time.perf_counter()
|
266 |
frame_count = 0
|
|
|
272 |
# Add a function to reset the simulation
|
273 |
async def reset_simulation():
|
274 |
nonlocal previous_frame, hidden_states, keys_down, frame_num, is_processing, input_queue
|
275 |
+
# Keep the client settings during reset
|
276 |
+
temp_client_settings = client_settings.copy()
|
277 |
|
278 |
# Log the reset action
|
279 |
log_interaction(
|
|
|
298 |
frame_num = -1
|
299 |
is_processing = False
|
300 |
|
301 |
+
# Restore client settings
|
302 |
+
client_settings.update(temp_client_settings)
|
303 |
+
|
304 |
+
print(f"[{time.perf_counter():.3f}] Simulation reset to initial state (preserved settings: USE_RNN={client_settings['use_rnn']}, SAMPLING_STEPS={client_settings['sampling_steps']})")
|
305 |
|
306 |
# Send confirmation to client
|
307 |
await websocket.send_json({"type": "reset_confirmed"})
|
308 |
+
|
309 |
+
# Also send the current settings to update the UI
|
310 |
+
await websocket.send_json({
|
311 |
+
"type": "settings",
|
312 |
+
"sampling_steps": client_settings["sampling_steps"],
|
313 |
+
"use_rnn": client_settings["use_rnn"]
|
314 |
+
})
|
315 |
|
316 |
# Add a function to update sampling steps
|
317 |
async def update_sampling_steps(steps):
|
318 |
+
nonlocal client_settings
|
319 |
|
320 |
# Validate the input
|
321 |
if steps < 1:
|
|
|
323 |
await websocket.send_json({"type": "error", "message": "Invalid sampling steps value"})
|
324 |
return
|
325 |
|
326 |
+
# Update the client-specific setting
|
327 |
+
old_steps = client_settings["sampling_steps"]
|
328 |
+
client_settings["sampling_steps"] = steps
|
329 |
|
330 |
+
print(f"[{time.perf_counter():.3f}] Updated sampling steps for client {client_id} from {old_steps} to {steps}")
|
331 |
|
332 |
# Send confirmation to client
|
333 |
await websocket.send_json({"type": "steps_updated", "steps": steps})
|
334 |
|
335 |
# Add a function to update USE_RNN setting
|
336 |
async def update_use_rnn(use_rnn):
|
337 |
+
nonlocal client_settings
|
338 |
|
339 |
+
# Update the client-specific setting
|
340 |
+
old_setting = client_settings["use_rnn"]
|
341 |
+
client_settings["use_rnn"] = use_rnn
|
342 |
|
343 |
+
print(f"[{time.perf_counter():.3f}] Updated USE_RNN for client {client_id} from {old_setting} to {use_rnn}")
|
344 |
|
345 |
# Send confirmation to client
|
346 |
await websocket.send_json({"type": "rnn_updated", "use_rnn": use_rnn})
|
|
|
389 |
previous_frame = padding_image
|
390 |
frame_num = 0
|
391 |
inputs = prepare_model_inputs(previous_frame, hidden_states, x, y, is_right_click, is_left_click, list(keys_down), stoi, itos, frame_num)
|
392 |
+
|
393 |
+
# Use client-specific settings
|
394 |
+
client_use_rnn = client_settings["use_rnn"]
|
395 |
+
client_sampling_steps = client_settings["sampling_steps"]
|
396 |
+
|
397 |
+
print(f"[{time.perf_counter():.3f}] Starting model inference with client settings - USE_RNN: {client_use_rnn}, SAMPLING_STEPS: {client_sampling_steps}...")
|
398 |
+
|
399 |
+
# Pass client-specific settings to process_frame
|
400 |
+
previous_frame, sample_img, hidden_states, timing_info = await process_frame(
|
401 |
+
model,
|
402 |
+
inputs,
|
403 |
+
use_rnn=client_use_rnn,
|
404 |
+
num_sampling_steps=client_sampling_steps
|
405 |
+
)
|
406 |
+
|
407 |
+
print (f'Client {client_id} settings: USE_RNN: {client_use_rnn}, SAMPLING_STEPS: {client_sampling_steps}')
|
408 |
|
409 |
|
410 |
timing_info['full_frame'] = time.perf_counter() - process_start_time
|
|
|
526 |
await update_use_rnn(data.get("use_rnn", False))
|
527 |
continue
|
528 |
|
529 |
+
# Handle settings request
|
530 |
+
if data.get("type") == "get_settings":
|
531 |
+
print(f"[{receive_time:.3f}] Received request for current settings")
|
532 |
+
await websocket.send_json({
|
533 |
+
"type": "settings",
|
534 |
+
"sampling_steps": client_settings["sampling_steps"],
|
535 |
+
"use_rnn": client_settings["use_rnn"]
|
536 |
+
})
|
537 |
+
continue
|
538 |
+
|
539 |
# Add the input to our queue
|
540 |
await input_queue.put(data)
|
541 |
print(f"[{receive_time:.3f}] Received input. Queue size now: {input_queue.qsize()}")
|
static/index.html
CHANGED
@@ -164,6 +164,12 @@
|
|
164 |
console.log("WebSocket connection established");
|
165 |
isConnected = true;
|
166 |
reconnectAttempts = 0;
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
//startHeartbeat();
|
168 |
};
|
169 |
|
@@ -198,6 +204,11 @@
|
|
198 |
console.log(`USE_RNN setting updated to: ${data.use_rnn}`);
|
199 |
// Update the toggle to match the server state
|
200 |
document.getElementById('useRnnToggle').checked = data.use_rnn;
|
|
|
|
|
|
|
|
|
|
|
201 |
}
|
202 |
};
|
203 |
}
|
|
|
164 |
console.log("WebSocket connection established");
|
165 |
isConnected = true;
|
166 |
reconnectAttempts = 0;
|
167 |
+
|
168 |
+
// Request current settings from server to sync UI
|
169 |
+
socket.send(JSON.stringify({
|
170 |
+
type: "get_settings"
|
171 |
+
}));
|
172 |
+
|
173 |
//startHeartbeat();
|
174 |
};
|
175 |
|
|
|
204 |
console.log(`USE_RNN setting updated to: ${data.use_rnn}`);
|
205 |
// Update the toggle to match the server state
|
206 |
document.getElementById('useRnnToggle').checked = data.use_rnn;
|
207 |
+
} else if (data.type === "settings") {
|
208 |
+
// Update UI elements to match server settings
|
209 |
+
console.log(`Received settings from server: SAMPLING_STEPS=${data.sampling_steps}, USE_RNN=${data.use_rnn}`);
|
210 |
+
document.getElementById('samplingSteps').value = data.sampling_steps;
|
211 |
+
document.getElementById('useRnnToggle').checked = data.use_rnn;
|
212 |
}
|
213 |
};
|
214 |
}
|