da03 commited on
Commit
7ea045b
·
1 Parent(s): facfd46
Files changed (2) hide show
  1. main.py +64 -21
  2. static/index.html +11 -0
main.py CHANGED
@@ -170,17 +170,19 @@ def prepare_model_inputs(
170
  @torch.no_grad()
171
  async def process_frame(
172
  model: LatentDiffusion,
173
- inputs: Dict[str, torch.Tensor]
 
 
174
  ) -> Tuple[torch.Tensor, np.ndarray, Any, Dict[str, float]]:
175
  """Process a single frame through the model."""
176
  # Run the heavy computation in a separate thread
177
  loop = asyncio.get_running_loop()
178
  return await loop.run_in_executor(
179
  thread_executor,
180
- lambda: _process_frame_sync(model, inputs)
181
  )
182
 
183
- def _process_frame_sync(model, inputs):
184
  """Synchronous version of process_frame that runs in a thread"""
185
  timing = {}
186
  # Temporal encoding
@@ -190,17 +192,17 @@ def _process_frame_sync(model, inputs):
190
 
191
  # UNet sampling
192
  start = time.perf_counter()
193
- print (f"USE_RNN: {USE_RNN}, NUM_SAMPLING_STEPS: {NUM_SAMPLING_STEPS}")
194
- if USE_RNN:
195
  sample_latent = output_from_rnn[:, :16]
196
  else:
197
  #NUM_SAMPLING_STEPS = 8
198
- if NUM_SAMPLING_STEPS >= 1000:
199
  sample_latent = model.p_sample_loop(cond={'c_concat': output_from_rnn}, shape=[1, *LATENT_DIMS], return_intermediates=False, verbose=True)
200
  else:
201
  sampler = DDIMSampler(model)
202
  sample_latent, _ = sampler.sample(
203
- S=NUM_SAMPLING_STEPS,
204
  conditioning={'c_concat': output_from_rnn},
205
  batch_size=1,
206
  shape=LATENT_DIMS,
@@ -253,6 +255,12 @@ async def websocket_endpoint(websocket: WebSocket):
253
  keys_down = set() # Initialize as an empty set
254
  frame_num = -1
255
 
 
 
 
 
 
 
256
  # Start timing for global FPS calculation
257
  connection_start_time = time.perf_counter()
258
  frame_count = 0
@@ -264,6 +272,8 @@ async def websocket_endpoint(websocket: WebSocket):
264
  # Add a function to reset the simulation
265
  async def reset_simulation():
266
  nonlocal previous_frame, hidden_states, keys_down, frame_num, is_processing, input_queue
 
 
267
 
268
  # Log the reset action
269
  log_interaction(
@@ -288,14 +298,24 @@ async def websocket_endpoint(websocket: WebSocket):
288
  frame_num = -1
289
  is_processing = False
290
 
291
- print(f"[{time.perf_counter():.3f}] Simulation reset to initial state")
 
 
 
292
 
293
  # Send confirmation to client
294
  await websocket.send_json({"type": "reset_confirmed"})
 
 
 
 
 
 
 
295
 
296
  # Add a function to update sampling steps
297
  async def update_sampling_steps(steps):
298
- global NUM_SAMPLING_STEPS
299
 
300
  # Validate the input
301
  if steps < 1:
@@ -303,24 +323,24 @@ async def websocket_endpoint(websocket: WebSocket):
303
  await websocket.send_json({"type": "error", "message": "Invalid sampling steps value"})
304
  return
305
 
306
- # Update the global variable
307
- old_steps = NUM_SAMPLING_STEPS
308
- NUM_SAMPLING_STEPS = steps
309
 
310
- print(f"[{time.perf_counter():.3f}] Updated NUM_SAMPLING_STEPS from {old_steps} to {steps}")
311
 
312
  # Send confirmation to client
313
  await websocket.send_json({"type": "steps_updated", "steps": steps})
314
 
315
  # Add a function to update USE_RNN setting
316
  async def update_use_rnn(use_rnn):
317
- global USE_RNN
318
 
319
- # Update the global variable
320
- old_setting = USE_RNN
321
- USE_RNN = use_rnn
322
 
323
- print(f"[{time.perf_counter():.3f}] Updated USE_RNN from {old_setting} to {use_rnn}")
324
 
325
  # Send confirmation to client
326
  await websocket.send_json({"type": "rnn_updated", "use_rnn": use_rnn})
@@ -369,9 +389,22 @@ async def websocket_endpoint(websocket: WebSocket):
369
  previous_frame = padding_image
370
  frame_num = 0
371
  inputs = prepare_model_inputs(previous_frame, hidden_states, x, y, is_right_click, is_left_click, list(keys_down), stoi, itos, frame_num)
372
- print(f"[{time.perf_counter():.3f}] Starting model inference...")
373
- previous_frame, sample_img, hidden_states, timing_info = await process_frame(model, inputs)
374
- print (f'aaa setting: DEBUG_MODE: {DEBUG_MODE}, DEBUG_MODE_2: {DEBUG_MODE_2}, NUM_MAX_FRAMES: {NUM_MAX_FRAMES}, NUM_SAMPLING_STEPS: {NUM_SAMPLING_STEPS}')
 
 
 
 
 
 
 
 
 
 
 
 
 
375
 
376
 
377
  timing_info['full_frame'] = time.perf_counter() - process_start_time
@@ -493,6 +526,16 @@ async def websocket_endpoint(websocket: WebSocket):
493
  await update_use_rnn(data.get("use_rnn", False))
494
  continue
495
 
 
 
 
 
 
 
 
 
 
 
496
  # Add the input to our queue
497
  await input_queue.put(data)
498
  print(f"[{receive_time:.3f}] Received input. Queue size now: {input_queue.qsize()}")
 
170
  @torch.no_grad()
171
  async def process_frame(
172
  model: LatentDiffusion,
173
+ inputs: Dict[str, torch.Tensor],
174
+ use_rnn: bool = False,
175
+ num_sampling_steps: int = 32
176
  ) -> Tuple[torch.Tensor, np.ndarray, Any, Dict[str, float]]:
177
  """Process a single frame through the model."""
178
  # Run the heavy computation in a separate thread
179
  loop = asyncio.get_running_loop()
180
  return await loop.run_in_executor(
181
  thread_executor,
182
+ lambda: _process_frame_sync(model, inputs, use_rnn, num_sampling_steps)
183
  )
184
 
185
+ def _process_frame_sync(model, inputs, use_rnn, num_sampling_steps):
186
  """Synchronous version of process_frame that runs in a thread"""
187
  timing = {}
188
  # Temporal encoding
 
192
 
193
  # UNet sampling
194
  start = time.perf_counter()
195
+ print (f"USE_RNN: {use_rnn}, NUM_SAMPLING_STEPS: {num_sampling_steps}")
196
+ if use_rnn:
197
  sample_latent = output_from_rnn[:, :16]
198
  else:
199
  #NUM_SAMPLING_STEPS = 8
200
+ if num_sampling_steps >= 1000:
201
  sample_latent = model.p_sample_loop(cond={'c_concat': output_from_rnn}, shape=[1, *LATENT_DIMS], return_intermediates=False, verbose=True)
202
  else:
203
  sampler = DDIMSampler(model)
204
  sample_latent, _ = sampler.sample(
205
+ S=num_sampling_steps,
206
  conditioning={'c_concat': output_from_rnn},
207
  batch_size=1,
208
  shape=LATENT_DIMS,
 
255
  keys_down = set() # Initialize as an empty set
256
  frame_num = -1
257
 
258
+ # Client-specific settings
259
+ client_settings = {
260
+ "use_rnn": USE_RNN, # Start with default global value
261
+ "sampling_steps": NUM_SAMPLING_STEPS # Start with default global value
262
+ }
263
+
264
  # Start timing for global FPS calculation
265
  connection_start_time = time.perf_counter()
266
  frame_count = 0
 
272
  # Add a function to reset the simulation
273
  async def reset_simulation():
274
  nonlocal previous_frame, hidden_states, keys_down, frame_num, is_processing, input_queue
275
+ # Keep the client settings during reset
276
+ temp_client_settings = client_settings.copy()
277
 
278
  # Log the reset action
279
  log_interaction(
 
298
  frame_num = -1
299
  is_processing = False
300
 
301
+ # Restore client settings
302
+ client_settings.update(temp_client_settings)
303
+
304
+ print(f"[{time.perf_counter():.3f}] Simulation reset to initial state (preserved settings: USE_RNN={client_settings['use_rnn']}, SAMPLING_STEPS={client_settings['sampling_steps']})")
305
 
306
  # Send confirmation to client
307
  await websocket.send_json({"type": "reset_confirmed"})
308
+
309
+ # Also send the current settings to update the UI
310
+ await websocket.send_json({
311
+ "type": "settings",
312
+ "sampling_steps": client_settings["sampling_steps"],
313
+ "use_rnn": client_settings["use_rnn"]
314
+ })
315
 
316
  # Add a function to update sampling steps
317
  async def update_sampling_steps(steps):
318
+ nonlocal client_settings
319
 
320
  # Validate the input
321
  if steps < 1:
 
323
  await websocket.send_json({"type": "error", "message": "Invalid sampling steps value"})
324
  return
325
 
326
+ # Update the client-specific setting
327
+ old_steps = client_settings["sampling_steps"]
328
+ client_settings["sampling_steps"] = steps
329
 
330
+ print(f"[{time.perf_counter():.3f}] Updated sampling steps for client {client_id} from {old_steps} to {steps}")
331
 
332
  # Send confirmation to client
333
  await websocket.send_json({"type": "steps_updated", "steps": steps})
334
 
335
  # Add a function to update USE_RNN setting
336
  async def update_use_rnn(use_rnn):
337
+ nonlocal client_settings
338
 
339
+ # Update the client-specific setting
340
+ old_setting = client_settings["use_rnn"]
341
+ client_settings["use_rnn"] = use_rnn
342
 
343
+ print(f"[{time.perf_counter():.3f}] Updated USE_RNN for client {client_id} from {old_setting} to {use_rnn}")
344
 
345
  # Send confirmation to client
346
  await websocket.send_json({"type": "rnn_updated", "use_rnn": use_rnn})
 
389
  previous_frame = padding_image
390
  frame_num = 0
391
  inputs = prepare_model_inputs(previous_frame, hidden_states, x, y, is_right_click, is_left_click, list(keys_down), stoi, itos, frame_num)
392
+
393
+ # Use client-specific settings
394
+ client_use_rnn = client_settings["use_rnn"]
395
+ client_sampling_steps = client_settings["sampling_steps"]
396
+
397
+ print(f"[{time.perf_counter():.3f}] Starting model inference with client settings - USE_RNN: {client_use_rnn}, SAMPLING_STEPS: {client_sampling_steps}...")
398
+
399
+ # Pass client-specific settings to process_frame
400
+ previous_frame, sample_img, hidden_states, timing_info = await process_frame(
401
+ model,
402
+ inputs,
403
+ use_rnn=client_use_rnn,
404
+ num_sampling_steps=client_sampling_steps
405
+ )
406
+
407
+ print (f'Client {client_id} settings: USE_RNN: {client_use_rnn}, SAMPLING_STEPS: {client_sampling_steps}')
408
 
409
 
410
  timing_info['full_frame'] = time.perf_counter() - process_start_time
 
526
  await update_use_rnn(data.get("use_rnn", False))
527
  continue
528
 
529
+ # Handle settings request
530
+ if data.get("type") == "get_settings":
531
+ print(f"[{receive_time:.3f}] Received request for current settings")
532
+ await websocket.send_json({
533
+ "type": "settings",
534
+ "sampling_steps": client_settings["sampling_steps"],
535
+ "use_rnn": client_settings["use_rnn"]
536
+ })
537
+ continue
538
+
539
  # Add the input to our queue
540
  await input_queue.put(data)
541
  print(f"[{receive_time:.3f}] Received input. Queue size now: {input_queue.qsize()}")
static/index.html CHANGED
@@ -164,6 +164,12 @@
164
  console.log("WebSocket connection established");
165
  isConnected = true;
166
  reconnectAttempts = 0;
 
 
 
 
 
 
167
  //startHeartbeat();
168
  };
169
 
@@ -198,6 +204,11 @@
198
  console.log(`USE_RNN setting updated to: ${data.use_rnn}`);
199
  // Update the toggle to match the server state
200
  document.getElementById('useRnnToggle').checked = data.use_rnn;
 
 
 
 
 
201
  }
202
  };
203
  }
 
164
  console.log("WebSocket connection established");
165
  isConnected = true;
166
  reconnectAttempts = 0;
167
+
168
+ // Request current settings from server to sync UI
169
+ socket.send(JSON.stringify({
170
+ type: "get_settings"
171
+ }));
172
+
173
  //startHeartbeat();
174
  };
175
 
 
204
  console.log(`USE_RNN setting updated to: ${data.use_rnn}`);
205
  // Update the toggle to match the server state
206
  document.getElementById('useRnnToggle').checked = data.use_rnn;
207
+ } else if (data.type === "settings") {
208
+ // Update UI elements to match server settings
209
+ console.log(`Received settings from server: SAMPLING_STEPS=${data.sampling_steps}, USE_RNN=${data.use_rnn}`);
210
+ document.getElementById('samplingSteps').value = data.sampling_steps;
211
+ document.getElementById('useRnnToggle').checked = data.use_rnn;
212
  }
213
  };
214
  }