thecollabagepatch commited on
Commit
4bdf506
·
1 Parent(s): db53efe

ok one last try

Browse files
Files changed (2) hide show
  1. jam_worker.py +404 -355
  2. utils.py +62 -36
jam_worker.py CHANGED
@@ -1,5 +1,5 @@
1
- # jam_worker.py - SIMPLE FIX VERSION
2
- import threading, time, base64, io, uuid
3
  from dataclasses import dataclass, field
4
  import numpy as np
5
  import soundfile as sf
@@ -8,7 +8,7 @@ from threading import RLock
8
  from utils import (
9
  match_loudness_to_reference, stitch_generated, hard_trim_seconds,
10
  apply_micro_fades, make_bar_aligned_context, take_bar_aligned_tail,
11
- resample_and_snap, wav_bytes_base64
12
  )
13
 
14
  @dataclass
@@ -32,6 +32,34 @@ class JamChunk:
32
  audio_base64: str
33
  metadata: dict
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  class JamWorker(threading.Thread):
36
  def __init__(self, mrt, params: JamParams):
37
  super().__init__(daemon=True)
@@ -39,9 +67,32 @@ class JamWorker(threading.Thread):
39
  self.params = params
40
  self.state = mrt.init_state()
41
 
42
- # init synchronization + placeholders FIRST
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  self._lock = threading.Lock()
44
- self._original_context_tokens = None # so hasattr checks are cheap/clear
45
 
46
  if params.combined_loop is not None:
47
  self._setup_context_from_combined_loop()
@@ -50,28 +101,39 @@ class JamWorker(threading.Thread):
50
  self.outbox: list[JamChunk] = []
51
  self._stop_event = threading.Event()
52
 
 
53
  self._stream = None
54
- self._next_emit_start = 0
55
-
56
- # NEW: Track delivery state
57
  self._last_delivered_index = 0
58
  self._max_buffer_ahead = 5
59
 
 
 
 
 
 
 
 
 
 
 
60
  # Timing info
61
  self.last_chunk_started_at = None
62
  self.last_chunk_completed_at = None
63
 
64
- self._pending_reseed = None # {"ctx": np.ndarray, "ref": au.Waveform|None}
65
- self._needs_bar_realign = False # request a one-shot downbeat alignment
66
- self._reseed_ref_loop = None # which loop to align against after reseed
67
-
68
 
69
  def _setup_context_from_combined_loop(self):
70
  """Set up MRT context tokens from the combined loop audio"""
71
  try:
72
  from utils import make_bar_aligned_context, take_bar_aligned_tail
73
 
74
- codec_fps = float(self.mrt.codec.frame_rate)
75
  ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps
76
 
77
  loop_for_context = take_bar_aligned_tail(
@@ -84,452 +146,381 @@ class JamWorker(threading.Thread):
84
  tokens_full = self.mrt.codec.encode(loop_for_context).astype(np.int32)
85
  tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
86
 
 
87
  context_tokens = make_bar_aligned_context(
88
  tokens,
89
  bpm=self.params.bpm,
90
- fps=float(self.mrt.codec.frame_rate), # keep fractional fps
91
  ctx_frames=self.mrt.config.context_length_frames,
92
- beats_per_bar=self.params.beats_per_bar
 
93
  )
94
 
95
- # Install fresh context
96
  self.state.context_tokens = context_tokens
97
- print(f" JamWorker: Set up fresh context from combined loop")
98
 
99
- # NEW: keep a copy of the *original* context tokens for future splice-reseed
100
- # (guard so we only set this once, at jam start)
101
  with self._lock:
102
  if not hasattr(self, "_original_context_tokens") or self._original_context_tokens is None:
103
- self._original_context_tokens = np.copy(context_tokens) # shape: [T, depth]
104
 
105
  except Exception as e:
106
- print(f"Failed to setup context from combined loop: {e}")
107
 
108
  def stop(self):
109
  self._stop_event.set()
110
 
111
  def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None):
112
  with self._lock:
113
- if guidance_weight is not None: self.params.guidance_weight = float(guidance_weight)
114
- if temperature is not None: self.params.temperature = float(temperature)
115
- if topk is not None: self.params.topk = int(topk)
 
 
 
116
 
117
  def get_next_chunk(self) -> JamChunk | None:
118
  """Get the next sequential chunk (blocks/waits if not ready)"""
119
  target_index = self._last_delivered_index + 1
120
 
121
- # Wait for the target chunk to be ready (with timeout)
122
- max_wait = 30.0 # seconds
123
  start_time = time.time()
124
 
125
  while time.time() - start_time < max_wait and not self._stop_event.is_set():
126
  with self._lock:
127
- # Look for the exact chunk we need
128
  for chunk in self.outbox:
129
  if chunk.index == target_index:
130
  self._last_delivered_index = target_index
131
- print(f"📦 Delivered chunk {target_index}")
132
  return chunk
133
-
134
- # Not ready yet, wait a bit
135
  time.sleep(0.1)
136
 
137
- # Timeout or stopped
138
  return None
139
 
140
  def mark_chunk_consumed(self, chunk_index: int):
141
  """Mark a chunk as consumed by the frontend"""
142
  with self._lock:
143
  self._last_delivered_index = max(self._last_delivered_index, chunk_index)
144
- print(f"✅ Chunk {chunk_index} consumed")
145
 
146
  def _should_generate_next_chunk(self) -> bool:
147
- """Check if we should generate the next chunk (don't get too far ahead)"""
148
  with self._lock:
149
- # Don't generate if we're already too far ahead
150
- if self.idx > self._last_delivered_index + self._max_buffer_ahead:
151
- return False
152
- return True
153
-
154
- def _seconds_per_bar(self) -> float:
155
- return self.params.beats_per_bar * (60.0 / self.params.bpm)
156
-
157
- def _snap_and_encode(self, y, seconds, target_sr, bars):
158
- cur_sr = int(self.mrt.sample_rate)
159
- x = y.samples if y.samples.ndim == 2 else y.samples[:, None]
160
- x = resample_and_snap(x, cur_sr=cur_sr, target_sr=target_sr, seconds=seconds)
161
- b64, total_samples, channels = wav_bytes_base64(x, target_sr)
162
- meta = {
163
- "bpm": int(round(self.params.bpm)),
164
- "bars": int(bars),
165
- "beats_per_bar": int(self.params.beats_per_bar),
166
- "sample_rate": int(target_sr),
167
- "channels": channels,
168
- "total_samples": total_samples,
169
- "seconds_per_bar": self._seconds_per_bar(),
170
- "loop_duration_seconds": bars * self._seconds_per_bar(),
171
- "guidance_weight": self.params.guidance_weight,
172
- "temperature": self.params.temperature,
173
- "topk": self.params.topk,
174
- }
175
- return b64, meta
176
 
177
  def _append_model_chunk_to_stream(self, wav):
178
- """Incrementally append a model chunk with equal-power crossfade."""
179
  xfade_s = float(self.mrt.config.crossfade_length)
180
- sr = int(self.mrt.sample_rate)
181
  xfade_n = int(round(xfade_s * sr))
182
 
183
  s = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None]
184
 
185
- if getattr(self, "_stream", None) is None:
186
- # First chunk: drop model pre-roll (xfade head)
187
  if s.shape[0] > xfade_n:
188
  self._stream = s[xfade_n:].astype(np.float32, copy=True)
189
  else:
190
  self._stream = np.zeros((0, s.shape[1]), dtype=np.float32)
191
- self._next_emit_start = 0 # pointer into _stream (model SR samples)
192
  return
193
 
194
- # Crossfade last xfade_n samples of _stream with head of new s
195
  if s.shape[0] <= xfade_n or self._stream.shape[0] < xfade_n:
196
- # Degenerate safeguard
197
  self._stream = np.concatenate([self._stream, s], axis=0)
 
198
  return
199
 
 
200
  tail = self._stream[-xfade_n:]
201
  head = s[:xfade_n]
202
 
203
- # Equal-power envelopes
204
  t = np.linspace(0, np.pi/2, xfade_n, endpoint=False, dtype=np.float32)[:, None]
205
  eq_in, eq_out = np.sin(t), np.cos(t)
206
  mixed = tail * eq_out + head * eq_in
207
 
208
  self._stream = np.concatenate([self._stream[:-xfade_n], mixed, s[xfade_n:]], axis=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  def reseed_from_waveform(self, wav):
211
- # 1) Re-init state
212
  new_state = self.mrt.init_state()
213
-
214
- # 2) Build bar-aligned context tokens from provided audio
215
- codec_fps = float(self.mrt.codec.frame_rate)
216
  ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps
217
- from utils import take_bar_aligned_tail, make_bar_aligned_context
218
-
219
  tail = take_bar_aligned_tail(wav, self.params.bpm, self.params.beats_per_bar, ctx_seconds)
220
  tokens_full = self.mrt.codec.encode(tail).astype(np.int32)
221
  tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
222
- context_tokens = make_bar_aligned_context(tokens,
223
- bpm=self.params.bpm, fps=float(self.mrt.codec.frame_rate),
 
 
 
224
  ctx_frames=self.mrt.config.context_length_frames,
225
- beats_per_bar=self.params.beats_per_bar
 
226
  )
 
227
  new_state.context_tokens = context_tokens
228
  self.state = new_state
229
- self._prepare_stream_for_reseed_handoff()
230
-
231
- def _frames_per_bar(self) -> int:
232
- # codec frame-rate (frames/s) -> frames per musical bar
233
- fps = float(self.mrt.codec.frame_rate)
234
- sec_per_bar = (60.0 / float(self.params.bpm)) * float(self.params.beats_per_bar)
235
- return int(round(fps * sec_per_bar))
236
-
237
- def _ctx_frames(self) -> int:
238
- # how many codec frames fit in the model’s conditioning window
239
- return int(self.mrt.config.context_length_frames)
240
-
241
- def _make_recent_tokens_from_wave(self, wav) -> np.ndarray:
242
- """
243
- Encode waveform and produce a BAR-ALIGNED context token window.
244
- """
245
- tokens_full = self.mrt.codec.encode(wav).astype(np.int32) # [T, rvq_total]
246
- tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
247
-
248
- from utils import make_bar_aligned_context
249
- ctx = make_bar_aligned_context(
250
- tokens,
251
- bpm=self.params.bpm,
252
- fps=float(self.mrt.codec.frame_rate), # keep fractional fps
253
- ctx_frames=self.mrt.config.context_length_frames,
254
- beats_per_bar=self.params.beats_per_bar
255
- )
256
- return ctx
257
-
258
- def _bar_aligned_tail(self, tokens: np.ndarray, bars: float) -> np.ndarray:
259
- """
260
- Take a tail slice that is an integer number of codec frames corresponding to `bars`.
261
- We round to nearest frame to stay phase-consistent with codec grid.
262
- """
263
- frames_per_bar = self._frames_per_bar()
264
- want = max(frames_per_bar * int(round(bars)), 0)
265
- if want == 0:
266
- return tokens[:0] # empty
267
- if tokens.shape[0] <= want:
268
- return tokens
269
- return tokens[-want:]
270
-
271
- def _splice_context(self, original_tokens: np.ndarray, recent_tokens: np.ndarray,
272
- anchor_bars: float) -> np.ndarray:
273
- import math
274
- ctx_frames = self._ctx_frames()
275
- depth = original_tokens.shape[1]
276
- frames_per_bar = self._frames_per_bar()
277
-
278
- # 1) Anchor tail (whole bars)
279
- anchor = self._bar_aligned_tail(original_tokens, math.floor(anchor_bars))
280
-
281
- # 2) Fill remainder with recent (prefer whole bars)
282
- a = anchor.shape[0]
283
- remain = max(ctx_frames - a, 0)
284
-
285
- recent = recent_tokens[:0]
286
- used_recent = 0 # frames taken from the END of recent_tokens
287
- if remain > 0:
288
- bars_fit = remain // frames_per_bar
289
- if bars_fit >= 1:
290
- want_recent_frames = int(bars_fit * frames_per_bar)
291
- used_recent = min(want_recent_frames, recent_tokens.shape[0])
292
- recent = recent_tokens[-used_recent:] if used_recent > 0 else recent_tokens[:0]
293
- else:
294
- used_recent = min(remain, recent_tokens.shape[0])
295
- recent = recent_tokens[-used_recent:] if used_recent > 0 else recent_tokens[:0]
296
-
297
- # 3) Concat in order [anchor, recent]
298
- if anchor.size or recent.size:
299
- out = np.concatenate([anchor, recent], axis=0)
300
- else:
301
- # fallback: just take the last ctx window from recent
302
- out = recent_tokens[-ctx_frames:]
303
-
304
- # 4) Trim if we overshot
305
- if out.shape[0] > ctx_frames:
306
- out = out[-ctx_frames:]
307
-
308
- # 5) Snap the **END** to the nearest LOWER bar boundary
309
- if frames_per_bar > 0:
310
- max_bar_aligned = (out.shape[0] // frames_per_bar) * frames_per_bar
311
- else:
312
- max_bar_aligned = out.shape[0]
313
- if max_bar_aligned > 0 and out.shape[0] != max_bar_aligned:
314
- out = out[-max_bar_aligned:]
315
-
316
- # 6) Left-fill to reach ctx_frames **without moving the END**
317
- deficit = ctx_frames - out.shape[0]
318
- if deficit > 0:
319
- left_parts = []
320
-
321
- # Prefer frames immediately BEFORE the region we used from 'recent_tokens'
322
- if used_recent < recent_tokens.shape[0]:
323
- take = min(deficit, recent_tokens.shape[0] - used_recent)
324
- if used_recent > 0:
325
- left_parts.append(recent_tokens[-(used_recent + take) : -used_recent])
326
- else:
327
- left_parts.append(recent_tokens[-take:])
328
-
329
- # Then take frames immediately BEFORE the 'anchor' in original_tokens
330
- if sum(p.shape[0] for p in left_parts) < deficit and anchor.shape[0] > 0:
331
- need = deficit - sum(p.shape[0] for p in left_parts)
332
- a_len = anchor.shape[0]
333
- avail = max(original_tokens.shape[0] - a_len, 0)
334
- take2 = min(need, avail)
335
- if take2 > 0:
336
- left_parts.append(original_tokens[-(a_len + take2) : -a_len])
337
-
338
- # Still short? tile from what's available
339
- have = sum(p.shape[0] for p in left_parts)
340
- if have < deficit:
341
- base = out if out.shape[0] > 0 else (recent_tokens if recent_tokens.shape[0] > 0 else original_tokens)
342
- reps = int(np.ceil((deficit - have) / max(1, base.shape[0])))
343
- left_parts.append(np.tile(base, (reps, 1))[: (deficit - have)])
344
-
345
- left = np.concatenate(left_parts, axis=0)
346
- out = np.concatenate([left[-deficit:], out], axis=0)
347
-
348
- # 7) Final guard to exact length
349
- if out.shape[0] > ctx_frames:
350
- out = out[-ctx_frames:]
351
- elif out.shape[0] < ctx_frames:
352
- reps = int(np.ceil(ctx_frames / max(1, out.shape[0])))
353
- out = np.tile(out, (reps, 1))[-ctx_frames:]
354
-
355
- # 8) Depth guard
356
- if out.shape[1] != depth:
357
- out = out[:, :depth]
358
- return out
359
-
360
-
361
- def _realign_emit_pointer_to_bar(self, sr_model: int):
362
- """Advance _next_emit_start to the next bar boundary in model-sample space."""
363
- bar_samps = int(round(self._seconds_per_bar() * sr_model))
364
- if bar_samps <= 0:
365
- return
366
- phase = self._next_emit_start % bar_samps
367
- if phase != 0:
368
- self._next_emit_start += (bar_samps - phase)
369
-
370
- def _prepare_stream_for_reseed_handoff(self):
371
- # OLD: keep crossfade tail -> causes phase offset
372
- # sr = int(self.mrt.sample_rate)
373
- # xfade_s = float(self.mrt.config.crossfade_length)
374
- # xfade_n = int(round(xfade_s * sr))
375
- # if getattr(self, "_stream", None) is not None and self._stream.shape[0] > 0:
376
- # tail = self._stream[-xfade_n:] if self._stream.shape[0] > xfade_n else self._stream
377
- # self._stream = tail.copy()
378
- # else:
379
- # self._stream = None
380
-
381
- # NEW: throw away the tail completely; start fresh
382
  self._stream = None
383
-
384
- self._next_emit_start = 0
 
 
 
385
  self._needs_bar_realign = True
 
386
 
387
  def reseed_splice(self, recent_wav, anchor_bars: float):
388
- """
389
- Token-splice reseed queued for the next bar boundary between chunks.
390
- """
391
  with self._lock:
392
  if not hasattr(self, "_original_context_tokens") or self._original_context_tokens is None:
393
  self._original_context_tokens = np.copy(self.state.context_tokens)
394
 
395
- recent_tokens = self._make_recent_tokens_from_wave(recent_wav) # [T, depth]
 
396
  new_ctx = self._splice_context(self._original_context_tokens, recent_tokens, anchor_bars)
397
 
398
- # Queue it; the run loop will install right after we finish the current slice
399
  self._pending_reseed = {"ctx": new_ctx, "ref": recent_wav}
400
-
401
- # install the new context window
402
  new_state = self.mrt.init_state()
403
  new_state.context_tokens = new_ctx
404
  self.state = new_state
405
 
406
- self._prepare_stream_for_reseed_handoff()
 
 
 
 
 
 
 
407
 
408
- # optional: ask streamer to drop an intro crossfade worth of audio right after reseed
409
- self._pending_drop_intro_bars = getattr(self, "_pending_drop_intro_bars", 0) + 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
 
411
  def run(self):
412
- """Main worker loop generate into a continuous stream, then emit bar-aligned slices."""
413
- spb = self._seconds_per_bar() # seconds per bar
414
- chunk_secs = self.params.bars_per_chunk * spb
415
- xfade = float(self.mrt.config.crossfade_length) # seconds
416
- sr = int(self.mrt.sample_rate)
417
- chunk_samps = int(round(chunk_secs * sr))
418
-
419
- def _need(first_chunk_extra=False):
420
- """How many more samples we still need in the stream to emit next slice."""
421
- have = 0 if getattr(self, "_stream", None) is None else self._stream.shape[0] - getattr(self, "_next_emit_start", 0)
422
- want = chunk_samps
423
  if first_chunk_extra:
424
- # reserve two bars extra so first-chunk onset alignment has material
425
- want += int(round(2 * spb * sr))
426
- return max(0, want - have)
427
-
428
- def _mono_env(x: np.ndarray, sr: int, win_ms: float = 10.0) -> np.ndarray:
429
- if x.ndim == 2: x = x.mean(axis=1)
430
- x = np.abs(x).astype(np.float32)
431
- w = max(1, int(round(win_ms * 1e-3 * sr)))
432
- if w > 1:
433
- kern = np.ones(w, dtype=np.float32) / float(w)
434
- x = np.convolve(x, kern, mode="same")
435
- d = np.diff(x, prepend=x[:1])
436
- d[d < 0] = 0.0
437
- return d
438
-
439
- def _estimate_first_offset_samples(ref_loop_wav, gen_head_wav, sr: int, spb: float) -> int:
440
- """Tempo-aware first-downbeat offset (positive => model late)."""
441
- try:
442
- max_ms = int(max(160.0, min(0.25 * spb * 1000.0, 450.0)))
443
- ref = ref_loop_wav if ref_loop_wav.sample_rate == sr else ref_loop_wav.resample(sr)
444
- n_bar = int(round(spb * sr))
445
- ref_tail = ref.samples[-n_bar:, :] if ref.samples.shape[0] >= n_bar else ref.samples
446
- gen_head = gen_head_wav.samples[: int(2 * n_bar), :]
447
- if ref_tail.size == 0 or gen_head.size == 0:
448
- return 0
449
-
450
- # envelopes + z-score
451
- import numpy as np
452
- def _z(a):
453
- m, s = float(a.mean()), float(a.std() or 1.0); return (a - m) / s
454
- e_ref = _z(_mono_env(ref_tail, sr)).astype(np.float32)
455
- e_gen = _z(_mono_env(gen_head, sr)).astype(np.float32)
456
-
457
- # upsample x4 for finer lag
458
- def _upsample(a, r=4):
459
- n = len(a); grid = np.arange(n, dtype=np.float32)
460
- fine = np.linspace(0, n - 1, num=n * r, dtype=np.float32)
461
- return np.interp(fine, grid, a).astype(np.float32)
462
- up = 4
463
- e_ref_u, e_gen_u = _upsample(e_ref, up), _upsample(e_gen, up)
464
-
465
- max_lag_u = int(round((max_ms / 1000.0) * sr * up))
466
- seg = min(len(e_ref_u), len(e_gen_u))
467
- e_ref_u = e_ref_u[-seg:]
468
- pad = np.zeros(max_lag_u, dtype=np.float32)
469
- e_gen_u_pad = np.concatenate([pad, e_gen_u, pad])
470
-
471
- best_lag_u, best_score = 0, -1e9
472
- for lag_u in range(-max_lag_u, max_lag_u + 1):
473
- start = max_lag_u + lag_u
474
- b = e_gen_u_pad[start : start + seg]
475
- denom = (np.linalg.norm(e_ref_u) * np.linalg.norm(b)) or 1.0
476
- score = float(np.dot(e_ref_u, b) / denom)
477
- if score > best_score:
478
- best_score, best_lag_u = score, lag_u
479
- return int(round(best_lag_u / up))
480
- except Exception:
481
- return 0
482
-
483
- print("🚀 JamWorker started (bar-aligned streaming)…")
484
 
485
  while not self._stop_event.is_set():
486
  if not self._should_generate_next_chunk():
487
  time.sleep(0.25)
488
  continue
489
 
490
- # 1) Generate until we have enough material in the stream
491
- need = _need(first_chunk_extra=(self.idx == 0))
492
- while need > 0 and not self._stop_event.is_set():
493
  with self._lock:
494
  style_vec = self.params.style_vec
495
  self.mrt.guidance_weight = float(self.params.guidance_weight)
496
- self.mrt.temperature = float(self.params.temperature)
497
- self.mrt.topk = int(self.params.topk)
 
498
  wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
499
- self._append_model_chunk_to_stream(wav) # equal-power xfade into a persistent stream
500
- need = _need(first_chunk_extra=(self.idx == 0))
501
 
502
  if self._stop_event.is_set():
503
  break
504
 
505
- # 2) One-time: align the emit pointer to the groove
506
  if (self.idx == 0 and self.params.combined_loop is not None) or self._needs_bar_realign:
507
  ref_loop = self._reseed_ref_loop or self.params.combined_loop
508
  if ref_loop is not None:
509
- head_len = min(self._stream.shape[0] - self._next_emit_start, int(round(2 * spb * sr)))
510
- seg = self._stream[self._next_emit_start : self._next_emit_start + head_len]
511
- gen_head = au.Waveform(seg.astype(np.float32, copy=False), sr).as_stereo()
512
- offs = _estimate_first_offset_samples(ref_loop, gen_head, sr, spb)
513
- if offs != 0:
514
- self._next_emit_start = max(0, self._next_emit_start + offs)
515
- print(f"🎯 Offset compensation: {offs/sr:+.3f}s")
516
- self._realign_emit_pointer_to_bar(sr)
517
  self._needs_bar_realign = False
518
  self._reseed_ref_loop = None
519
 
520
- # 3) Emit exactly bars_per_chunk × spb from the stream
521
- start = self._next_emit_start
522
- end = start + chunk_samps
523
- if end > self._stream.shape[0]:
524
- # shouldn't happen often; generate a bit more and loop
525
- continue
526
 
527
- slice_ = self._stream[start:end]
528
- self._next_emit_start = end
 
 
 
 
529
 
530
- y = au.Waveform(slice_.astype(np.float32, copy=False), sr).as_stereo()
 
531
 
532
- # 4) Post-processing / loudness
533
  if self.idx == 0 and self.params.ref_loop is not None:
534
  y, _ = match_loudness_to_reference(
535
  self.params.ref_loop, y,
@@ -539,38 +530,96 @@ class JamWorker(threading.Thread):
539
  else:
540
  apply_micro_fades(y, 3)
541
 
542
- # 5) Resample + exact-length snap + encode
543
- b64, meta = self._snap_and_encode(
544
- y, seconds=chunk_secs, target_sr=self.params.target_sr, bars=self.params.bars_per_chunk
545
- )
546
- meta["xfade_seconds"] = xfade
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
 
548
- # 6) Publish
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549
  with self._lock:
550
  self.idx += 1
551
- self.outbox.append(JamChunk(index=self.idx, audio_base64=b64, metadata=meta))
 
 
 
552
  if len(self.outbox) > 10:
553
  cutoff = self._last_delivered_index - 5
554
  self.outbox = [ch for ch in self.outbox if ch.index > cutoff]
555
 
556
- # 👉 If a reseed was requested, apply it *now*, between chunks
557
  if self._pending_reseed is not None:
558
  pkg = self._pending_reseed
559
  self._pending_reseed = None
560
 
561
  new_state = self.mrt.init_state()
562
- new_state.context_tokens = pkg["ctx"] # exact (ctx_frames, depth)
563
  self.state = new_state
564
 
565
- # start a fresh stream and schedule one-time alignment
566
  self._stream = None
567
- self._next_emit_start = 0
568
- self._reseed_ref_loop = pkg.get("ref") or self.params.combined_loop
 
 
 
 
569
  self._needs_bar_realign = True
570
 
571
- print("🔁 Reseed installed at bar boundary; will realign before next slice")
572
 
573
- print(f"✅ Completed chunk {self.idx}")
574
-
575
- print("🛑 JamWorker stopped")
576
 
 
 
 
 
 
 
 
 
 
1
+ # jam_worker.py - COMPREHENSIVE REWRITE FOR PRECISE TIMING
2
+ import threading, time, base64, io, uuid, math
3
  from dataclasses import dataclass, field
4
  import numpy as np
5
  import soundfile as sf
 
8
  from utils import (
9
  match_loudness_to_reference, stitch_generated, hard_trim_seconds,
10
  apply_micro_fades, make_bar_aligned_context, take_bar_aligned_tail,
11
+ resample_and_snap, wav_bytes_base64, StreamingResampler
12
  )
13
 
14
  @dataclass
 
32
  audio_base64: str
33
  metadata: dict
34
 
35
+ @dataclass
36
+ class TimingState:
37
+ """Precise timing state tracking"""
38
+ # Fractional bar position (never rounded until final emission)
39
+ emit_position_bars: float = 0.0
40
+
41
+ # Sample-accurate positions in the stream
42
+ stream_position_samples: int = 0
43
+
44
+ # Accumulated timing error for correction
45
+ fractional_error_bars: float = 0.0
46
+
47
+ # Codec frame timing
48
+ frames_per_bar: float = 0.0
49
+ samples_per_bar: float = 0.0
50
+
51
+ def advance_by_bars(self, bars: float):
52
+ """Advance timing by exact fractional bars"""
53
+ self.emit_position_bars += bars
54
+ self.fractional_error_bars += bars - int(bars)
55
+
56
+ # Correct for accumulated error when it gets significant
57
+ if abs(self.fractional_error_bars) > 0.5:
58
+ correction = int(round(self.fractional_error_bars))
59
+ self.fractional_error_bars -= correction
60
+ return correction # bars to skip/rewind
61
+ return 0
62
+
63
  class JamWorker(threading.Thread):
64
  def __init__(self, mrt, params: JamParams):
65
  super().__init__(daemon=True)
 
67
  self.params = params
68
  self.state = mrt.init_state()
69
 
70
+ # Core timing calculations (keep as floats for precision)
71
+ self._codec_fps = float(self.mrt.codec.frame_rate) # 25.0
72
+ self._model_sr = int(self.mrt.sample_rate) # 48000
73
+ self._target_sr = int(params.target_sr)
74
+
75
+ # Critical: these stay as floats to preserve fractional precision
76
+ self._seconds_per_bar = float(params.beats_per_bar * 60.0 / params.bpm)
77
+ self._frames_per_bar = self._seconds_per_bar * self._codec_fps
78
+ self._samples_per_bar_model = self._seconds_per_bar * self._model_sr
79
+ self._samples_per_bar_target = self._seconds_per_bar * self._target_sr
80
+
81
+ # Timing state
82
+ self._timing = TimingState(
83
+ frames_per_bar=self._frames_per_bar,
84
+ samples_per_bar=self._samples_per_bar_model
85
+ )
86
+
87
+ # Warn about problematic BPMs
88
+ frame_error = abs(self._frames_per_bar - round(self._frames_per_bar))
89
+ if frame_error > 0.01:
90
+ print(f"⚠️ Warning: {params.bpm} BPM creates {frame_error:.3f} frame drift per bar")
91
+ print(f" This may cause gradual timing drift in long jams")
92
+
93
+ # Synchronization + placeholders
94
  self._lock = threading.Lock()
95
+ self._original_context_tokens = None
96
 
97
  if params.combined_loop is not None:
98
  self._setup_context_from_combined_loop()
 
101
  self.outbox: list[JamChunk] = []
102
  self._stop_event = threading.Event()
103
 
104
+ # Stream state
105
  self._stream = None
106
+ self._stream_write_pos = 0 # Where we append new model output
107
+
108
+ # Delivery tracking
109
  self._last_delivered_index = 0
110
  self._max_buffer_ahead = 5
111
 
112
+ # Streaming resampler for precise SR conversion
113
+ self._resampler = None
114
+ if self._target_sr != self._model_sr:
115
+ self._resampler = StreamingResampler(
116
+ in_sr=self._model_sr,
117
+ out_sr=self._target_sr,
118
+ channels=2,
119
+ quality="VHQ"
120
+ )
121
+
122
  # Timing info
123
  self.last_chunk_started_at = None
124
  self.last_chunk_completed_at = None
125
 
126
+ # Control flags
127
+ self._pending_reseed = None
128
+ self._needs_bar_realign = False
129
+ self._reseed_ref_loop = None
130
 
131
  def _setup_context_from_combined_loop(self):
132
  """Set up MRT context tokens from the combined loop audio"""
133
  try:
134
  from utils import make_bar_aligned_context, take_bar_aligned_tail
135
 
136
+ codec_fps = self._codec_fps
137
  ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps
138
 
139
  loop_for_context = take_bar_aligned_tail(
 
146
  tokens_full = self.mrt.codec.encode(loop_for_context).astype(np.int32)
147
  tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
148
 
149
+ # Use enhanced context alignment for fractional BPMs
150
  context_tokens = make_bar_aligned_context(
151
  tokens,
152
  bpm=self.params.bpm,
153
+ fps=self._codec_fps,
154
  ctx_frames=self.mrt.config.context_length_frames,
155
+ beats_per_bar=self.params.beats_per_bar,
156
+ precise_timing=True # Use new precise mode
157
  )
158
 
 
159
  self.state.context_tokens = context_tokens
160
+ print(f"Context setup: {context_tokens.shape[0]} frames, {self._frames_per_bar:.3f} frames/bar")
161
 
162
+ # Store original context for splice reseeding
 
163
  with self._lock:
164
  if not hasattr(self, "_original_context_tokens") or self._original_context_tokens is None:
165
+ self._original_context_tokens = np.copy(context_tokens)
166
 
167
  except Exception as e:
168
+ print(f"Failed to setup context from combined loop: {e}")
169
 
170
  def stop(self):
171
  self._stop_event.set()
172
 
173
  def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None):
174
  with self._lock:
175
+ if guidance_weight is not None:
176
+ self.params.guidance_weight = float(guidance_weight)
177
+ if temperature is not None:
178
+ self.params.temperature = float(temperature)
179
+ if topk is not None:
180
+ self.params.topk = int(topk)
181
 
182
  def get_next_chunk(self) -> JamChunk | None:
183
  """Get the next sequential chunk (blocks/waits if not ready)"""
184
  target_index = self._last_delivered_index + 1
185
 
186
+ max_wait = 30.0
 
187
  start_time = time.time()
188
 
189
  while time.time() - start_time < max_wait and not self._stop_event.is_set():
190
  with self._lock:
 
191
  for chunk in self.outbox:
192
  if chunk.index == target_index:
193
  self._last_delivered_index = target_index
194
+ print(f"Delivered chunk {target_index} (bars {chunk.metadata.get('bar_range', 'unknown')})")
195
  return chunk
 
 
196
  time.sleep(0.1)
197
 
 
198
  return None
199
 
200
  def mark_chunk_consumed(self, chunk_index: int):
201
  """Mark a chunk as consumed by the frontend"""
202
  with self._lock:
203
  self._last_delivered_index = max(self._last_delivered_index, chunk_index)
 
204
 
205
  def _should_generate_next_chunk(self) -> bool:
206
+ """Check if we should generate the next chunk"""
207
  with self._lock:
208
+ return self.idx <= self._last_delivered_index + self._max_buffer_ahead
209
+
210
+ def _get_precise_chunk_samples(self, bars: float) -> int:
211
+ """Get exact sample count for fractional bars at model SR"""
212
+ exact_seconds = bars * self._seconds_per_bar
213
+ return int(round(exact_seconds * self._model_sr))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
  def _append_model_chunk_to_stream(self, wav):
216
+ """Append model output to continuous stream with crossfading"""
217
  xfade_s = float(self.mrt.config.crossfade_length)
218
+ sr = self._model_sr
219
  xfade_n = int(round(xfade_s * sr))
220
 
221
  s = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None]
222
 
223
+ if self._stream is None:
224
+ # First chunk: drop model pre-roll
225
  if s.shape[0] > xfade_n:
226
  self._stream = s[xfade_n:].astype(np.float32, copy=True)
227
  else:
228
  self._stream = np.zeros((0, s.shape[1]), dtype=np.float32)
229
+ self._stream_write_pos = self._stream.shape[0]
230
  return
231
 
232
+ # Crossfade with equal-power curves
233
  if s.shape[0] <= xfade_n or self._stream.shape[0] < xfade_n:
234
+ # Degenerate case
235
  self._stream = np.concatenate([self._stream, s], axis=0)
236
+ self._stream_write_pos = self._stream.shape[0]
237
  return
238
 
239
+ # Standard crossfade
240
  tail = self._stream[-xfade_n:]
241
  head = s[:xfade_n]
242
 
 
243
  t = np.linspace(0, np.pi/2, xfade_n, endpoint=False, dtype=np.float32)[:, None]
244
  eq_in, eq_out = np.sin(t), np.cos(t)
245
  mixed = tail * eq_out + head * eq_in
246
 
247
  self._stream = np.concatenate([self._stream[:-xfade_n], mixed, s[xfade_n:]], axis=0)
248
+ self._stream_write_pos = self._stream.shape[0]
249
+
250
+ def _extract_precise_chunk(self, start_bars: float, chunk_bars: float) -> np.ndarray:
251
+ """Extract exactly chunk_bars worth of audio starting at start_bars"""
252
+ start_samples = self._get_precise_chunk_samples(start_bars)
253
+ chunk_samples = self._get_precise_chunk_samples(chunk_bars)
254
+ end_samples = start_samples + chunk_samples
255
+
256
+ if end_samples > self._stream.shape[0]:
257
+ return None # Not enough audio generated yet
258
+
259
+ return self._stream[start_samples:end_samples]
260
+
261
+ def _perform_onset_alignment(self, ref_loop: au.Waveform) -> float:
262
+ """Estimate timing offset between generated audio and reference"""
263
+ if self._stream is None or self._stream.shape[0] < self._model_sr:
264
+ return 0.0
265
+
266
+ try:
267
+ # Take first ~2 seconds of generated audio
268
+ gen_samples = min(int(2.0 * self._model_sr), self._stream.shape[0])
269
+ gen_head = au.Waveform(
270
+ self._stream[:gen_samples].astype(np.float32, copy=False),
271
+ self._model_sr
272
+ ).as_stereo()
273
+
274
+ # Reference: last bar of the loop
275
+ ref_samples = int(self._seconds_per_bar * ref_loop.sample_rate)
276
+ if ref_loop.samples.shape[0] >= ref_samples:
277
+ ref_tail = au.Waveform(
278
+ ref_loop.samples[-ref_samples:],
279
+ ref_loop.sample_rate
280
+ ).resample(self._model_sr).as_stereo()
281
+ else:
282
+ ref_tail = ref_loop.resample(self._model_sr).as_stereo()
283
+
284
+ # Cross-correlation based alignment
285
+ def envelope(x, sr):
286
+ if x.ndim == 2:
287
+ x = x.mean(axis=1)
288
+ x = np.abs(x).astype(np.float32)
289
+ # Simple smoothing
290
+ win = max(1, int(0.01 * sr)) # 10ms window
291
+ if win > 1:
292
+ kernel = np.ones(win) / win
293
+ x = np.convolve(x, kernel, mode='same')
294
+ return x
295
+
296
+ env_ref = envelope(ref_tail.samples, self._model_sr)
297
+ env_gen = envelope(gen_head.samples, self._model_sr)
298
+
299
+ # Limit search range to reasonable offset
300
+ max_offset_samples = int(0.2 * self._model_sr) # 200ms max
301
+
302
+ # Normalize for correlation
303
+ env_ref = (env_ref - env_ref.mean()) / (env_ref.std() + 1e-8)
304
+ env_gen = (env_gen - env_gen.mean()) / (env_gen.std() + 1e-8)
305
+
306
+ # Find best correlation
307
+ best_offset = 0
308
+ best_corr = -1.0
309
+
310
+ search_len = min(len(env_ref), len(env_gen) - max_offset_samples)
311
+ if search_len > 0:
312
+ for offset in range(0, max_offset_samples, 4): # subsample for speed
313
+ if offset + search_len >= len(env_gen):
314
+ break
315
+ corr = np.corrcoef(env_ref[:search_len], env_gen[offset:offset+search_len])[0,1]
316
+ if not np.isnan(corr) and corr > best_corr:
317
+ best_corr = corr
318
+ best_offset = offset
319
+
320
+ offset_seconds = best_offset / self._model_sr
321
+ print(f"Onset alignment: {offset_seconds:.3f}s offset (correlation: {best_corr:.3f})")
322
+ return offset_seconds
323
+
324
+ except Exception as e:
325
+ print(f"Onset alignment failed: {e}")
326
+ return 0.0
327
+
328
+ def _align_to_bar_boundary(self):
329
+ """Align timing state to next bar boundary"""
330
+ current_bar = self._timing.emit_position_bars
331
+ next_bar = math.ceil(current_bar)
332
+
333
+ if abs(next_bar - current_bar) > 1e-6:
334
+ skip_bars = next_bar - current_bar
335
+ skip_samples = self._get_precise_chunk_samples(skip_bars)
336
+ self._timing.stream_position_samples += skip_samples
337
+ self._timing.emit_position_bars = next_bar
338
+ print(f"Aligned to bar {next_bar:.0f}, skipped {skip_bars:.4f} bars")
339
 
340
  def reseed_from_waveform(self, wav):
341
+ """Full context replacement reseed"""
342
  new_state = self.mrt.init_state()
343
+
344
+ # Build new context from waveform
345
+ codec_fps = self._codec_fps
346
  ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps
347
+
 
348
  tail = take_bar_aligned_tail(wav, self.params.bpm, self.params.beats_per_bar, ctx_seconds)
349
  tokens_full = self.mrt.codec.encode(tail).astype(np.int32)
350
  tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
351
+
352
+ context_tokens = make_bar_aligned_context(
353
+ tokens,
354
+ bpm=self.params.bpm,
355
+ fps=self._codec_fps,
356
  ctx_frames=self.mrt.config.context_length_frames,
357
+ beats_per_bar=self.params.beats_per_bar,
358
+ precise_timing=True
359
  )
360
+
361
  new_state.context_tokens = context_tokens
362
  self.state = new_state
363
+
364
+ # Reset stream
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
  self._stream = None
366
+ self._stream_write_pos = 0
367
+ self._timing = TimingState(
368
+ frames_per_bar=self._frames_per_bar,
369
+ samples_per_bar=self._samples_per_bar_model
370
+ )
371
  self._needs_bar_realign = True
372
+ self._reseed_ref_loop = wav
373
 
374
  def reseed_splice(self, recent_wav, anchor_bars: float):
375
+ """Token-splice reseed"""
 
 
376
  with self._lock:
377
  if not hasattr(self, "_original_context_tokens") or self._original_context_tokens is None:
378
  self._original_context_tokens = np.copy(self.state.context_tokens)
379
 
380
+ # Build new context via splicing
381
+ recent_tokens = self._make_recent_tokens_from_wave(recent_wav)
382
  new_ctx = self._splice_context(self._original_context_tokens, recent_tokens, anchor_bars)
383
 
 
384
  self._pending_reseed = {"ctx": new_ctx, "ref": recent_wav}
385
+
386
+ # Install immediately
387
  new_state = self.mrt.init_state()
388
  new_state.context_tokens = new_ctx
389
  self.state = new_state
390
 
391
+ # Reset stream state
392
+ self._stream = None
393
+ self._stream_write_pos = 0
394
+ self._timing = TimingState(
395
+ frames_per_bar=self._frames_per_bar,
396
+ samples_per_bar=self._samples_per_bar_model
397
+ )
398
+ self._needs_bar_realign = True
399
 
400
+ def _make_recent_tokens_from_wave(self, wav) -> np.ndarray:
401
+ """Encode waveform to context tokens with precise alignment"""
402
+ tokens_full = self.mrt.codec.encode(wav).astype(np.int32)
403
+ tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
404
+
405
+ context_tokens = make_bar_aligned_context(
406
+ tokens,
407
+ bpm=self.params.bpm,
408
+ fps=self._codec_fps,
409
+ ctx_frames=self.mrt.config.context_length_frames,
410
+ beats_per_bar=self.params.beats_per_bar,
411
+ precise_timing=True
412
+ )
413
+ return context_tokens
414
+
415
+ def _splice_context(self, original_tokens: np.ndarray, recent_tokens: np.ndarray, anchor_bars: float) -> np.ndarray:
416
+ """Enhanced context splicing with fractional bar handling"""
417
+ ctx_frames = int(self.mrt.config.context_length_frames)
418
+
419
+ # Convert anchor bars to codec frames (keep fractional precision)
420
+ anchor_frames_f = anchor_bars * self._frames_per_bar
421
+ anchor_frames = int(round(anchor_frames_f))
422
+
423
+ # Take anchor from original
424
+ anchor = original_tokens[-anchor_frames:] if anchor_frames <= original_tokens.shape[0] else original_tokens
425
+
426
+ # Fill remainder with recent tokens
427
+ remain_frames = ctx_frames - anchor.shape[0]
428
+ if remain_frames > 0:
429
+ recent = recent_tokens[-remain_frames:] if remain_frames <= recent_tokens.shape[0] else recent_tokens
430
+ else:
431
+ recent = recent_tokens[:0] # empty
432
+
433
+ # Combine
434
+ if anchor.size > 0 and recent.size > 0:
435
+ spliced = np.concatenate([recent, anchor], axis=0)
436
+ elif anchor.size > 0:
437
+ spliced = anchor
438
+ else:
439
+ spliced = recent_tokens[-ctx_frames:]
440
+
441
+ # Ensure exact length
442
+ if spliced.shape[0] > ctx_frames:
443
+ spliced = spliced[-ctx_frames:]
444
+ elif spliced.shape[0] < ctx_frames:
445
+ # Tile to fill
446
+ reps = int(np.ceil(ctx_frames / max(1, spliced.shape[0])))
447
+ spliced = np.tile(spliced, (reps, 1))[-ctx_frames:]
448
+
449
+ return spliced
450
 
451
  def run(self):
452
+ """Main generation loop with precise timing"""
453
+ chunk_bars = float(self.params.bars_per_chunk)
454
+ chunk_samples = self._get_precise_chunk_samples(chunk_bars)
455
+ xfade_s = float(self.mrt.config.crossfade_length)
456
+
457
+ def _samples_needed(first_chunk_extra=False):
458
+ """Calculate samples needed in stream for next emission"""
459
+ available = 0 if self._stream is None else (
460
+ self._stream.shape[0] - self._timing.stream_position_samples
461
+ )
462
+ required = chunk_samples
463
  if first_chunk_extra:
464
+ # Extra material for onset alignment
465
+ extra_samples = self._get_precise_chunk_samples(2.0)
466
+ required += extra_samples
467
+ return max(0, required - available)
468
+
469
+ print(f"JamWorker started: {self.params.bpm} BPM, {self._frames_per_bar:.3f} frames/bar, {chunk_bars} bars/chunk")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
 
471
  while not self._stop_event.is_set():
472
  if not self._should_generate_next_chunk():
473
  time.sleep(0.25)
474
  continue
475
 
476
+ # 1) Generate until we have enough audio
477
+ needed = _samples_needed(first_chunk_extra=(self.idx == 0))
478
+ while needed > 0 and not self._stop_event.is_set():
479
  with self._lock:
480
  style_vec = self.params.style_vec
481
  self.mrt.guidance_weight = float(self.params.guidance_weight)
482
+ self.mrt.temperature = float(self.params.temperature)
483
+ self.mrt.topk = int(self.params.topk)
484
+
485
  wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
486
+ self._append_model_chunk_to_stream(wav)
487
+ needed = _samples_needed(first_chunk_extra=(self.idx == 0))
488
 
489
  if self._stop_event.is_set():
490
  break
491
 
492
+ # 2) First chunk: perform onset alignment
493
  if (self.idx == 0 and self.params.combined_loop is not None) or self._needs_bar_realign:
494
  ref_loop = self._reseed_ref_loop or self.params.combined_loop
495
  if ref_loop is not None:
496
+ offset_seconds = self._perform_onset_alignment(ref_loop)
497
+ if abs(offset_seconds) > 0.01: # More than 10ms
498
+ offset_samples = int(round(offset_seconds * self._model_sr))
499
+ self._timing.stream_position_samples = max(0, offset_samples)
500
+ print(f"Applied onset offset: {offset_seconds:.3f}s")
501
+
502
+ self._align_to_bar_boundary()
 
503
  self._needs_bar_realign = False
504
  self._reseed_ref_loop = None
505
 
506
+ # 3) Extract precise chunk
507
+ chunk_start_bars = self._timing.emit_position_bars
508
+ slice_audio = self._extract_precise_chunk(chunk_start_bars, chunk_bars)
509
+
510
+ if slice_audio is None:
511
+ continue # Need more generation
512
 
513
+ # Update timing state
514
+ correction = self._timing.advance_by_bars(chunk_bars)
515
+ if correction != 0:
516
+ print(f"Applied {correction} bar timing correction")
517
+
518
+ self._timing.stream_position_samples += chunk_samples
519
 
520
+ # 4) Create waveform and process
521
+ y = au.Waveform(slice_audio.astype(np.float32, copy=False), self._model_sr).as_stereo()
522
 
523
+ # Loudness matching and fades
524
  if self.idx == 0 and self.params.ref_loop is not None:
525
  y, _ = match_loudness_to_reference(
526
  self.params.ref_loop, y,
 
530
  else:
531
  apply_micro_fades(y, 3)
532
 
533
+ # 5) Sample rate conversion
534
+ if self._resampler is not None:
535
+ # Use streaming resampler for precise conversion
536
+ resampled = self._resampler.process(y.samples, final=False)
537
+
538
+ # Ensure exact target length
539
+ target_samples = int(round(chunk_bars * self._samples_per_bar_target))
540
+ if resampled.shape[0] != target_samples:
541
+ if resampled.shape[0] < target_samples:
542
+ pad_samples = target_samples - resampled.shape[0]
543
+ pad = np.zeros((pad_samples, resampled.shape[1]), dtype=resampled.dtype)
544
+ resampled = np.vstack([resampled, pad])
545
+ else:
546
+ resampled = resampled[:target_samples]
547
+
548
+ final_audio = resampled
549
+ final_sr = self._target_sr
550
+ else:
551
+ # No resampling needed
552
+ final_audio = y.samples
553
+ final_sr = self._model_sr
554
 
555
+ # 6) Encode to base64
556
+ b64, total_samples, channels = wav_bytes_base64(final_audio, final_sr)
557
+
558
+ # 7) Create metadata with timing info
559
+ actual_duration = total_samples / final_sr
560
+ bar_range = f"{chunk_start_bars:.2f}-{self._timing.emit_position_bars:.2f}"
561
+
562
+ meta = {
563
+ "bpm": int(round(self.params.bpm)),
564
+ "bars": int(self.params.bars_per_chunk),
565
+ "beats_per_bar": int(self.params.beats_per_bar),
566
+ "sample_rate": int(final_sr),
567
+ "channels": int(channels),
568
+ "total_samples": int(total_samples),
569
+ "seconds_per_bar": self._seconds_per_bar,
570
+ "loop_duration_seconds": actual_duration,
571
+ "bar_range": bar_range,
572
+ "timing_state": {
573
+ "emit_position_bars": self._timing.emit_position_bars,
574
+ "frames_per_bar": self._frames_per_bar,
575
+ "fractional_error": self._timing.fractional_error_bars,
576
+ },
577
+ "xfade_seconds": xfade_s,
578
+ "guidance_weight": self.params.guidance_weight,
579
+ "temperature": self.params.temperature,
580
+ "topk": self.params.topk,
581
+ }
582
+
583
+ # 8) Publish chunk
584
  with self._lock:
585
  self.idx += 1
586
+ chunk = JamChunk(index=self.idx, audio_base64=b64, metadata=meta)
587
+ self.outbox.append(chunk)
588
+
589
+ # Cleanup old chunks
590
  if len(self.outbox) > 10:
591
  cutoff = self._last_delivered_index - 5
592
  self.outbox = [ch for ch in self.outbox if ch.index > cutoff]
593
 
594
+ # Handle pending reseeds
595
  if self._pending_reseed is not None:
596
  pkg = self._pending_reseed
597
  self._pending_reseed = None
598
 
599
  new_state = self.mrt.init_state()
600
+ new_state.context_tokens = pkg["ctx"]
601
  self.state = new_state
602
 
603
+ # Reset timing and stream
604
  self._stream = None
605
+ self._stream_write_pos = 0
606
+ self._timing = TimingState(
607
+ frames_per_bar=self._frames_per_bar,
608
+ samples_per_bar=self._samples_per_bar_model
609
+ )
610
+ self._reseed_ref_loop = pkg.get("ref")
611
  self._needs_bar_realign = True
612
 
613
+ print("Reseed applied at bar boundary")
614
 
615
+ drift_ms = abs(self._timing.fractional_error_bars) * self._seconds_per_bar * 1000
616
+ print(f"Completed chunk {self.idx} ({bar_range} bars, {drift_ms:.1f}ms drift)")
 
617
 
618
+ print("JamWorker stopped")
619
+
620
+ # Clean up resampler
621
+ if self._resampler is not None:
622
+ try:
623
+ self._resampler.flush()
624
+ except:
625
+ pass
utils.py CHANGED
@@ -109,55 +109,81 @@ def apply_micro_fades(wav: au.Waveform, ms: int = 5) -> None:
109
 
110
 
111
  # ---------- Token context helpers ----------
112
- def make_bar_aligned_context(tokens, bpm, fps=25.0, ctx_frames=250, beats_per_bar=4):
113
  """
114
  Return a ctx_frames-long slice of `tokens` whose **end** lands on the nearest
115
- whole-bar boundary in codec-frame space, even when frames_per_bar is fractional.
116
-
117
- tokens: np.ndarray of shape (T, D) or (T,) where T = codec frames
118
- bpm: float
119
- fps: float (codec frames per second; keep this as float)
120
- ctx_frames: int (length of context window in codec frames)
121
- beats_per_bar: int
122
  """
123
-
124
-
125
  if tokens is None:
126
  raise ValueError("tokens is None")
127
  tokens = np.asarray(tokens)
128
  if tokens.ndim == 1:
129
- tokens = tokens[:, None] # promote to (T, 1) for uniform tiling
130
 
131
  T = tokens.shape[0]
132
  if T == 0:
133
  return tokens
134
 
135
  fps = float(fps)
136
- frames_per_bar_f = (beats_per_bar * 60.0 / float(bpm)) * fps # float frames per bar
137
-
138
- # Tile a little more than we need so we can always snap the END to a bar boundary
139
- reps = int(np.ceil((ctx_frames + T) / float(T))) + 1
140
- tiled = np.tile(tokens, (reps, 1))
141
- total = tiled.shape[0]
142
-
143
- # How many whole bars fit?
144
- k_bars = int(np.floor(total / frames_per_bar_f))
145
- if k_bars <= 0:
146
- # Fallback: just take the last ctx_frames
147
- window = tiled[-ctx_frames:]
148
- return window
149
-
150
- # Snap END index to the nearest integer frame at a whole-bar boundary
151
- end_idx = int(round(k_bars * frames_per_bar_f))
152
- end_idx = min(max(end_idx, ctx_frames), total)
153
- start_idx = end_idx - ctx_frames
154
- if start_idx < 0:
155
- start_idx = 0
156
- end_idx = ctx_frames
157
-
158
- window = tiled[start_idx:end_idx]
159
-
160
- # Guard against rare off-by-one due to rounding
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  if window.shape[0] < ctx_frames:
162
  pad = np.tile(tokens, (int(np.ceil((ctx_frames - window.shape[0]) / T)), 1))
163
  window = np.vstack([window, pad])[:ctx_frames]
 
109
 
110
 
111
  # ---------- Token context helpers ----------
112
+ def make_bar_aligned_context(tokens, bpm, fps=25.0, ctx_frames=250, beats_per_bar=4, precise_timing=False):
113
  """
114
  Return a ctx_frames-long slice of `tokens` whose **end** lands on the nearest
115
+ whole-bar boundary in codec-frame space.
116
+
117
+ NEW: precise_timing mode handles fractional frames per bar more carefully.
 
 
 
 
118
  """
 
 
119
  if tokens is None:
120
  raise ValueError("tokens is None")
121
  tokens = np.asarray(tokens)
122
  if tokens.ndim == 1:
123
+ tokens = tokens[:, None]
124
 
125
  T = tokens.shape[0]
126
  if T == 0:
127
  return tokens
128
 
129
  fps = float(fps)
130
+ frames_per_bar_f = (beats_per_bar * 60.0 / float(bpm)) * fps
131
+
132
+ if precise_timing and abs(frames_per_bar_f - round(frames_per_bar_f)) > 1e-6:
133
+ # We have fractional frames per bar - use a different strategy
134
+ # Instead of trying to align to exact bar boundaries, align to the closest
135
+ # multiple of frames_per_bar_f that gives us integer frame positions
136
+
137
+ # Tile enough to work with
138
+ reps = max(2, int(np.ceil((ctx_frames + T) / float(T))))
139
+ tiled = np.tile(tokens, (reps, 1))
140
+ total = tiled.shape[0]
141
+
142
+ # Find the best integer end position that's close to a bar boundary
143
+ best_end = ctx_frames
144
+ best_error = float('inf')
145
+
146
+ # Check positions around the naive ctx_frames endpoint
147
+ for candidate_end in range(max(ctx_frames - 50, ctx_frames), min(total, ctx_frames + 50)):
148
+ # How many fractional bars does this represent?
149
+ fractional_bars = candidate_end / frames_per_bar_f
150
+ # How far from an integer number of bars?
151
+ bar_error = abs(fractional_bars - round(fractional_bars))
152
+
153
+ if bar_error < best_error:
154
+ best_error = bar_error
155
+ best_end = candidate_end
156
+
157
+ end_idx = best_end
158
+ start_idx = max(0, end_idx - ctx_frames)
159
+
160
+ window = tiled[start_idx:end_idx]
161
+
162
+ # Report timing info for debugging
163
+ actual_bars = end_idx / frames_per_bar_f
164
+ print(f"Context aligned to {actual_bars:.3f} bars (error: {best_error:.4f})")
165
+
166
+ else:
167
+ # Original logic for integer frames per bar
168
+ reps = int(np.ceil((ctx_frames + T) / float(T))) + 1
169
+ tiled = np.tile(tokens, (reps, 1))
170
+ total = tiled.shape[0]
171
+
172
+ k_bars = int(np.floor(total / frames_per_bar_f))
173
+ if k_bars <= 0:
174
+ window = tiled[-ctx_frames:]
175
+ return window
176
+
177
+ end_idx = int(round(k_bars * frames_per_bar_f))
178
+ end_idx = min(max(end_idx, ctx_frames), total)
179
+ start_idx = end_idx - ctx_frames
180
+ if start_idx < 0:
181
+ start_idx = 0
182
+ end_idx = ctx_frames
183
+
184
+ window = tiled[start_idx:end_idx]
185
+
186
+ # Ensure exact length
187
  if window.shape[0] < ctx_frames:
188
  pad = np.tile(tokens, (int(np.ceil((ctx_frames - window.shape[0]) / T)), 1))
189
  window = np.vstack([window, pad])[:ctx_frames]