thecollabagepatch commited on
Commit
7fe8be5
·
1 Parent(s): 4bdf506

ok reverting one more time

Browse files
Files changed (2) hide show
  1. jam_worker.py +355 -404
  2. utils.py +36 -62
jam_worker.py CHANGED
@@ -1,5 +1,5 @@
1
- # jam_worker.py - COMPREHENSIVE REWRITE FOR PRECISE TIMING
2
- import threading, time, base64, io, uuid, math
3
  from dataclasses import dataclass, field
4
  import numpy as np
5
  import soundfile as sf
@@ -8,7 +8,7 @@ from threading import RLock
8
  from utils import (
9
  match_loudness_to_reference, stitch_generated, hard_trim_seconds,
10
  apply_micro_fades, make_bar_aligned_context, take_bar_aligned_tail,
11
- resample_and_snap, wav_bytes_base64, StreamingResampler
12
  )
13
 
14
  @dataclass
@@ -32,34 +32,6 @@ class JamChunk:
32
  audio_base64: str
33
  metadata: dict
34
 
35
- @dataclass
36
- class TimingState:
37
- """Precise timing state tracking"""
38
- # Fractional bar position (never rounded until final emission)
39
- emit_position_bars: float = 0.0
40
-
41
- # Sample-accurate positions in the stream
42
- stream_position_samples: int = 0
43
-
44
- # Accumulated timing error for correction
45
- fractional_error_bars: float = 0.0
46
-
47
- # Codec frame timing
48
- frames_per_bar: float = 0.0
49
- samples_per_bar: float = 0.0
50
-
51
- def advance_by_bars(self, bars: float):
52
- """Advance timing by exact fractional bars"""
53
- self.emit_position_bars += bars
54
- self.fractional_error_bars += bars - int(bars)
55
-
56
- # Correct for accumulated error when it gets significant
57
- if abs(self.fractional_error_bars) > 0.5:
58
- correction = int(round(self.fractional_error_bars))
59
- self.fractional_error_bars -= correction
60
- return correction # bars to skip/rewind
61
- return 0
62
-
63
  class JamWorker(threading.Thread):
64
  def __init__(self, mrt, params: JamParams):
65
  super().__init__(daemon=True)
@@ -67,32 +39,9 @@ class JamWorker(threading.Thread):
67
  self.params = params
68
  self.state = mrt.init_state()
69
 
70
- # Core timing calculations (keep as floats for precision)
71
- self._codec_fps = float(self.mrt.codec.frame_rate) # 25.0
72
- self._model_sr = int(self.mrt.sample_rate) # 48000
73
- self._target_sr = int(params.target_sr)
74
-
75
- # Critical: these stay as floats to preserve fractional precision
76
- self._seconds_per_bar = float(params.beats_per_bar * 60.0 / params.bpm)
77
- self._frames_per_bar = self._seconds_per_bar * self._codec_fps
78
- self._samples_per_bar_model = self._seconds_per_bar * self._model_sr
79
- self._samples_per_bar_target = self._seconds_per_bar * self._target_sr
80
-
81
- # Timing state
82
- self._timing = TimingState(
83
- frames_per_bar=self._frames_per_bar,
84
- samples_per_bar=self._samples_per_bar_model
85
- )
86
-
87
- # Warn about problematic BPMs
88
- frame_error = abs(self._frames_per_bar - round(self._frames_per_bar))
89
- if frame_error > 0.01:
90
- print(f"⚠️ Warning: {params.bpm} BPM creates {frame_error:.3f} frame drift per bar")
91
- print(f" This may cause gradual timing drift in long jams")
92
-
93
- # Synchronization + placeholders
94
  self._lock = threading.Lock()
95
- self._original_context_tokens = None
96
 
97
  if params.combined_loop is not None:
98
  self._setup_context_from_combined_loop()
@@ -101,39 +50,28 @@ class JamWorker(threading.Thread):
101
  self.outbox: list[JamChunk] = []
102
  self._stop_event = threading.Event()
103
 
104
- # Stream state
105
  self._stream = None
106
- self._stream_write_pos = 0 # Where we append new model output
107
-
108
- # Delivery tracking
109
  self._last_delivered_index = 0
110
  self._max_buffer_ahead = 5
111
 
112
- # Streaming resampler for precise SR conversion
113
- self._resampler = None
114
- if self._target_sr != self._model_sr:
115
- self._resampler = StreamingResampler(
116
- in_sr=self._model_sr,
117
- out_sr=self._target_sr,
118
- channels=2,
119
- quality="VHQ"
120
- )
121
-
122
  # Timing info
123
  self.last_chunk_started_at = None
124
  self.last_chunk_completed_at = None
125
 
126
- # Control flags
127
- self._pending_reseed = None
128
- self._needs_bar_realign = False
129
- self._reseed_ref_loop = None
130
 
131
  def _setup_context_from_combined_loop(self):
132
  """Set up MRT context tokens from the combined loop audio"""
133
  try:
134
  from utils import make_bar_aligned_context, take_bar_aligned_tail
135
 
136
- codec_fps = self._codec_fps
137
  ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps
138
 
139
  loop_for_context = take_bar_aligned_tail(
@@ -146,381 +84,452 @@ class JamWorker(threading.Thread):
146
  tokens_full = self.mrt.codec.encode(loop_for_context).astype(np.int32)
147
  tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
148
 
149
- # Use enhanced context alignment for fractional BPMs
150
  context_tokens = make_bar_aligned_context(
151
  tokens,
152
  bpm=self.params.bpm,
153
- fps=self._codec_fps,
154
  ctx_frames=self.mrt.config.context_length_frames,
155
- beats_per_bar=self.params.beats_per_bar,
156
- precise_timing=True # Use new precise mode
157
  )
158
 
 
159
  self.state.context_tokens = context_tokens
160
- print(f"Context setup: {context_tokens.shape[0]} frames, {self._frames_per_bar:.3f} frames/bar")
161
 
162
- # Store original context for splice reseeding
 
163
  with self._lock:
164
  if not hasattr(self, "_original_context_tokens") or self._original_context_tokens is None:
165
- self._original_context_tokens = np.copy(context_tokens)
166
 
167
  except Exception as e:
168
- print(f"Failed to setup context from combined loop: {e}")
169
 
170
  def stop(self):
171
  self._stop_event.set()
172
 
173
  def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None):
174
  with self._lock:
175
- if guidance_weight is not None:
176
- self.params.guidance_weight = float(guidance_weight)
177
- if temperature is not None:
178
- self.params.temperature = float(temperature)
179
- if topk is not None:
180
- self.params.topk = int(topk)
181
 
182
  def get_next_chunk(self) -> JamChunk | None:
183
  """Get the next sequential chunk (blocks/waits if not ready)"""
184
  target_index = self._last_delivered_index + 1
185
 
186
- max_wait = 30.0
 
187
  start_time = time.time()
188
 
189
  while time.time() - start_time < max_wait and not self._stop_event.is_set():
190
  with self._lock:
 
191
  for chunk in self.outbox:
192
  if chunk.index == target_index:
193
  self._last_delivered_index = target_index
194
- print(f"Delivered chunk {target_index} (bars {chunk.metadata.get('bar_range', 'unknown')})")
195
  return chunk
 
 
196
  time.sleep(0.1)
197
 
 
198
  return None
199
 
200
  def mark_chunk_consumed(self, chunk_index: int):
201
  """Mark a chunk as consumed by the frontend"""
202
  with self._lock:
203
  self._last_delivered_index = max(self._last_delivered_index, chunk_index)
 
204
 
205
  def _should_generate_next_chunk(self) -> bool:
206
- """Check if we should generate the next chunk"""
207
  with self._lock:
208
- return self.idx <= self._last_delivered_index + self._max_buffer_ahead
209
-
210
- def _get_precise_chunk_samples(self, bars: float) -> int:
211
- """Get exact sample count for fractional bars at model SR"""
212
- exact_seconds = bars * self._seconds_per_bar
213
- return int(round(exact_seconds * self._model_sr))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
  def _append_model_chunk_to_stream(self, wav):
216
- """Append model output to continuous stream with crossfading"""
217
  xfade_s = float(self.mrt.config.crossfade_length)
218
- sr = self._model_sr
219
  xfade_n = int(round(xfade_s * sr))
220
 
221
  s = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None]
222
 
223
- if self._stream is None:
224
- # First chunk: drop model pre-roll
225
  if s.shape[0] > xfade_n:
226
  self._stream = s[xfade_n:].astype(np.float32, copy=True)
227
  else:
228
  self._stream = np.zeros((0, s.shape[1]), dtype=np.float32)
229
- self._stream_write_pos = self._stream.shape[0]
230
  return
231
 
232
- # Crossfade with equal-power curves
233
  if s.shape[0] <= xfade_n or self._stream.shape[0] < xfade_n:
234
- # Degenerate case
235
  self._stream = np.concatenate([self._stream, s], axis=0)
236
- self._stream_write_pos = self._stream.shape[0]
237
  return
238
 
239
- # Standard crossfade
240
  tail = self._stream[-xfade_n:]
241
  head = s[:xfade_n]
242
 
 
243
  t = np.linspace(0, np.pi/2, xfade_n, endpoint=False, dtype=np.float32)[:, None]
244
  eq_in, eq_out = np.sin(t), np.cos(t)
245
  mixed = tail * eq_out + head * eq_in
246
 
247
  self._stream = np.concatenate([self._stream[:-xfade_n], mixed, s[xfade_n:]], axis=0)
248
- self._stream_write_pos = self._stream.shape[0]
249
-
250
- def _extract_precise_chunk(self, start_bars: float, chunk_bars: float) -> np.ndarray:
251
- """Extract exactly chunk_bars worth of audio starting at start_bars"""
252
- start_samples = self._get_precise_chunk_samples(start_bars)
253
- chunk_samples = self._get_precise_chunk_samples(chunk_bars)
254
- end_samples = start_samples + chunk_samples
255
-
256
- if end_samples > self._stream.shape[0]:
257
- return None # Not enough audio generated yet
258
-
259
- return self._stream[start_samples:end_samples]
260
-
261
- def _perform_onset_alignment(self, ref_loop: au.Waveform) -> float:
262
- """Estimate timing offset between generated audio and reference"""
263
- if self._stream is None or self._stream.shape[0] < self._model_sr:
264
- return 0.0
265
-
266
- try:
267
- # Take first ~2 seconds of generated audio
268
- gen_samples = min(int(2.0 * self._model_sr), self._stream.shape[0])
269
- gen_head = au.Waveform(
270
- self._stream[:gen_samples].astype(np.float32, copy=False),
271
- self._model_sr
272
- ).as_stereo()
273
-
274
- # Reference: last bar of the loop
275
- ref_samples = int(self._seconds_per_bar * ref_loop.sample_rate)
276
- if ref_loop.samples.shape[0] >= ref_samples:
277
- ref_tail = au.Waveform(
278
- ref_loop.samples[-ref_samples:],
279
- ref_loop.sample_rate
280
- ).resample(self._model_sr).as_stereo()
281
- else:
282
- ref_tail = ref_loop.resample(self._model_sr).as_stereo()
283
-
284
- # Cross-correlation based alignment
285
- def envelope(x, sr):
286
- if x.ndim == 2:
287
- x = x.mean(axis=1)
288
- x = np.abs(x).astype(np.float32)
289
- # Simple smoothing
290
- win = max(1, int(0.01 * sr)) # 10ms window
291
- if win > 1:
292
- kernel = np.ones(win) / win
293
- x = np.convolve(x, kernel, mode='same')
294
- return x
295
-
296
- env_ref = envelope(ref_tail.samples, self._model_sr)
297
- env_gen = envelope(gen_head.samples, self._model_sr)
298
-
299
- # Limit search range to reasonable offset
300
- max_offset_samples = int(0.2 * self._model_sr) # 200ms max
301
-
302
- # Normalize for correlation
303
- env_ref = (env_ref - env_ref.mean()) / (env_ref.std() + 1e-8)
304
- env_gen = (env_gen - env_gen.mean()) / (env_gen.std() + 1e-8)
305
-
306
- # Find best correlation
307
- best_offset = 0
308
- best_corr = -1.0
309
-
310
- search_len = min(len(env_ref), len(env_gen) - max_offset_samples)
311
- if search_len > 0:
312
- for offset in range(0, max_offset_samples, 4): # subsample for speed
313
- if offset + search_len >= len(env_gen):
314
- break
315
- corr = np.corrcoef(env_ref[:search_len], env_gen[offset:offset+search_len])[0,1]
316
- if not np.isnan(corr) and corr > best_corr:
317
- best_corr = corr
318
- best_offset = offset
319
-
320
- offset_seconds = best_offset / self._model_sr
321
- print(f"Onset alignment: {offset_seconds:.3f}s offset (correlation: {best_corr:.3f})")
322
- return offset_seconds
323
-
324
- except Exception as e:
325
- print(f"Onset alignment failed: {e}")
326
- return 0.0
327
-
328
- def _align_to_bar_boundary(self):
329
- """Align timing state to next bar boundary"""
330
- current_bar = self._timing.emit_position_bars
331
- next_bar = math.ceil(current_bar)
332
-
333
- if abs(next_bar - current_bar) > 1e-6:
334
- skip_bars = next_bar - current_bar
335
- skip_samples = self._get_precise_chunk_samples(skip_bars)
336
- self._timing.stream_position_samples += skip_samples
337
- self._timing.emit_position_bars = next_bar
338
- print(f"Aligned to bar {next_bar:.0f}, skipped {skip_bars:.4f} bars")
339
 
340
  def reseed_from_waveform(self, wav):
341
- """Full context replacement reseed"""
342
  new_state = self.mrt.init_state()
343
-
344
- # Build new context from waveform
345
- codec_fps = self._codec_fps
346
  ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps
347
-
 
348
  tail = take_bar_aligned_tail(wav, self.params.bpm, self.params.beats_per_bar, ctx_seconds)
349
  tokens_full = self.mrt.codec.encode(tail).astype(np.int32)
350
  tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
351
-
352
- context_tokens = make_bar_aligned_context(
353
- tokens,
354
- bpm=self.params.bpm,
355
- fps=self._codec_fps,
356
  ctx_frames=self.mrt.config.context_length_frames,
357
- beats_per_bar=self.params.beats_per_bar,
358
- precise_timing=True
359
  )
360
-
361
  new_state.context_tokens = context_tokens
362
  self.state = new_state
363
-
364
- # Reset stream
365
- self._stream = None
366
- self._stream_write_pos = 0
367
- self._timing = TimingState(
368
- frames_per_bar=self._frames_per_bar,
369
- samples_per_bar=self._samples_per_bar_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  self._needs_bar_realign = True
372
- self._reseed_ref_loop = wav
373
 
374
  def reseed_splice(self, recent_wav, anchor_bars: float):
375
- """Token-splice reseed"""
 
 
376
  with self._lock:
377
  if not hasattr(self, "_original_context_tokens") or self._original_context_tokens is None:
378
  self._original_context_tokens = np.copy(self.state.context_tokens)
379
 
380
- # Build new context via splicing
381
- recent_tokens = self._make_recent_tokens_from_wave(recent_wav)
382
  new_ctx = self._splice_context(self._original_context_tokens, recent_tokens, anchor_bars)
383
 
 
384
  self._pending_reseed = {"ctx": new_ctx, "ref": recent_wav}
385
-
386
- # Install immediately
387
  new_state = self.mrt.init_state()
388
  new_state.context_tokens = new_ctx
389
  self.state = new_state
390
 
391
- # Reset stream state
392
- self._stream = None
393
- self._stream_write_pos = 0
394
- self._timing = TimingState(
395
- frames_per_bar=self._frames_per_bar,
396
- samples_per_bar=self._samples_per_bar_model
397
- )
398
- self._needs_bar_realign = True
399
 
400
- def _make_recent_tokens_from_wave(self, wav) -> np.ndarray:
401
- """Encode waveform to context tokens with precise alignment"""
402
- tokens_full = self.mrt.codec.encode(wav).astype(np.int32)
403
- tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
404
-
405
- context_tokens = make_bar_aligned_context(
406
- tokens,
407
- bpm=self.params.bpm,
408
- fps=self._codec_fps,
409
- ctx_frames=self.mrt.config.context_length_frames,
410
- beats_per_bar=self.params.beats_per_bar,
411
- precise_timing=True
412
- )
413
- return context_tokens
414
-
415
- def _splice_context(self, original_tokens: np.ndarray, recent_tokens: np.ndarray, anchor_bars: float) -> np.ndarray:
416
- """Enhanced context splicing with fractional bar handling"""
417
- ctx_frames = int(self.mrt.config.context_length_frames)
418
-
419
- # Convert anchor bars to codec frames (keep fractional precision)
420
- anchor_frames_f = anchor_bars * self._frames_per_bar
421
- anchor_frames = int(round(anchor_frames_f))
422
-
423
- # Take anchor from original
424
- anchor = original_tokens[-anchor_frames:] if anchor_frames <= original_tokens.shape[0] else original_tokens
425
-
426
- # Fill remainder with recent tokens
427
- remain_frames = ctx_frames - anchor.shape[0]
428
- if remain_frames > 0:
429
- recent = recent_tokens[-remain_frames:] if remain_frames <= recent_tokens.shape[0] else recent_tokens
430
- else:
431
- recent = recent_tokens[:0] # empty
432
-
433
- # Combine
434
- if anchor.size > 0 and recent.size > 0:
435
- spliced = np.concatenate([recent, anchor], axis=0)
436
- elif anchor.size > 0:
437
- spliced = anchor
438
- else:
439
- spliced = recent_tokens[-ctx_frames:]
440
-
441
- # Ensure exact length
442
- if spliced.shape[0] > ctx_frames:
443
- spliced = spliced[-ctx_frames:]
444
- elif spliced.shape[0] < ctx_frames:
445
- # Tile to fill
446
- reps = int(np.ceil(ctx_frames / max(1, spliced.shape[0])))
447
- spliced = np.tile(spliced, (reps, 1))[-ctx_frames:]
448
-
449
- return spliced
450
 
451
  def run(self):
452
- """Main generation loop with precise timing"""
453
- chunk_bars = float(self.params.bars_per_chunk)
454
- chunk_samples = self._get_precise_chunk_samples(chunk_bars)
455
- xfade_s = float(self.mrt.config.crossfade_length)
456
-
457
- def _samples_needed(first_chunk_extra=False):
458
- """Calculate samples needed in stream for next emission"""
459
- available = 0 if self._stream is None else (
460
- self._stream.shape[0] - self._timing.stream_position_samples
461
- )
462
- required = chunk_samples
463
  if first_chunk_extra:
464
- # Extra material for onset alignment
465
- extra_samples = self._get_precise_chunk_samples(2.0)
466
- required += extra_samples
467
- return max(0, required - available)
468
-
469
- print(f"JamWorker started: {self.params.bpm} BPM, {self._frames_per_bar:.3f} frames/bar, {chunk_bars} bars/chunk")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
 
471
  while not self._stop_event.is_set():
472
  if not self._should_generate_next_chunk():
473
  time.sleep(0.25)
474
  continue
475
 
476
- # 1) Generate until we have enough audio
477
- needed = _samples_needed(first_chunk_extra=(self.idx == 0))
478
- while needed > 0 and not self._stop_event.is_set():
479
  with self._lock:
480
  style_vec = self.params.style_vec
481
  self.mrt.guidance_weight = float(self.params.guidance_weight)
482
- self.mrt.temperature = float(self.params.temperature)
483
- self.mrt.topk = int(self.params.topk)
484
-
485
  wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
486
- self._append_model_chunk_to_stream(wav)
487
- needed = _samples_needed(first_chunk_extra=(self.idx == 0))
488
 
489
  if self._stop_event.is_set():
490
  break
491
 
492
- # 2) First chunk: perform onset alignment
493
  if (self.idx == 0 and self.params.combined_loop is not None) or self._needs_bar_realign:
494
  ref_loop = self._reseed_ref_loop or self.params.combined_loop
495
  if ref_loop is not None:
496
- offset_seconds = self._perform_onset_alignment(ref_loop)
497
- if abs(offset_seconds) > 0.01: # More than 10ms
498
- offset_samples = int(round(offset_seconds * self._model_sr))
499
- self._timing.stream_position_samples = max(0, offset_samples)
500
- print(f"Applied onset offset: {offset_seconds:.3f}s")
501
-
502
- self._align_to_bar_boundary()
 
503
  self._needs_bar_realign = False
504
  self._reseed_ref_loop = None
505
 
506
- # 3) Extract precise chunk
507
- chunk_start_bars = self._timing.emit_position_bars
508
- slice_audio = self._extract_precise_chunk(chunk_start_bars, chunk_bars)
509
-
510
- if slice_audio is None:
511
- continue # Need more generation
512
 
513
- # Update timing state
514
- correction = self._timing.advance_by_bars(chunk_bars)
515
- if correction != 0:
516
- print(f"Applied {correction} bar timing correction")
517
-
518
- self._timing.stream_position_samples += chunk_samples
519
 
520
- # 4) Create waveform and process
521
- y = au.Waveform(slice_audio.astype(np.float32, copy=False), self._model_sr).as_stereo()
522
 
523
- # Loudness matching and fades
524
  if self.idx == 0 and self.params.ref_loop is not None:
525
  y, _ = match_loudness_to_reference(
526
  self.params.ref_loop, y,
@@ -530,96 +539,38 @@ class JamWorker(threading.Thread):
530
  else:
531
  apply_micro_fades(y, 3)
532
 
533
- # 5) Sample rate conversion
534
- if self._resampler is not None:
535
- # Use streaming resampler for precise conversion
536
- resampled = self._resampler.process(y.samples, final=False)
537
-
538
- # Ensure exact target length
539
- target_samples = int(round(chunk_bars * self._samples_per_bar_target))
540
- if resampled.shape[0] != target_samples:
541
- if resampled.shape[0] < target_samples:
542
- pad_samples = target_samples - resampled.shape[0]
543
- pad = np.zeros((pad_samples, resampled.shape[1]), dtype=resampled.dtype)
544
- resampled = np.vstack([resampled, pad])
545
- else:
546
- resampled = resampled[:target_samples]
547
-
548
- final_audio = resampled
549
- final_sr = self._target_sr
550
- else:
551
- # No resampling needed
552
- final_audio = y.samples
553
- final_sr = self._model_sr
554
 
555
- # 6) Encode to base64
556
- b64, total_samples, channels = wav_bytes_base64(final_audio, final_sr)
557
-
558
- # 7) Create metadata with timing info
559
- actual_duration = total_samples / final_sr
560
- bar_range = f"{chunk_start_bars:.2f}-{self._timing.emit_position_bars:.2f}"
561
-
562
- meta = {
563
- "bpm": int(round(self.params.bpm)),
564
- "bars": int(self.params.bars_per_chunk),
565
- "beats_per_bar": int(self.params.beats_per_bar),
566
- "sample_rate": int(final_sr),
567
- "channels": int(channels),
568
- "total_samples": int(total_samples),
569
- "seconds_per_bar": self._seconds_per_bar,
570
- "loop_duration_seconds": actual_duration,
571
- "bar_range": bar_range,
572
- "timing_state": {
573
- "emit_position_bars": self._timing.emit_position_bars,
574
- "frames_per_bar": self._frames_per_bar,
575
- "fractional_error": self._timing.fractional_error_bars,
576
- },
577
- "xfade_seconds": xfade_s,
578
- "guidance_weight": self.params.guidance_weight,
579
- "temperature": self.params.temperature,
580
- "topk": self.params.topk,
581
- }
582
-
583
- # 8) Publish chunk
584
  with self._lock:
585
  self.idx += 1
586
- chunk = JamChunk(index=self.idx, audio_base64=b64, metadata=meta)
587
- self.outbox.append(chunk)
588
-
589
- # Cleanup old chunks
590
  if len(self.outbox) > 10:
591
  cutoff = self._last_delivered_index - 5
592
  self.outbox = [ch for ch in self.outbox if ch.index > cutoff]
593
 
594
- # Handle pending reseeds
595
  if self._pending_reseed is not None:
596
  pkg = self._pending_reseed
597
  self._pending_reseed = None
598
 
599
  new_state = self.mrt.init_state()
600
- new_state.context_tokens = pkg["ctx"]
601
  self.state = new_state
602
 
603
- # Reset timing and stream
604
  self._stream = None
605
- self._stream_write_pos = 0
606
- self._timing = TimingState(
607
- frames_per_bar=self._frames_per_bar,
608
- samples_per_bar=self._samples_per_bar_model
609
- )
610
- self._reseed_ref_loop = pkg.get("ref")
611
  self._needs_bar_realign = True
612
 
613
- print("Reseed applied at bar boundary")
614
 
615
- drift_ms = abs(self._timing.fractional_error_bars) * self._seconds_per_bar * 1000
616
- print(f"Completed chunk {self.idx} ({bar_range} bars, {drift_ms:.1f}ms drift)")
 
617
 
618
- print("JamWorker stopped")
619
-
620
- # Clean up resampler
621
- if self._resampler is not None:
622
- try:
623
- self._resampler.flush()
624
- except:
625
- pass
 
1
+ # jam_worker.py - SIMPLE FIX VERSION
2
+ import threading, time, base64, io, uuid
3
  from dataclasses import dataclass, field
4
  import numpy as np
5
  import soundfile as sf
 
8
  from utils import (
9
  match_loudness_to_reference, stitch_generated, hard_trim_seconds,
10
  apply_micro_fades, make_bar_aligned_context, take_bar_aligned_tail,
11
+ resample_and_snap, wav_bytes_base64
12
  )
13
 
14
  @dataclass
 
32
  audio_base64: str
33
  metadata: dict
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  class JamWorker(threading.Thread):
36
  def __init__(self, mrt, params: JamParams):
37
  super().__init__(daemon=True)
 
39
  self.params = params
40
  self.state = mrt.init_state()
41
 
42
+ # init synchronization + placeholders FIRST
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  self._lock = threading.Lock()
44
+ self._original_context_tokens = None # so hasattr checks are cheap/clear
45
 
46
  if params.combined_loop is not None:
47
  self._setup_context_from_combined_loop()
 
50
  self.outbox: list[JamChunk] = []
51
  self._stop_event = threading.Event()
52
 
 
53
  self._stream = None
54
+ self._next_emit_start = 0
55
+
56
+ # NEW: Track delivery state
57
  self._last_delivered_index = 0
58
  self._max_buffer_ahead = 5
59
 
 
 
 
 
 
 
 
 
 
 
60
  # Timing info
61
  self.last_chunk_started_at = None
62
  self.last_chunk_completed_at = None
63
 
64
+ self._pending_reseed = None # {"ctx": np.ndarray, "ref": au.Waveform|None}
65
+ self._needs_bar_realign = False # request a one-shot downbeat alignment
66
+ self._reseed_ref_loop = None # which loop to align against after reseed
67
+
68
 
69
  def _setup_context_from_combined_loop(self):
70
  """Set up MRT context tokens from the combined loop audio"""
71
  try:
72
  from utils import make_bar_aligned_context, take_bar_aligned_tail
73
 
74
+ codec_fps = float(self.mrt.codec.frame_rate)
75
  ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps
76
 
77
  loop_for_context = take_bar_aligned_tail(
 
84
  tokens_full = self.mrt.codec.encode(loop_for_context).astype(np.int32)
85
  tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
86
 
 
87
  context_tokens = make_bar_aligned_context(
88
  tokens,
89
  bpm=self.params.bpm,
90
+ fps=float(self.mrt.codec.frame_rate), # keep fractional fps
91
  ctx_frames=self.mrt.config.context_length_frames,
92
+ beats_per_bar=self.params.beats_per_bar
 
93
  )
94
 
95
+ # Install fresh context
96
  self.state.context_tokens = context_tokens
97
+ print(f" JamWorker: Set up fresh context from combined loop")
98
 
99
+ # NEW: keep a copy of the *original* context tokens for future splice-reseed
100
+ # (guard so we only set this once, at jam start)
101
  with self._lock:
102
  if not hasattr(self, "_original_context_tokens") or self._original_context_tokens is None:
103
+ self._original_context_tokens = np.copy(context_tokens) # shape: [T, depth]
104
 
105
  except Exception as e:
106
+ print(f"Failed to setup context from combined loop: {e}")
107
 
108
  def stop(self):
109
  self._stop_event.set()
110
 
111
  def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None):
112
  with self._lock:
113
+ if guidance_weight is not None: self.params.guidance_weight = float(guidance_weight)
114
+ if temperature is not None: self.params.temperature = float(temperature)
115
+ if topk is not None: self.params.topk = int(topk)
 
 
 
116
 
117
  def get_next_chunk(self) -> JamChunk | None:
118
  """Get the next sequential chunk (blocks/waits if not ready)"""
119
  target_index = self._last_delivered_index + 1
120
 
121
+ # Wait for the target chunk to be ready (with timeout)
122
+ max_wait = 30.0 # seconds
123
  start_time = time.time()
124
 
125
  while time.time() - start_time < max_wait and not self._stop_event.is_set():
126
  with self._lock:
127
+ # Look for the exact chunk we need
128
  for chunk in self.outbox:
129
  if chunk.index == target_index:
130
  self._last_delivered_index = target_index
131
+ print(f"📦 Delivered chunk {target_index}")
132
  return chunk
133
+
134
+ # Not ready yet, wait a bit
135
  time.sleep(0.1)
136
 
137
+ # Timeout or stopped
138
  return None
139
 
140
  def mark_chunk_consumed(self, chunk_index: int):
141
  """Mark a chunk as consumed by the frontend"""
142
  with self._lock:
143
  self._last_delivered_index = max(self._last_delivered_index, chunk_index)
144
+ print(f"✅ Chunk {chunk_index} consumed")
145
 
146
  def _should_generate_next_chunk(self) -> bool:
147
+ """Check if we should generate the next chunk (don't get too far ahead)"""
148
  with self._lock:
149
+ # Don't generate if we're already too far ahead
150
+ if self.idx > self._last_delivered_index + self._max_buffer_ahead:
151
+ return False
152
+ return True
153
+
154
+ def _seconds_per_bar(self) -> float:
155
+ return self.params.beats_per_bar * (60.0 / self.params.bpm)
156
+
157
+ def _snap_and_encode(self, y, seconds, target_sr, bars):
158
+ cur_sr = int(self.mrt.sample_rate)
159
+ x = y.samples if y.samples.ndim == 2 else y.samples[:, None]
160
+ x = resample_and_snap(x, cur_sr=cur_sr, target_sr=target_sr, seconds=seconds)
161
+ b64, total_samples, channels = wav_bytes_base64(x, target_sr)
162
+ meta = {
163
+ "bpm": int(round(self.params.bpm)),
164
+ "bars": int(bars),
165
+ "beats_per_bar": int(self.params.beats_per_bar),
166
+ "sample_rate": int(target_sr),
167
+ "channels": channels,
168
+ "total_samples": total_samples,
169
+ "seconds_per_bar": self._seconds_per_bar(),
170
+ "loop_duration_seconds": bars * self._seconds_per_bar(),
171
+ "guidance_weight": self.params.guidance_weight,
172
+ "temperature": self.params.temperature,
173
+ "topk": self.params.topk,
174
+ }
175
+ return b64, meta
176
 
177
  def _append_model_chunk_to_stream(self, wav):
178
+ """Incrementally append a model chunk with equal-power crossfade."""
179
  xfade_s = float(self.mrt.config.crossfade_length)
180
+ sr = int(self.mrt.sample_rate)
181
  xfade_n = int(round(xfade_s * sr))
182
 
183
  s = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None]
184
 
185
+ if getattr(self, "_stream", None) is None:
186
+ # First chunk: drop model pre-roll (xfade head)
187
  if s.shape[0] > xfade_n:
188
  self._stream = s[xfade_n:].astype(np.float32, copy=True)
189
  else:
190
  self._stream = np.zeros((0, s.shape[1]), dtype=np.float32)
191
+ self._next_emit_start = 0 # pointer into _stream (model SR samples)
192
  return
193
 
194
+ # Crossfade last xfade_n samples of _stream with head of new s
195
  if s.shape[0] <= xfade_n or self._stream.shape[0] < xfade_n:
196
+ # Degenerate safeguard
197
  self._stream = np.concatenate([self._stream, s], axis=0)
 
198
  return
199
 
 
200
  tail = self._stream[-xfade_n:]
201
  head = s[:xfade_n]
202
 
203
+ # Equal-power envelopes
204
  t = np.linspace(0, np.pi/2, xfade_n, endpoint=False, dtype=np.float32)[:, None]
205
  eq_in, eq_out = np.sin(t), np.cos(t)
206
  mixed = tail * eq_out + head * eq_in
207
 
208
  self._stream = np.concatenate([self._stream[:-xfade_n], mixed, s[xfade_n:]], axis=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  def reseed_from_waveform(self, wav):
211
+ # 1) Re-init state
212
  new_state = self.mrt.init_state()
213
+
214
+ # 2) Build bar-aligned context tokens from provided audio
215
+ codec_fps = float(self.mrt.codec.frame_rate)
216
  ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps
217
+ from utils import take_bar_aligned_tail, make_bar_aligned_context
218
+
219
  tail = take_bar_aligned_tail(wav, self.params.bpm, self.params.beats_per_bar, ctx_seconds)
220
  tokens_full = self.mrt.codec.encode(tail).astype(np.int32)
221
  tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
222
+ context_tokens = make_bar_aligned_context(tokens,
223
+ bpm=self.params.bpm, fps=float(self.mrt.codec.frame_rate),
 
 
 
224
  ctx_frames=self.mrt.config.context_length_frames,
225
+ beats_per_bar=self.params.beats_per_bar
 
226
  )
 
227
  new_state.context_tokens = context_tokens
228
  self.state = new_state
229
+ self._prepare_stream_for_reseed_handoff()
230
+
231
+ def _frames_per_bar(self) -> int:
232
+ # codec frame-rate (frames/s) -> frames per musical bar
233
+ fps = float(self.mrt.codec.frame_rate)
234
+ sec_per_bar = (60.0 / float(self.params.bpm)) * float(self.params.beats_per_bar)
235
+ return int(round(fps * sec_per_bar))
236
+
237
+ def _ctx_frames(self) -> int:
238
+ # how many codec frames fit in the model’s conditioning window
239
+ return int(self.mrt.config.context_length_frames)
240
+
241
+ def _make_recent_tokens_from_wave(self, wav) -> np.ndarray:
242
+ """
243
+ Encode waveform and produce a BAR-ALIGNED context token window.
244
+ """
245
+ tokens_full = self.mrt.codec.encode(wav).astype(np.int32) # [T, rvq_total]
246
+ tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
247
+
248
+ from utils import make_bar_aligned_context
249
+ ctx = make_bar_aligned_context(
250
+ tokens,
251
+ bpm=self.params.bpm,
252
+ fps=float(self.mrt.codec.frame_rate), # keep fractional fps
253
+ ctx_frames=self.mrt.config.context_length_frames,
254
+ beats_per_bar=self.params.beats_per_bar
255
  )
256
+ return ctx
257
+
258
+ def _bar_aligned_tail(self, tokens: np.ndarray, bars: float) -> np.ndarray:
259
+ """
260
+ Take a tail slice that is an integer number of codec frames corresponding to `bars`.
261
+ We round to nearest frame to stay phase-consistent with codec grid.
262
+ """
263
+ frames_per_bar = self._frames_per_bar()
264
+ want = max(frames_per_bar * int(round(bars)), 0)
265
+ if want == 0:
266
+ return tokens[:0] # empty
267
+ if tokens.shape[0] <= want:
268
+ return tokens
269
+ return tokens[-want:]
270
+
271
+ def _splice_context(self, original_tokens: np.ndarray, recent_tokens: np.ndarray,
272
+ anchor_bars: float) -> np.ndarray:
273
+ import math
274
+ ctx_frames = self._ctx_frames()
275
+ depth = original_tokens.shape[1]
276
+ frames_per_bar = self._frames_per_bar()
277
+
278
+ # 1) Anchor tail (whole bars)
279
+ anchor = self._bar_aligned_tail(original_tokens, math.floor(anchor_bars))
280
+
281
+ # 2) Fill remainder with recent (prefer whole bars)
282
+ a = anchor.shape[0]
283
+ remain = max(ctx_frames - a, 0)
284
+
285
+ recent = recent_tokens[:0]
286
+ used_recent = 0 # frames taken from the END of recent_tokens
287
+ if remain > 0:
288
+ bars_fit = remain // frames_per_bar
289
+ if bars_fit >= 1:
290
+ want_recent_frames = int(bars_fit * frames_per_bar)
291
+ used_recent = min(want_recent_frames, recent_tokens.shape[0])
292
+ recent = recent_tokens[-used_recent:] if used_recent > 0 else recent_tokens[:0]
293
+ else:
294
+ used_recent = min(remain, recent_tokens.shape[0])
295
+ recent = recent_tokens[-used_recent:] if used_recent > 0 else recent_tokens[:0]
296
+
297
+ # 3) Concat in order [anchor, recent]
298
+ if anchor.size or recent.size:
299
+ out = np.concatenate([anchor, recent], axis=0)
300
+ else:
301
+ # fallback: just take the last ctx window from recent
302
+ out = recent_tokens[-ctx_frames:]
303
+
304
+ # 4) Trim if we overshot
305
+ if out.shape[0] > ctx_frames:
306
+ out = out[-ctx_frames:]
307
+
308
+ # 5) Snap the **END** to the nearest LOWER bar boundary
309
+ if frames_per_bar > 0:
310
+ max_bar_aligned = (out.shape[0] // frames_per_bar) * frames_per_bar
311
+ else:
312
+ max_bar_aligned = out.shape[0]
313
+ if max_bar_aligned > 0 and out.shape[0] != max_bar_aligned:
314
+ out = out[-max_bar_aligned:]
315
+
316
+ # 6) Left-fill to reach ctx_frames **without moving the END**
317
+ deficit = ctx_frames - out.shape[0]
318
+ if deficit > 0:
319
+ left_parts = []
320
+
321
+ # Prefer frames immediately BEFORE the region we used from 'recent_tokens'
322
+ if used_recent < recent_tokens.shape[0]:
323
+ take = min(deficit, recent_tokens.shape[0] - used_recent)
324
+ if used_recent > 0:
325
+ left_parts.append(recent_tokens[-(used_recent + take) : -used_recent])
326
+ else:
327
+ left_parts.append(recent_tokens[-take:])
328
+
329
+ # Then take frames immediately BEFORE the 'anchor' in original_tokens
330
+ if sum(p.shape[0] for p in left_parts) < deficit and anchor.shape[0] > 0:
331
+ need = deficit - sum(p.shape[0] for p in left_parts)
332
+ a_len = anchor.shape[0]
333
+ avail = max(original_tokens.shape[0] - a_len, 0)
334
+ take2 = min(need, avail)
335
+ if take2 > 0:
336
+ left_parts.append(original_tokens[-(a_len + take2) : -a_len])
337
+
338
+ # Still short? tile from what's available
339
+ have = sum(p.shape[0] for p in left_parts)
340
+ if have < deficit:
341
+ base = out if out.shape[0] > 0 else (recent_tokens if recent_tokens.shape[0] > 0 else original_tokens)
342
+ reps = int(np.ceil((deficit - have) / max(1, base.shape[0])))
343
+ left_parts.append(np.tile(base, (reps, 1))[: (deficit - have)])
344
+
345
+ left = np.concatenate(left_parts, axis=0)
346
+ out = np.concatenate([left[-deficit:], out], axis=0)
347
+
348
+ # 7) Final guard to exact length
349
+ if out.shape[0] > ctx_frames:
350
+ out = out[-ctx_frames:]
351
+ elif out.shape[0] < ctx_frames:
352
+ reps = int(np.ceil(ctx_frames / max(1, out.shape[0])))
353
+ out = np.tile(out, (reps, 1))[-ctx_frames:]
354
+
355
+ # 8) Depth guard
356
+ if out.shape[1] != depth:
357
+ out = out[:, :depth]
358
+ return out
359
+
360
+
361
+ def _realign_emit_pointer_to_bar(self, sr_model: int):
362
+ """Advance _next_emit_start to the next bar boundary in model-sample space."""
363
+ bar_samps = int(round(self._seconds_per_bar() * sr_model))
364
+ if bar_samps <= 0:
365
+ return
366
+ phase = self._next_emit_start % bar_samps
367
+ if phase != 0:
368
+ self._next_emit_start += (bar_samps - phase)
369
+
370
+ def _prepare_stream_for_reseed_handoff(self):
371
+ # OLD: keep crossfade tail -> causes phase offset
372
+ # sr = int(self.mrt.sample_rate)
373
+ # xfade_s = float(self.mrt.config.crossfade_length)
374
+ # xfade_n = int(round(xfade_s * sr))
375
+ # if getattr(self, "_stream", None) is not None and self._stream.shape[0] > 0:
376
+ # tail = self._stream[-xfade_n:] if self._stream.shape[0] > xfade_n else self._stream
377
+ # self._stream = tail.copy()
378
+ # else:
379
+ # self._stream = None
380
+
381
+ # NEW: throw away the tail completely; start fresh
382
+ self._stream = None
383
+
384
+ self._next_emit_start = 0
385
  self._needs_bar_realign = True
 
386
 
387
  def reseed_splice(self, recent_wav, anchor_bars: float):
388
+ """
389
+ Token-splice reseed queued for the next bar boundary between chunks.
390
+ """
391
  with self._lock:
392
  if not hasattr(self, "_original_context_tokens") or self._original_context_tokens is None:
393
  self._original_context_tokens = np.copy(self.state.context_tokens)
394
 
395
+ recent_tokens = self._make_recent_tokens_from_wave(recent_wav) # [T, depth]
 
396
  new_ctx = self._splice_context(self._original_context_tokens, recent_tokens, anchor_bars)
397
 
398
+ # Queue it; the run loop will install right after we finish the current slice
399
  self._pending_reseed = {"ctx": new_ctx, "ref": recent_wav}
400
+
401
+ # install the new context window
402
  new_state = self.mrt.init_state()
403
  new_state.context_tokens = new_ctx
404
  self.state = new_state
405
 
406
+ self._prepare_stream_for_reseed_handoff()
 
 
 
 
 
 
 
407
 
408
+ # optional: ask streamer to drop an intro crossfade worth of audio right after reseed
409
+ self._pending_drop_intro_bars = getattr(self, "_pending_drop_intro_bars", 0) + 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
 
411
  def run(self):
412
+ """Main worker loop generate into a continuous stream, then emit bar-aligned slices."""
413
+ spb = self._seconds_per_bar() # seconds per bar
414
+ chunk_secs = self.params.bars_per_chunk * spb
415
+ xfade = float(self.mrt.config.crossfade_length) # seconds
416
+ sr = int(self.mrt.sample_rate)
417
+ chunk_samps = int(round(chunk_secs * sr))
418
+
419
+ def _need(first_chunk_extra=False):
420
+ """How many more samples we still need in the stream to emit next slice."""
421
+ have = 0 if getattr(self, "_stream", None) is None else self._stream.shape[0] - getattr(self, "_next_emit_start", 0)
422
+ want = chunk_samps
423
  if first_chunk_extra:
424
+ # reserve two bars extra so first-chunk onset alignment has material
425
+ want += int(round(2 * spb * sr))
426
+ return max(0, want - have)
427
+
428
+ def _mono_env(x: np.ndarray, sr: int, win_ms: float = 10.0) -> np.ndarray:
429
+ if x.ndim == 2: x = x.mean(axis=1)
430
+ x = np.abs(x).astype(np.float32)
431
+ w = max(1, int(round(win_ms * 1e-3 * sr)))
432
+ if w > 1:
433
+ kern = np.ones(w, dtype=np.float32) / float(w)
434
+ x = np.convolve(x, kern, mode="same")
435
+ d = np.diff(x, prepend=x[:1])
436
+ d[d < 0] = 0.0
437
+ return d
438
+
439
+ def _estimate_first_offset_samples(ref_loop_wav, gen_head_wav, sr: int, spb: float) -> int:
440
+ """Tempo-aware first-downbeat offset (positive => model late)."""
441
+ try:
442
+ max_ms = int(max(160.0, min(0.25 * spb * 1000.0, 450.0)))
443
+ ref = ref_loop_wav if ref_loop_wav.sample_rate == sr else ref_loop_wav.resample(sr)
444
+ n_bar = int(round(spb * sr))
445
+ ref_tail = ref.samples[-n_bar:, :] if ref.samples.shape[0] >= n_bar else ref.samples
446
+ gen_head = gen_head_wav.samples[: int(2 * n_bar), :]
447
+ if ref_tail.size == 0 or gen_head.size == 0:
448
+ return 0
449
+
450
+ # envelopes + z-score
451
+ import numpy as np
452
+ def _z(a):
453
+ m, s = float(a.mean()), float(a.std() or 1.0); return (a - m) / s
454
+ e_ref = _z(_mono_env(ref_tail, sr)).astype(np.float32)
455
+ e_gen = _z(_mono_env(gen_head, sr)).astype(np.float32)
456
+
457
+ # upsample x4 for finer lag
458
+ def _upsample(a, r=4):
459
+ n = len(a); grid = np.arange(n, dtype=np.float32)
460
+ fine = np.linspace(0, n - 1, num=n * r, dtype=np.float32)
461
+ return np.interp(fine, grid, a).astype(np.float32)
462
+ up = 4
463
+ e_ref_u, e_gen_u = _upsample(e_ref, up), _upsample(e_gen, up)
464
+
465
+ max_lag_u = int(round((max_ms / 1000.0) * sr * up))
466
+ seg = min(len(e_ref_u), len(e_gen_u))
467
+ e_ref_u = e_ref_u[-seg:]
468
+ pad = np.zeros(max_lag_u, dtype=np.float32)
469
+ e_gen_u_pad = np.concatenate([pad, e_gen_u, pad])
470
+
471
+ best_lag_u, best_score = 0, -1e9
472
+ for lag_u in range(-max_lag_u, max_lag_u + 1):
473
+ start = max_lag_u + lag_u
474
+ b = e_gen_u_pad[start : start + seg]
475
+ denom = (np.linalg.norm(e_ref_u) * np.linalg.norm(b)) or 1.0
476
+ score = float(np.dot(e_ref_u, b) / denom)
477
+ if score > best_score:
478
+ best_score, best_lag_u = score, lag_u
479
+ return int(round(best_lag_u / up))
480
+ except Exception:
481
+ return 0
482
+
483
+ print("🚀 JamWorker started (bar-aligned streaming)…")
484
 
485
  while not self._stop_event.is_set():
486
  if not self._should_generate_next_chunk():
487
  time.sleep(0.25)
488
  continue
489
 
490
+ # 1) Generate until we have enough material in the stream
491
+ need = _need(first_chunk_extra=(self.idx == 0))
492
+ while need > 0 and not self._stop_event.is_set():
493
  with self._lock:
494
  style_vec = self.params.style_vec
495
  self.mrt.guidance_weight = float(self.params.guidance_weight)
496
+ self.mrt.temperature = float(self.params.temperature)
497
+ self.mrt.topk = int(self.params.topk)
 
498
  wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
499
+ self._append_model_chunk_to_stream(wav) # equal-power xfade into a persistent stream
500
+ need = _need(first_chunk_extra=(self.idx == 0))
501
 
502
  if self._stop_event.is_set():
503
  break
504
 
505
+ # 2) One-time: align the emit pointer to the groove
506
  if (self.idx == 0 and self.params.combined_loop is not None) or self._needs_bar_realign:
507
  ref_loop = self._reseed_ref_loop or self.params.combined_loop
508
  if ref_loop is not None:
509
+ head_len = min(self._stream.shape[0] - self._next_emit_start, int(round(2 * spb * sr)))
510
+ seg = self._stream[self._next_emit_start : self._next_emit_start + head_len]
511
+ gen_head = au.Waveform(seg.astype(np.float32, copy=False), sr).as_stereo()
512
+ offs = _estimate_first_offset_samples(ref_loop, gen_head, sr, spb)
513
+ if offs != 0:
514
+ self._next_emit_start = max(0, self._next_emit_start + offs)
515
+ print(f"🎯 Offset compensation: {offs/sr:+.3f}s")
516
+ self._realign_emit_pointer_to_bar(sr)
517
  self._needs_bar_realign = False
518
  self._reseed_ref_loop = None
519
 
520
+ # 3) Emit exactly bars_per_chunk × spb from the stream
521
+ start = self._next_emit_start
522
+ end = start + chunk_samps
523
+ if end > self._stream.shape[0]:
524
+ # shouldn't happen often; generate a bit more and loop
525
+ continue
526
 
527
+ slice_ = self._stream[start:end]
528
+ self._next_emit_start = end
 
 
 
 
529
 
530
+ y = au.Waveform(slice_.astype(np.float32, copy=False), sr).as_stereo()
 
531
 
532
+ # 4) Post-processing / loudness
533
  if self.idx == 0 and self.params.ref_loop is not None:
534
  y, _ = match_loudness_to_reference(
535
  self.params.ref_loop, y,
 
539
  else:
540
  apply_micro_fades(y, 3)
541
 
542
+ # 5) Resample + exact-length snap + encode
543
+ b64, meta = self._snap_and_encode(
544
+ y, seconds=chunk_secs, target_sr=self.params.target_sr, bars=self.params.bars_per_chunk
545
+ )
546
+ meta["xfade_seconds"] = xfade
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
 
548
+ # 6) Publish
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549
  with self._lock:
550
  self.idx += 1
551
+ self.outbox.append(JamChunk(index=self.idx, audio_base64=b64, metadata=meta))
 
 
 
552
  if len(self.outbox) > 10:
553
  cutoff = self._last_delivered_index - 5
554
  self.outbox = [ch for ch in self.outbox if ch.index > cutoff]
555
 
556
+ # 👉 If a reseed was requested, apply it *now*, between chunks
557
  if self._pending_reseed is not None:
558
  pkg = self._pending_reseed
559
  self._pending_reseed = None
560
 
561
  new_state = self.mrt.init_state()
562
+ new_state.context_tokens = pkg["ctx"] # exact (ctx_frames, depth)
563
  self.state = new_state
564
 
565
+ # start a fresh stream and schedule one-time alignment
566
  self._stream = None
567
+ self._next_emit_start = 0
568
+ self._reseed_ref_loop = pkg.get("ref") or self.params.combined_loop
 
 
 
 
569
  self._needs_bar_realign = True
570
 
571
+ print("🔁 Reseed installed at bar boundary; will realign before next slice")
572
 
573
+ print(f"✅ Completed chunk {self.idx}")
574
+
575
+ print("🛑 JamWorker stopped")
576
 
 
 
 
 
 
 
 
 
utils.py CHANGED
@@ -109,81 +109,55 @@ def apply_micro_fades(wav: au.Waveform, ms: int = 5) -> None:
109
 
110
 
111
  # ---------- Token context helpers ----------
112
- def make_bar_aligned_context(tokens, bpm, fps=25.0, ctx_frames=250, beats_per_bar=4, precise_timing=False):
113
  """
114
  Return a ctx_frames-long slice of `tokens` whose **end** lands on the nearest
115
- whole-bar boundary in codec-frame space.
116
-
117
- NEW: precise_timing mode handles fractional frames per bar more carefully.
 
 
 
 
118
  """
 
 
119
  if tokens is None:
120
  raise ValueError("tokens is None")
121
  tokens = np.asarray(tokens)
122
  if tokens.ndim == 1:
123
- tokens = tokens[:, None]
124
 
125
  T = tokens.shape[0]
126
  if T == 0:
127
  return tokens
128
 
129
  fps = float(fps)
130
- frames_per_bar_f = (beats_per_bar * 60.0 / float(bpm)) * fps
131
-
132
- if precise_timing and abs(frames_per_bar_f - round(frames_per_bar_f)) > 1e-6:
133
- # We have fractional frames per bar - use a different strategy
134
- # Instead of trying to align to exact bar boundaries, align to the closest
135
- # multiple of frames_per_bar_f that gives us integer frame positions
136
-
137
- # Tile enough to work with
138
- reps = max(2, int(np.ceil((ctx_frames + T) / float(T))))
139
- tiled = np.tile(tokens, (reps, 1))
140
- total = tiled.shape[0]
141
-
142
- # Find the best integer end position that's close to a bar boundary
143
- best_end = ctx_frames
144
- best_error = float('inf')
145
-
146
- # Check positions around the naive ctx_frames endpoint
147
- for candidate_end in range(max(ctx_frames - 50, ctx_frames), min(total, ctx_frames + 50)):
148
- # How many fractional bars does this represent?
149
- fractional_bars = candidate_end / frames_per_bar_f
150
- # How far from an integer number of bars?
151
- bar_error = abs(fractional_bars - round(fractional_bars))
152
-
153
- if bar_error < best_error:
154
- best_error = bar_error
155
- best_end = candidate_end
156
-
157
- end_idx = best_end
158
- start_idx = max(0, end_idx - ctx_frames)
159
-
160
- window = tiled[start_idx:end_idx]
161
-
162
- # Report timing info for debugging
163
- actual_bars = end_idx / frames_per_bar_f
164
- print(f"Context aligned to {actual_bars:.3f} bars (error: {best_error:.4f})")
165
-
166
- else:
167
- # Original logic for integer frames per bar
168
- reps = int(np.ceil((ctx_frames + T) / float(T))) + 1
169
- tiled = np.tile(tokens, (reps, 1))
170
- total = tiled.shape[0]
171
-
172
- k_bars = int(np.floor(total / frames_per_bar_f))
173
- if k_bars <= 0:
174
- window = tiled[-ctx_frames:]
175
- return window
176
-
177
- end_idx = int(round(k_bars * frames_per_bar_f))
178
- end_idx = min(max(end_idx, ctx_frames), total)
179
- start_idx = end_idx - ctx_frames
180
- if start_idx < 0:
181
- start_idx = 0
182
- end_idx = ctx_frames
183
-
184
- window = tiled[start_idx:end_idx]
185
-
186
- # Ensure exact length
187
  if window.shape[0] < ctx_frames:
188
  pad = np.tile(tokens, (int(np.ceil((ctx_frames - window.shape[0]) / T)), 1))
189
  window = np.vstack([window, pad])[:ctx_frames]
 
109
 
110
 
111
  # ---------- Token context helpers ----------
112
+ def make_bar_aligned_context(tokens, bpm, fps=25.0, ctx_frames=250, beats_per_bar=4):
113
  """
114
  Return a ctx_frames-long slice of `tokens` whose **end** lands on the nearest
115
+ whole-bar boundary in codec-frame space, even when frames_per_bar is fractional.
116
+
117
+ tokens: np.ndarray of shape (T, D) or (T,) where T = codec frames
118
+ bpm: float
119
+ fps: float (codec frames per second; keep this as float)
120
+ ctx_frames: int (length of context window in codec frames)
121
+ beats_per_bar: int
122
  """
123
+
124
+
125
  if tokens is None:
126
  raise ValueError("tokens is None")
127
  tokens = np.asarray(tokens)
128
  if tokens.ndim == 1:
129
+ tokens = tokens[:, None] # promote to (T, 1) for uniform tiling
130
 
131
  T = tokens.shape[0]
132
  if T == 0:
133
  return tokens
134
 
135
  fps = float(fps)
136
+ frames_per_bar_f = (beats_per_bar * 60.0 / float(bpm)) * fps # float frames per bar
137
+
138
+ # Tile a little more than we need so we can always snap the END to a bar boundary
139
+ reps = int(np.ceil((ctx_frames + T) / float(T))) + 1
140
+ tiled = np.tile(tokens, (reps, 1))
141
+ total = tiled.shape[0]
142
+
143
+ # How many whole bars fit?
144
+ k_bars = int(np.floor(total / frames_per_bar_f))
145
+ if k_bars <= 0:
146
+ # Fallback: just take the last ctx_frames
147
+ window = tiled[-ctx_frames:]
148
+ return window
149
+
150
+ # Snap END index to the nearest integer frame at a whole-bar boundary
151
+ end_idx = int(round(k_bars * frames_per_bar_f))
152
+ end_idx = min(max(end_idx, ctx_frames), total)
153
+ start_idx = end_idx - ctx_frames
154
+ if start_idx < 0:
155
+ start_idx = 0
156
+ end_idx = ctx_frames
157
+
158
+ window = tiled[start_idx:end_idx]
159
+
160
+ # Guard against rare off-by-one due to rounding
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  if window.shape[0] < ctx_frames:
162
  pad = np.tile(tokens, (int(np.ceil((ctx_frames - window.shape[0]) / T)), 1))
163
  window = np.vstack([window, pad])[:ctx_frames]