Commit
·
946680d
1
Parent(s):
2d4dad3
reseed splice fix
Browse files- jam_worker.py +63 -6
jam_worker.py
CHANGED
@@ -71,6 +71,7 @@ class BarClock:
|
|
71 |
# -----------------------------
|
72 |
|
73 |
class JamWorker(threading.Thread):
|
|
|
74 |
"""Generates continuous audio with MagentaRT, spools it at target SR,
|
75 |
and emits *sample-accurate*, bar-aligned chunks (no FPS drift)."""
|
76 |
|
@@ -93,6 +94,7 @@ class JamWorker(threading.Thread):
|
|
93 |
|
94 |
# codec/setup
|
95 |
self._codec_fps = float(self.mrt.codec.frame_rate)
|
|
|
96 |
self._ctx_frames = int(self.mrt.config.context_length_frames)
|
97 |
self._ctx_seconds = self._ctx_frames / self._codec_fps
|
98 |
|
@@ -121,8 +123,9 @@ class JamWorker(threading.Thread):
|
|
121 |
self._stop_event = threading.Event()
|
122 |
self._max_buffer_ahead = 5
|
123 |
|
124 |
-
# reseed
|
125 |
-
self._pending_reseed: Optional[dict] = None
|
|
|
126 |
|
127 |
# Prepare initial context from combined loop (best musical alignment)
|
128 |
if self.params.combined_loop is not None:
|
@@ -254,10 +257,49 @@ class JamWorker(threading.Thread):
|
|
254 |
self._original_context_tokens = np.copy(context_tokens)
|
255 |
|
256 |
def reseed_splice(self, recent_wav: au.Waveform, anchor_bars: float):
|
257 |
-
"""Queue a
|
258 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
with self._lock:
|
260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
|
262 |
|
263 |
def reseed_from_waveform(self, wav: au.Waveform):
|
@@ -376,7 +418,22 @@ class JamWorker(threading.Thread):
|
|
376 |
|
377 |
# If a reseed is queued, install it *right after* we finish a chunk
|
378 |
with self._lock:
|
379 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
380 |
new_state = self.mrt.init_state()
|
381 |
new_state.context_tokens = self._pending_reseed["ctx"]
|
382 |
self.state = new_state
|
|
|
71 |
# -----------------------------
|
72 |
|
73 |
class JamWorker(threading.Thread):
|
74 |
+
FRAMES_PER_SECOND: float | None = None # filled in __init__ once codec is available
|
75 |
"""Generates continuous audio with MagentaRT, spools it at target SR,
|
76 |
and emits *sample-accurate*, bar-aligned chunks (no FPS drift)."""
|
77 |
|
|
|
94 |
|
95 |
# codec/setup
|
96 |
self._codec_fps = float(self.mrt.codec.frame_rate)
|
97 |
+
JamWorker.FRAMES_PER_SECOND = self._codec_fps
|
98 |
self._ctx_frames = int(self.mrt.config.context_length_frames)
|
99 |
self._ctx_seconds = self._ctx_frames / self._codec_fps
|
100 |
|
|
|
123 |
self._stop_event = threading.Event()
|
124 |
self._max_buffer_ahead = 5
|
125 |
|
126 |
+
# reseed queues (install at next bar boundary after emission)
|
127 |
+
self._pending_reseed: Optional[dict] = None # legacy full reset path (kept for fallback)
|
128 |
+
self._pending_token_splice: Optional[dict] = None # seamless token splice
|
129 |
|
130 |
# Prepare initial context from combined loop (best musical alignment)
|
131 |
if self.params.combined_loop is not None:
|
|
|
257 |
self._original_context_tokens = np.copy(context_tokens)
|
258 |
|
259 |
def reseed_splice(self, recent_wav: au.Waveform, anchor_bars: float):
|
260 |
+
"""Queue a *seamless* reseed by token splicing instead of full restart.
|
261 |
+
We compute a fresh, bar-locked context token tensor of exact length
|
262 |
+
(e.g., 250 frames), then splice only the *tail* corresponding to
|
263 |
+
`anchor_bars` so generation continues smoothly without resetting state.
|
264 |
+
"""
|
265 |
+
new_ctx = self._encode_exact_context_tokens(recent_wav) # (F,D)
|
266 |
+
F = int(self._ctx_frames)
|
267 |
+
D = int(self.mrt.config.decoder_codec_rvq_depth)
|
268 |
+
assert new_ctx.shape == (F, D), f"expected {(F, D)}, got {new_ctx.shape}"
|
269 |
+
|
270 |
+
# how many frames correspond to the requested anchor bars
|
271 |
+
spb = self._bar_clock.seconds_per_bar()
|
272 |
+
frames_per_bar = int(round(self._codec_fps * spb))
|
273 |
+
splice_frames = int(round(max(1, anchor_bars) * frames_per_bar))
|
274 |
+
splice_frames = max(1, min(splice_frames, F))
|
275 |
+
|
276 |
with self._lock:
|
277 |
+
# snapshot current context
|
278 |
+
cur = getattr(self.state, "context_tokens", None)
|
279 |
+
if cur is None:
|
280 |
+
# if state has no context yet, fall back to full reseed
|
281 |
+
self._pending_reseed = {"ctx": new_ctx}
|
282 |
+
return
|
283 |
+
if cur.shape != (F, D):
|
284 |
+
# safety: coerce by trim/pad
|
285 |
+
if cur.shape[0] > F:
|
286 |
+
cur = cur[-F:, :]
|
287 |
+
elif cur.shape[0] < F:
|
288 |
+
pad = np.repeat(cur[0:1, :], F - cur.shape[0], axis=0)
|
289 |
+
cur = np.concatenate([pad, cur], axis=0)
|
290 |
+
if cur.shape[1] != D:
|
291 |
+
cur = cur[:, :D]
|
292 |
+
|
293 |
+
# build the spliced tensor: keep left (F - splice) from cur, take right (splice) from new
|
294 |
+
left = cur[:F - splice_frames, :]
|
295 |
+
right = new_ctx[F - splice_frames:, :]
|
296 |
+
spliced = np.concatenate([left, right], axis=0)
|
297 |
+
|
298 |
+
# queue for install at the *next bar boundary* right after emission
|
299 |
+
self._pending_token_splice = {
|
300 |
+
"tokens": spliced,
|
301 |
+
"debug": {"F": F, "D": D, "splice_frames": splice_frames, "frames_per_bar": frames_per_bar}
|
302 |
+
}
|
303 |
|
304 |
|
305 |
def reseed_from_waveform(self, wav: au.Waveform):
|
|
|
418 |
|
419 |
# If a reseed is queued, install it *right after* we finish a chunk
|
420 |
with self._lock:
|
421 |
+
# Prefer seamless token splice when available
|
422 |
+
if self._pending_token_splice is not None:
|
423 |
+
try:
|
424 |
+
spliced = self._pending_token_splice["tokens"]
|
425 |
+
self.state.context_tokens = spliced # in-place, no reset
|
426 |
+
self._pending_token_splice = None
|
427 |
+
# do NOT reset self._model_stream — keep continuity
|
428 |
+
# leave params/style as-is
|
429 |
+
except Exception as e:
|
430 |
+
# fallback: full reseed if setter rejects
|
431 |
+
new_state = self.mrt.init_state()
|
432 |
+
new_state.context_tokens = spliced
|
433 |
+
self.state = new_state
|
434 |
+
self._model_stream = None
|
435 |
+
self._pending_token_splice = None
|
436 |
+
elif self._pending_reseed is not None:
|
437 |
new_state = self.mrt.init_state()
|
438 |
new_state.context_tokens = self._pending_reseed["ctx"]
|
439 |
self.state = new_state
|