Commit
·
a9330a3
1
Parent(s):
c188e5c
bar aligned context in reseed
Browse files- jam_worker.py +30 -32
jam_worker.py
CHANGED
@@ -233,18 +233,20 @@ class JamWorker(threading.Thread):
|
|
233 |
|
234 |
def _make_recent_tokens_from_wave(self, wav) -> np.ndarray:
|
235 |
"""
|
236 |
-
Encode
|
237 |
-
as state.context_tokens). Uses your existing codec depth.
|
238 |
"""
|
239 |
-
tokens_full = self.mrt.codec.encode(wav).astype(np.int32)
|
240 |
-
tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
|
|
|
|
|
|
248 |
|
249 |
def _bar_aligned_tail(self, tokens: np.ndarray, bars: float) -> np.ndarray:
|
250 |
"""
|
@@ -260,42 +262,38 @@ class JamWorker(threading.Thread):
|
|
260 |
return tokens[-want:]
|
261 |
|
262 |
def _splice_context(self, original_tokens: np.ndarray, recent_tokens: np.ndarray,
|
263 |
-
|
264 |
-
|
265 |
-
Build new context by concatenating:
|
266 |
-
anchor = tail from originals (anchor_bars)
|
267 |
-
recent = tail from recent_tokens filling the remainder
|
268 |
-
Then clamp to ctx_frames from the tail (safety).
|
269 |
-
"""
|
270 |
ctx_frames = self._ctx_frames()
|
271 |
depth = original_tokens.shape[1]
|
|
|
272 |
|
273 |
-
# 1)
|
274 |
-
|
|
|
275 |
|
276 |
-
# 2)
|
277 |
a = anchor.shape[0]
|
278 |
remain = max(ctx_frames - a, 0)
|
279 |
-
|
280 |
-
# 3) Take bar-aligned recent tail not exceeding 'remain' (rounded to bars)
|
281 |
if remain > 0:
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
# if we can’t fit even one bar, just take the exact frame remainder
|
286 |
-
if recent_bars_fit >= 1:
|
287 |
-
want_recent_frames = recent_bars_fit * frames_per_bar
|
288 |
recent = recent_tokens[-want_recent_frames:] if recent_tokens.shape[0] > want_recent_frames else recent_tokens
|
289 |
else:
|
290 |
recent = recent_tokens[-remain:] if recent_tokens.shape[0] > remain else recent_tokens
|
291 |
else:
|
292 |
recent = recent_tokens[:0]
|
293 |
|
294 |
-
|
295 |
-
out = np.concatenate([anchor, recent], axis=0) if anchor.size or recent.size else recent_tokens[-ctx_frames:]
|
296 |
if out.shape[0] > ctx_frames:
|
297 |
out = out[-ctx_frames:]
|
298 |
-
|
|
|
|
|
|
|
|
|
|
|
299 |
if out.shape[1] != depth:
|
300 |
out = out[:, :depth]
|
301 |
return out
|
|
|
233 |
|
234 |
def _make_recent_tokens_from_wave(self, wav) -> np.ndarray:
|
235 |
"""
|
236 |
+
Encode waveform and produce a BAR-ALIGNED context token window.
|
|
|
237 |
"""
|
238 |
+
tokens_full = self.mrt.codec.encode(wav).astype(np.int32) # [T, rvq_total]
|
239 |
+
tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
|
240 |
+
|
241 |
+
from utils import make_bar_aligned_context
|
242 |
+
ctx = make_bar_aligned_context(
|
243 |
+
tokens,
|
244 |
+
bpm=self.params.bpm,
|
245 |
+
fps=int(self.mrt.codec.frame_rate),
|
246 |
+
ctx_frames=self.mrt.config.context_length_frames,
|
247 |
+
beats_per_bar=self.params.beats_per_bar
|
248 |
+
)
|
249 |
+
return ctx
|
250 |
|
251 |
def _bar_aligned_tail(self, tokens: np.ndarray, bars: float) -> np.ndarray:
|
252 |
"""
|
|
|
262 |
return tokens[-want:]
|
263 |
|
264 |
def _splice_context(self, original_tokens: np.ndarray, recent_tokens: np.ndarray,
|
265 |
+
anchor_bars: float) -> np.ndarray:
|
266 |
+
import math
|
|
|
|
|
|
|
|
|
|
|
267 |
ctx_frames = self._ctx_frames()
|
268 |
depth = original_tokens.shape[1]
|
269 |
+
frames_per_bar = self._frames_per_bar()
|
270 |
|
271 |
+
# 1) Anchor tail
|
272 |
+
# Use floor, not round, to avoid grabbing an extra bar.
|
273 |
+
anchor = self._bar_aligned_tail(original_tokens, math.floor(anchor_bars))
|
274 |
|
275 |
+
# 2) Fill remainder with recent (in whole bars when possible)
|
276 |
a = anchor.shape[0]
|
277 |
remain = max(ctx_frames - a, 0)
|
|
|
|
|
278 |
if remain > 0:
|
279 |
+
bars_fit = remain // frames_per_bar
|
280 |
+
if bars_fit >= 1:
|
281 |
+
want_recent_frames = int(bars_fit * frames_per_bar)
|
|
|
|
|
|
|
282 |
recent = recent_tokens[-want_recent_frames:] if recent_tokens.shape[0] > want_recent_frames else recent_tokens
|
283 |
else:
|
284 |
recent = recent_tokens[-remain:] if recent_tokens.shape[0] > remain else recent_tokens
|
285 |
else:
|
286 |
recent = recent_tokens[:0]
|
287 |
|
288 |
+
out = np.concatenate([anchor, recent], axis=0) if (anchor.size or recent.size) else recent_tokens[-ctx_frames:]
|
|
|
289 |
if out.shape[0] > ctx_frames:
|
290 |
out = out[-ctx_frames:]
|
291 |
+
|
292 |
+
# --- NEW: force total length to a whole number of bars
|
293 |
+
max_bar_aligned = (out.shape[0] // frames_per_bar) * frames_per_bar
|
294 |
+
if max_bar_aligned > 0 and out.shape[0] != max_bar_aligned:
|
295 |
+
out = out[-max_bar_aligned:]
|
296 |
+
|
297 |
if out.shape[1] != depth:
|
298 |
out = out[:, :depth]
|
299 |
return out
|