thecollabagepatch commited on
Commit
946680d
·
1 Parent(s): 2d4dad3

reseed splice fix

Browse files
Files changed (1) hide show
  1. jam_worker.py +63 -6
jam_worker.py CHANGED
@@ -71,6 +71,7 @@ class BarClock:
71
  # -----------------------------
72
 
73
  class JamWorker(threading.Thread):
 
74
  """Generates continuous audio with MagentaRT, spools it at target SR,
75
  and emits *sample-accurate*, bar-aligned chunks (no FPS drift)."""
76
 
@@ -93,6 +94,7 @@ class JamWorker(threading.Thread):
93
 
94
  # codec/setup
95
  self._codec_fps = float(self.mrt.codec.frame_rate)
 
96
  self._ctx_frames = int(self.mrt.config.context_length_frames)
97
  self._ctx_seconds = self._ctx_frames / self._codec_fps
98
 
@@ -121,8 +123,9 @@ class JamWorker(threading.Thread):
121
  self._stop_event = threading.Event()
122
  self._max_buffer_ahead = 5
123
 
124
- # reseed queue (install at next safe point)
125
- self._pending_reseed: Optional[dict] = None
 
126
 
127
  # Prepare initial context from combined loop (best musical alignment)
128
  if self.params.combined_loop is not None:
@@ -254,10 +257,49 @@ class JamWorker(threading.Thread):
254
  self._original_context_tokens = np.copy(context_tokens)
255
 
256
  def reseed_splice(self, recent_wav: au.Waveform, anchor_bars: float):
257
- """Queue a splice reseed to be applied right after the next emitted loop."""
258
- new_ctx = self._encode_exact_context_tokens(recent_wav)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  with self._lock:
260
- self._pending_reseed = {"ctx": new_ctx}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
 
263
  def reseed_from_waveform(self, wav: au.Waveform):
@@ -376,7 +418,22 @@ class JamWorker(threading.Thread):
376
 
377
  # If a reseed is queued, install it *right after* we finish a chunk
378
  with self._lock:
379
- if self._pending_reseed is not None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
  new_state = self.mrt.init_state()
381
  new_state.context_tokens = self._pending_reseed["ctx"]
382
  self.state = new_state
 
71
  # -----------------------------
72
 
73
  class JamWorker(threading.Thread):
74
+ FRAMES_PER_SECOND: float | None = None # filled in __init__ once codec is available
75
  """Generates continuous audio with MagentaRT, spools it at target SR,
76
  and emits *sample-accurate*, bar-aligned chunks (no FPS drift)."""
77
 
 
94
 
95
  # codec/setup
96
  self._codec_fps = float(self.mrt.codec.frame_rate)
97
+ JamWorker.FRAMES_PER_SECOND = self._codec_fps
98
  self._ctx_frames = int(self.mrt.config.context_length_frames)
99
  self._ctx_seconds = self._ctx_frames / self._codec_fps
100
 
 
123
  self._stop_event = threading.Event()
124
  self._max_buffer_ahead = 5
125
 
126
+ # reseed queues (install at next bar boundary after emission)
127
+ self._pending_reseed: Optional[dict] = None # legacy full reset path (kept for fallback)
128
+ self._pending_token_splice: Optional[dict] = None # seamless token splice
129
 
130
  # Prepare initial context from combined loop (best musical alignment)
131
  if self.params.combined_loop is not None:
 
257
  self._original_context_tokens = np.copy(context_tokens)
258
 
259
  def reseed_splice(self, recent_wav: au.Waveform, anchor_bars: float):
260
+ """Queue a *seamless* reseed by token splicing instead of full restart.
261
+ We compute a fresh, bar-locked context token tensor of exact length
262
+ (e.g., 250 frames), then splice only the *tail* corresponding to
263
+ `anchor_bars` so generation continues smoothly without resetting state.
264
+ """
265
+ new_ctx = self._encode_exact_context_tokens(recent_wav) # (F,D)
266
+ F = int(self._ctx_frames)
267
+ D = int(self.mrt.config.decoder_codec_rvq_depth)
268
+ assert new_ctx.shape == (F, D), f"expected {(F, D)}, got {new_ctx.shape}"
269
+
270
+ # how many frames correspond to the requested anchor bars
271
+ spb = self._bar_clock.seconds_per_bar()
272
+ frames_per_bar = int(round(self._codec_fps * spb))
273
+ splice_frames = int(round(max(1, anchor_bars) * frames_per_bar))
274
+ splice_frames = max(1, min(splice_frames, F))
275
+
276
  with self._lock:
277
+ # snapshot current context
278
+ cur = getattr(self.state, "context_tokens", None)
279
+ if cur is None:
280
+ # if state has no context yet, fall back to full reseed
281
+ self._pending_reseed = {"ctx": new_ctx}
282
+ return
283
+ if cur.shape != (F, D):
284
+ # safety: coerce by trim/pad
285
+ if cur.shape[0] > F:
286
+ cur = cur[-F:, :]
287
+ elif cur.shape[0] < F:
288
+ pad = np.repeat(cur[0:1, :], F - cur.shape[0], axis=0)
289
+ cur = np.concatenate([pad, cur], axis=0)
290
+ if cur.shape[1] != D:
291
+ cur = cur[:, :D]
292
+
293
+ # build the spliced tensor: keep left (F - splice) from cur, take right (splice) from new
294
+ left = cur[:F - splice_frames, :]
295
+ right = new_ctx[F - splice_frames:, :]
296
+ spliced = np.concatenate([left, right], axis=0)
297
+
298
+ # queue for install at the *next bar boundary* right after emission
299
+ self._pending_token_splice = {
300
+ "tokens": spliced,
301
+ "debug": {"F": F, "D": D, "splice_frames": splice_frames, "frames_per_bar": frames_per_bar}
302
+ }
303
 
304
 
305
  def reseed_from_waveform(self, wav: au.Waveform):
 
418
 
419
  # If a reseed is queued, install it *right after* we finish a chunk
420
  with self._lock:
421
+ # Prefer seamless token splice when available
422
+ if self._pending_token_splice is not None:
423
+ try:
424
+ spliced = self._pending_token_splice["tokens"]
425
+ self.state.context_tokens = spliced # in-place, no reset
426
+ self._pending_token_splice = None
427
+ # do NOT reset self._model_stream — keep continuity
428
+ # leave params/style as-is
429
+ except Exception as e:
430
+ # fallback: full reseed if setter rejects
431
+ new_state = self.mrt.init_state()
432
+ new_state.context_tokens = spliced
433
+ self.state = new_state
434
+ self._model_stream = None
435
+ self._pending_token_splice = None
436
+ elif self._pending_reseed is not None:
437
  new_state = self.mrt.init_state()
438
  new_state.context_tokens = self._pending_reseed["ctx"]
439
  self.state = new_state