thecollabagepatch commited on
Commit
2ab3869
·
1 Parent(s): 946680d

final reseed splice fix

Browse files
Files changed (1) hide show
  1. jam_worker.py +100 -24
jam_worker.py CHANGED
@@ -174,6 +174,91 @@ class JamWorker(threading.Thread):
174
 
175
  # ---------- context / reseed ----------
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  def _encode_exact_context_tokens(self, loop: au.Waveform) -> np.ndarray:
178
  """Build *exactly* context_length_frames worth of tokens (e.g., 250 @ 25fps),
179
  while ensuring the *end* of the audio lands on a bar boundary.
@@ -262,44 +347,35 @@ class JamWorker(threading.Thread):
262
  (e.g., 250 frames), then splice only the *tail* corresponding to
263
  `anchor_bars` so generation continues smoothly without resetting state.
264
  """
265
- new_ctx = self._encode_exact_context_tokens(recent_wav) # (F,D)
266
- F = int(self._ctx_frames)
267
- D = int(self.mrt.config.decoder_codec_rvq_depth)
268
- assert new_ctx.shape == (F, D), f"expected {(F, D)}, got {new_ctx.shape}"
269
 
270
  # how many frames correspond to the requested anchor bars
271
  spb = self._bar_clock.seconds_per_bar()
272
- frames_per_bar = int(round(self._codec_fps * spb))
273
- splice_frames = int(round(max(1, anchor_bars) * frames_per_bar))
274
- splice_frames = max(1, min(splice_frames, F))
275
 
276
  with self._lock:
277
  # snapshot current context
278
  cur = getattr(self.state, "context_tokens", None)
279
  if cur is None:
280
- # if state has no context yet, fall back to full reseed
281
  self._pending_reseed = {"ctx": new_ctx}
282
  return
283
- if cur.shape != (F, D):
284
- # safety: coerce by trim/pad
285
- if cur.shape[0] > F:
286
- cur = cur[-F:, :]
287
- elif cur.shape[0] < F:
288
- pad = np.repeat(cur[0:1, :], F - cur.shape[0], axis=0)
289
- cur = np.concatenate([pad, cur], axis=0)
290
- if cur.shape[1] != D:
291
- cur = cur[:, :D]
292
 
293
  # build the spliced tensor: keep left (F - splice) from cur, take right (splice) from new
294
  left = cur[:F - splice_frames, :]
295
  right = new_ctx[F - splice_frames:, :]
296
  spliced = np.concatenate([left, right], axis=0)
 
297
 
298
  # queue for install at the *next bar boundary* right after emission
299
  self._pending_token_splice = {
300
  "tokens": spliced,
301
  "debug": {"F": F, "D": D, "splice_frames": splice_frames, "frames_per_bar": frames_per_bar}
302
  }
 
303
 
304
 
305
  def reseed_from_waveform(self, wav: au.Waveform):
@@ -420,22 +496,22 @@ class JamWorker(threading.Thread):
420
  with self._lock:
421
  # Prefer seamless token splice when available
422
  if self._pending_token_splice is not None:
 
423
  try:
424
- spliced = self._pending_token_splice["tokens"]
425
- self.state.context_tokens = spliced # in-place, no reset
426
  self._pending_token_splice = None
427
- # do NOT reset self._model_stream — keep continuity
428
- # leave params/style as-is
429
- except Exception as e:
430
- # fallback: full reseed if setter rejects
431
  new_state = self.mrt.init_state()
432
  new_state.context_tokens = spliced
433
  self.state = new_state
434
  self._model_stream = None
435
  self._pending_token_splice = None
436
  elif self._pending_reseed is not None:
 
437
  new_state = self.mrt.init_state()
438
- new_state.context_tokens = self._pending_reseed["ctx"]
439
  self.state = new_state
440
  self._model_stream = None
441
  self._pending_reseed = None
 
174
 
175
  # ---------- context / reseed ----------
176
 
177
+ def _expected_token_shape(self) -> Tuple[int, int]:
178
+ F = int(self._ctx_frames)
179
+ D = int(self.mrt.config.decoder_codec_rvq_depth)
180
+ return F, D
181
+
182
+ def _coerce_tokens(self, toks: np.ndarray) -> np.ndarray:
183
+ """Force tokens to (context_length_frames, rvq_depth), padding/trimming as needed.
184
+ Pads missing frames by repeating the last frame (safer than zeros for RVQ stacks)."""
185
+ F, D = self._expected_token_shape()
186
+ if toks.ndim != 2:
187
+ toks = np.atleast_2d(toks)
188
+ # depth first
189
+ if toks.shape[1] > D:
190
+ toks = toks[:, :D]
191
+ elif toks.shape[1] < D:
192
+ pad_cols = np.tile(toks[:, -1:], (1, D - toks.shape[1]))
193
+ toks = np.concatenate([toks, pad_cols], axis=1)
194
+ # frames
195
+ if toks.shape[0] < F:
196
+ if toks.shape[0] == 0:
197
+ toks = np.zeros((1, D), dtype=np.int32)
198
+ pad = np.repeat(toks[-1:, :], F - toks.shape[0], axis=0)
199
+ toks = np.concatenate([pad, toks], axis=0)
200
+ elif toks.shape[0] > F:
201
+ toks = toks[-F:, :]
202
+ if toks.dtype != np.int32:
203
+ toks = toks.astype(np.int32, copy=False)
204
+ return toks
205
+
206
+ def _encode_exact_context_tokens(self, loop: au.Waveform) -> np.ndarray:
207
+ """Build *exactly* context_length_frames worth of tokens (e.g., 250 @ 25fps),
208
+ while ensuring the *end* of the audio lands on a bar boundary.
209
+ Strategy: take the largest integer number of bars <= ctx_seconds as the tail,
210
+ then left-fill from just before that tail (wrapping if needed) to reach exactly
211
+ ctx_seconds; finally, pad/trim to exact samples and, as a last resort, pad/trim
212
+ tokens to the expected frame count.
213
+ """
214
+ wav = loop.as_stereo().resample(self._model_sr)
215
+ data = wav.samples.astype(np.float32, copy=False)
216
+ if data.ndim == 1:
217
+ data = data[:, None]
218
+
219
+ spb = self._bar_clock.seconds_per_bar()
220
+ ctx_sec = float(self._ctx_seconds)
221
+ sr = int(self._model_sr)
222
+
223
+ # bars that fit fully inside ctx_sec (at least 1)
224
+ bars_fit = max(1, int(ctx_sec // spb))
225
+ tail_len_samps = int(round(bars_fit * spb * sr))
226
+
227
+ # ensure we have enough source by tiling
228
+ need = int(round(ctx_sec * sr)) + tail_len_samps
229
+ if data.shape[0] == 0:
230
+ data = np.zeros((1, 2), dtype=np.float32)
231
+ reps = int(np.ceil(need / float(data.shape[0])))
232
+ tiled = np.tile(data, (reps, 1))
233
+
234
+ end = tiled.shape[0]
235
+ tail = tiled[end - tail_len_samps:end]
236
+
237
+ # left-fill to reach exact ctx samples (keeps end-of-bar alignment)
238
+ ctx_samps = int(round(ctx_sec * sr))
239
+ pad_len = ctx_samps - tail.shape[0]
240
+ if pad_len > 0:
241
+ pre = tiled[end - tail_len_samps - pad_len:end - tail_len_samps]
242
+ ctx = np.concatenate([pre, tail], axis=0)
243
+ else:
244
+ ctx = tail[-ctx_samps:]
245
+
246
+ # final snap to *exact* ctx samples
247
+ if ctx.shape[0] < ctx_samps:
248
+ pad = np.zeros((ctx_samps - ctx.shape[0], ctx.shape[1]), dtype=np.float32)
249
+ ctx = np.concatenate([pad, ctx], axis=0)
250
+ elif ctx.shape[0] > ctx_samps:
251
+ ctx = ctx[-ctx_samps:]
252
+
253
+ exact = au.Waveform(ctx, sr)
254
+ tokens_full = self.mrt.codec.encode(exact).astype(np.int32)
255
+ depth = int(self.mrt.config.decoder_codec_rvq_depth)
256
+ tokens = tokens_full[:, :depth]
257
+
258
+ # Force expected (F,D) at *return time*
259
+ tokens = self._coerce_tokens(tokens)
260
+ return tokens
261
+
262
  def _encode_exact_context_tokens(self, loop: au.Waveform) -> np.ndarray:
263
  """Build *exactly* context_length_frames worth of tokens (e.g., 250 @ 25fps),
264
  while ensuring the *end* of the audio lands on a bar boundary.
 
347
  (e.g., 250 frames), then splice only the *tail* corresponding to
348
  `anchor_bars` so generation continues smoothly without resetting state.
349
  """
350
+ new_ctx = self._encode_exact_context_tokens(recent_wav) # coerce to (F,D)
351
+ F, D = self._expected_token_shape()
 
 
352
 
353
  # how many frames correspond to the requested anchor bars
354
  spb = self._bar_clock.seconds_per_bar()
355
+ frames_per_bar = max(1, int(round(self._codec_fps * spb)))
356
+ splice_frames = max(1, min(int(round(max(1.0, float(anchor_bars)) * frames_per_bar)), F))
 
357
 
358
  with self._lock:
359
  # snapshot current context
360
  cur = getattr(self.state, "context_tokens", None)
361
  if cur is None:
362
+ # fall back to full reseed (still coerced)
363
  self._pending_reseed = {"ctx": new_ctx}
364
  return
365
+ cur = self._coerce_tokens(cur)
 
 
 
 
 
 
 
 
366
 
367
  # build the spliced tensor: keep left (F - splice) from cur, take right (splice) from new
368
  left = cur[:F - splice_frames, :]
369
  right = new_ctx[F - splice_frames:, :]
370
  spliced = np.concatenate([left, right], axis=0)
371
+ spliced = self._coerce_tokens(spliced)
372
 
373
  # queue for install at the *next bar boundary* right after emission
374
  self._pending_token_splice = {
375
  "tokens": spliced,
376
  "debug": {"F": F, "D": D, "splice_frames": splice_frames, "frames_per_bar": frames_per_bar}
377
  }
378
+ }
379
 
380
 
381
  def reseed_from_waveform(self, wav: au.Waveform):
 
496
  with self._lock:
497
  # Prefer seamless token splice when available
498
  if self._pending_token_splice is not None:
499
+ spliced = self._coerce_tokens(self._pending_token_splice["tokens"])
500
  try:
501
+ # inplace update (no reset)
502
+ self.state.context_tokens = spliced
503
  self._pending_token_splice = None
504
+ except Exception:
505
+ # fallback: full reseed using spliced tokens
 
 
506
  new_state = self.mrt.init_state()
507
  new_state.context_tokens = spliced
508
  self.state = new_state
509
  self._model_stream = None
510
  self._pending_token_splice = None
511
  elif self._pending_reseed is not None:
512
+ ctx = self._coerce_tokens(self._pending_reseed["ctx"])
513
  new_state = self.mrt.init_state()
514
+ new_state.context_tokens = ctx
515
  self.state = new_state
516
  self._model_stream = None
517
  self._pending_reseed = None