thecollabagepatch commited on
Commit
a9330a3
·
1 Parent(s): c188e5c

bar aligned context in reseed

Browse files
Files changed (1) hide show
  1. jam_worker.py +30 -32
jam_worker.py CHANGED
@@ -233,18 +233,20 @@ class JamWorker(threading.Thread):
233
 
234
  def _make_recent_tokens_from_wave(self, wav) -> np.ndarray:
235
  """
236
- Encode a waveform and produce a bar-aligned context token window (same shape/depth
237
- as state.context_tokens). Uses your existing codec depth.
238
  """
239
- tokens_full = self.mrt.codec.encode(wav).astype(np.int32) # [T, rvq_total]
240
- tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth] # [T, depth]
241
- # If you already have a utility that builds bar-aligned context windows, prefer it.
242
- # Otherwise clamp to ctx_frames from the tail (bar-aligned trimming happens in splicer).
243
- t = tokens.shape[0]
244
- ctx = self._ctx_frames()
245
- if t > ctx:
246
- tokens = tokens[-ctx:]
247
- return tokens
 
 
 
248
 
249
  def _bar_aligned_tail(self, tokens: np.ndarray, bars: float) -> np.ndarray:
250
  """
@@ -260,42 +262,38 @@ class JamWorker(threading.Thread):
260
  return tokens[-want:]
261
 
262
  def _splice_context(self, original_tokens: np.ndarray, recent_tokens: np.ndarray,
263
- anchor_bars: float) -> np.ndarray:
264
- """
265
- Build new context by concatenating:
266
- anchor = tail from originals (anchor_bars)
267
- recent = tail from recent_tokens filling the remainder
268
- Then clamp to ctx_frames from the tail (safety).
269
- """
270
  ctx_frames = self._ctx_frames()
271
  depth = original_tokens.shape[1]
 
272
 
273
- # 1) Take bar-aligned tail from original
274
- anchor = self._bar_aligned_tail(original_tokens, anchor_bars) # [A, depth]
 
275
 
276
- # 2) Compute how many frames remain for recent
277
  a = anchor.shape[0]
278
  remain = max(ctx_frames - a, 0)
279
-
280
- # 3) Take bar-aligned recent tail not exceeding 'remain' (rounded to bars)
281
  if remain > 0:
282
- # how many bars fit in remain?
283
- frames_per_bar = self._frames_per_bar()
284
- recent_bars_fit = int(remain // frames_per_bar)
285
- # if we can’t fit even one bar, just take the exact frame remainder
286
- if recent_bars_fit >= 1:
287
- want_recent_frames = recent_bars_fit * frames_per_bar
288
  recent = recent_tokens[-want_recent_frames:] if recent_tokens.shape[0] > want_recent_frames else recent_tokens
289
  else:
290
  recent = recent_tokens[-remain:] if recent_tokens.shape[0] > remain else recent_tokens
291
  else:
292
  recent = recent_tokens[:0]
293
 
294
- # 4) Concat and clamp again (exact)
295
- out = np.concatenate([anchor, recent], axis=0) if anchor.size or recent.size else recent_tokens[-ctx_frames:]
296
  if out.shape[0] > ctx_frames:
297
  out = out[-ctx_frames:]
298
- # safety on depth
 
 
 
 
 
299
  if out.shape[1] != depth:
300
  out = out[:, :depth]
301
  return out
 
233
 
234
  def _make_recent_tokens_from_wave(self, wav) -> np.ndarray:
235
  """
236
+ Encode waveform and produce a BAR-ALIGNED context token window.
 
237
  """
238
+ tokens_full = self.mrt.codec.encode(wav).astype(np.int32) # [T, rvq_total]
239
+ tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
240
+
241
+ from utils import make_bar_aligned_context
242
+ ctx = make_bar_aligned_context(
243
+ tokens,
244
+ bpm=self.params.bpm,
245
+ fps=int(self.mrt.codec.frame_rate),
246
+ ctx_frames=self.mrt.config.context_length_frames,
247
+ beats_per_bar=self.params.beats_per_bar
248
+ )
249
+ return ctx
250
 
251
  def _bar_aligned_tail(self, tokens: np.ndarray, bars: float) -> np.ndarray:
252
  """
 
262
  return tokens[-want:]
263
 
264
  def _splice_context(self, original_tokens: np.ndarray, recent_tokens: np.ndarray,
265
+ anchor_bars: float) -> np.ndarray:
266
+ import math
 
 
 
 
 
267
  ctx_frames = self._ctx_frames()
268
  depth = original_tokens.shape[1]
269
+ frames_per_bar = self._frames_per_bar()
270
 
271
+ # 1) Anchor tail
272
+ # Use floor, not round, to avoid grabbing an extra bar.
273
+ anchor = self._bar_aligned_tail(original_tokens, math.floor(anchor_bars))
274
 
275
+ # 2) Fill remainder with recent (in whole bars when possible)
276
  a = anchor.shape[0]
277
  remain = max(ctx_frames - a, 0)
 
 
278
  if remain > 0:
279
+ bars_fit = remain // frames_per_bar
280
+ if bars_fit >= 1:
281
+ want_recent_frames = int(bars_fit * frames_per_bar)
 
 
 
282
  recent = recent_tokens[-want_recent_frames:] if recent_tokens.shape[0] > want_recent_frames else recent_tokens
283
  else:
284
  recent = recent_tokens[-remain:] if recent_tokens.shape[0] > remain else recent_tokens
285
  else:
286
  recent = recent_tokens[:0]
287
 
288
+ out = np.concatenate([anchor, recent], axis=0) if (anchor.size or recent.size) else recent_tokens[-ctx_frames:]
 
289
  if out.shape[0] > ctx_frames:
290
  out = out[-ctx_frames:]
291
+
292
+ # --- NEW: force total length to a whole number of bars
293
+ max_bar_aligned = (out.shape[0] // frames_per_bar) * frames_per_bar
294
+ if max_bar_aligned > 0 and out.shape[0] != max_bar_aligned:
295
+ out = out[-max_bar_aligned:]
296
+
297
  if out.shape[1] != depth:
298
  out = out[:, :depth]
299
  return out