thecollabagepatch commited on
Commit
b1bc032
·
1 Parent(s): bf8ae4c

incremental change for bpm sync

Browse files
Files changed (3) hide show
  1. Dockerfile +2 -0
  2. app.py +2 -2
  3. jam_worker.py +68 -28
Dockerfile CHANGED
@@ -128,7 +128,9 @@ RUN python -m pip install --no-cache-dir --force-reinstall "protobuf==4.25.3"
128
 
129
  RUN python -m pip install gradio
130
 
 
131
 
 
132
 
133
  # Switch to Spaces’ preferred user
134
  # Switch to Spaces’ preferred user
 
128
 
129
  RUN python -m pip install gradio
130
 
131
+ RUN python -m pip install soxr
132
 
133
+ RUN python -m pip install samplerate
134
 
135
  # Switch to Spaces’ preferred user
136
  # Switch to Spaces’ preferred user
app.py CHANGED
@@ -308,7 +308,7 @@ def generate_loop_continuation_with_mrt(
308
 
309
  # Bar-aligned token window (unchanged)
310
  context_tokens = make_bar_aligned_context(
311
- tokens, bpm=bpm, fps=int(mrt.codec.frame_rate),
312
  ctx_frames=mrt.config.context_length_frames, beats_per_bar=beats_per_bar
313
  )
314
  state = mrt.init_state()
@@ -441,7 +441,7 @@ def _mrt_warmup():
441
  context_tokens = make_bar_aligned_context(
442
  tokens,
443
  bpm=bpm,
444
- fps=int(mrt.codec.frame_rate),
445
  ctx_frames=mrt.config.context_length_frames,
446
  beats_per_bar=beats_per_bar,
447
  )
 
308
 
309
  # Bar-aligned token window (unchanged)
310
  context_tokens = make_bar_aligned_context(
311
+ tokens, bpm=bpm, fps=float(mrt.codec.frame_rate),
312
  ctx_frames=mrt.config.context_length_frames, beats_per_bar=beats_per_bar
313
  )
314
  state = mrt.init_state()
 
441
  context_tokens = make_bar_aligned_context(
442
  tokens,
443
  bpm=bpm,
444
+ fps=float(mrt.codec.frame_rate),
445
  ctx_frames=mrt.config.context_length_frames,
446
  beats_per_bar=beats_per_bar,
447
  )
jam_worker.py CHANGED
@@ -10,6 +10,7 @@ from utils import (
10
  apply_micro_fades, make_bar_aligned_context, take_bar_aligned_tail,
11
  resample_and_snap, wav_bytes_base64
12
  )
 
13
 
14
  @dataclass
15
  class JamParams:
@@ -61,6 +62,10 @@ class JamWorker(threading.Thread):
61
  self.last_chunk_started_at = None
62
  self.last_chunk_completed_at = None
63
 
 
 
 
 
64
 
65
  def _setup_context_from_combined_loop(self):
66
  """Set up MRT context tokens from the combined loop audio"""
@@ -382,19 +387,18 @@ class JamWorker(threading.Thread):
382
 
383
  def reseed_splice(self, recent_wav, anchor_bars: float):
384
  """
385
- Token-splice reseed:
386
- - original = the context we captured when the jam started
387
- - recent = tokens from the provided recent waveform (usually Swift-combined mix)
388
- - anchor_bars controls how much of the original vibe we re-inject
389
  """
390
  with self._lock:
391
  if not hasattr(self, "_original_context_tokens") or self._original_context_tokens is None:
392
- # Fallback: if we somehow don’t have originals, treat current as originals
393
  self._original_context_tokens = np.copy(self.state.context_tokens)
394
 
395
- recent_tokens = self._make_recent_tokens_from_wave(recent_wav) # [T, depth]
396
  new_ctx = self._splice_context(self._original_context_tokens, recent_tokens, anchor_bars)
397
 
 
 
 
398
  # install the new context window
399
  new_state = self.mrt.init_state()
400
  new_state.context_tokens = new_ctx
@@ -411,15 +415,31 @@ class JamWorker(threading.Thread):
411
  chunk_secs = self.params.bars_per_chunk * spb
412
  xfade = float(self.mrt.config.crossfade_length) # seconds
413
  sr = int(self.mrt.sample_rate)
414
- chunk_samps = int(round(chunk_secs * sr))
415
-
416
- def _need(first_chunk_extra=False):
417
- """How many more samples we still need in the stream to emit next slice."""
418
- have = 0 if getattr(self, "_stream", None) is None else self._stream.shape[0] - getattr(self, "_next_emit_start", 0)
419
- want = chunk_samps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  if first_chunk_extra:
421
- # reserve two bars extra so first-chunk onset alignment has material
422
- want += int(round(2 * spb * sr))
 
 
423
  return max(0, want - have)
424
 
425
  def _mono_env(x: np.ndarray, sr: int, win_ms: float = 10.0) -> np.ndarray:
@@ -445,7 +465,6 @@ class JamWorker(threading.Thread):
445
  return 0
446
 
447
  # envelopes + z-score
448
- import numpy as np
449
  def _z(a):
450
  m, s = float(a.mean()), float(a.std() or 1.0); return (a - m) / s
451
  e_ref = _z(_mono_env(ref_tail, sr)).astype(np.float32)
@@ -500,22 +519,26 @@ class JamWorker(threading.Thread):
500
  break
501
 
502
  # 2) One-time: align the emit pointer to the groove
503
- if self.idx == 0 and self.params.combined_loop is not None:
504
- # Compare ref tail vs the head of what we're about to emit
505
- head_len = min(self._stream.shape[0] - self._next_emit_start, int(round(2 * spb * sr)))
506
- seg = self._stream[self._next_emit_start : self._next_emit_start + head_len]
507
- gen_head = au.Waveform(seg.astype(np.float32, copy=False), sr).as_stereo()
508
- offs = _estimate_first_offset_samples(self.params.combined_loop, gen_head, sr, spb)
509
- if offs != 0:
510
- # positive => model late: skip some samples; negative => model early: "rewind" by padding
511
- self._next_emit_start = max(0, self._next_emit_start + offs)
512
- print(f"🎯 First-chunk offset compensation: {offs/sr:+.3f}s")
513
- # snap to next bar boundary
514
- self._realign_emit_pointer_to_bar(sr)
 
515
 
516
  # 3) Emit exactly bars_per_chunk × spb from the stream
517
  start = self._next_emit_start
518
- end = start + chunk_samps
 
 
 
519
  if end > self._stream.shape[0]:
520
  # shouldn't happen often; generate a bit more and loop
521
  continue
@@ -549,6 +572,23 @@ class JamWorker(threading.Thread):
549
  cutoff = self._last_delivered_index - 5
550
  self.outbox = [ch for ch in self.outbox if ch.index > cutoff]
551
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
552
  print(f"✅ Completed chunk {self.idx}")
553
 
554
  print("🛑 JamWorker stopped")
 
10
  apply_micro_fades, make_bar_aligned_context, take_bar_aligned_tail,
11
  resample_and_snap, wav_bytes_base64
12
  )
13
+ from math import floor, ceil
14
 
15
  @dataclass
16
  class JamParams:
 
62
  self.last_chunk_started_at = None
63
  self.last_chunk_completed_at = None
64
 
65
+ self._pending_reseed = None # {"ctx": np.ndarray, "ref": au.Waveform|None}
66
+ self._needs_bar_realign = False # request a one-shot downbeat alignment
67
+ self._reseed_ref_loop = None # which loop to align against after reseed
68
+
69
 
70
  def _setup_context_from_combined_loop(self):
71
  """Set up MRT context tokens from the combined loop audio"""
 
387
 
388
  def reseed_splice(self, recent_wav, anchor_bars: float):
389
  """
390
+ Token-splice reseed queued for the next bar boundary between chunks.
 
 
 
391
  """
392
  with self._lock:
393
  if not hasattr(self, "_original_context_tokens") or self._original_context_tokens is None:
 
394
  self._original_context_tokens = np.copy(self.state.context_tokens)
395
 
396
+ recent_tokens = self._make_recent_tokens_from_wave(recent_wav) # [T, depth]
397
  new_ctx = self._splice_context(self._original_context_tokens, recent_tokens, anchor_bars)
398
 
399
+ # Queue it; the run loop will install right after we finish the current slice
400
+ self._pending_reseed = {"ctx": new_ctx, "ref": recent_wav}
401
+
402
  # install the new context window
403
  new_state = self.mrt.init_state()
404
  new_state.context_tokens = new_ctx
 
415
  chunk_secs = self.params.bars_per_chunk * spb
416
  xfade = float(self.mrt.config.crossfade_length) # seconds
417
  sr = int(self.mrt.sample_rate)
418
+ chunk_step_f = chunk_secs * sr # float samples per chunk
419
+ self._emit_phase = getattr(self, "_emit_phase", 0.0)
420
+
421
+ def _need(first_chunk_extra: bool = False) -> int:
422
+ """
423
+ How many more samples we still need in the stream to emit the next slice.
424
+ Uses the fractional step (chunk_step_f) + current _emit_phase to compute
425
+ the *integer* number of samples required for the next chunk, without
426
+ mutating _emit_phase here.
427
+ """
428
+ start = getattr(self, "_next_emit_start", 0)
429
+ total = 0 if getattr(self, "_stream", None) is None else self._stream.shape[0]
430
+ have = max(0, total - start)
431
+
432
+ # Compute the integer step we'd use for the next emit, non-mutating.
433
+ emit_phase = float(getattr(self, "_emit_phase", 0.0))
434
+ step_int = int(floor(chunk_step_f + emit_phase)) # matches the logic used when advancing
435
+
436
+ # How much we want available beyond 'start' for this emit.
437
+ want = step_int
438
  if first_chunk_extra:
439
+ # Reserve two extra bars so the first-chunk onset alignment has material.
440
+ # Use ceil to be conservative so we don't under-request.
441
+ want += int(ceil(2.0 * spb * sr))
442
+
443
  return max(0, want - have)
444
 
445
  def _mono_env(x: np.ndarray, sr: int, win_ms: float = 10.0) -> np.ndarray:
 
465
  return 0
466
 
467
  # envelopes + z-score
 
468
  def _z(a):
469
  m, s = float(a.mean()), float(a.std() or 1.0); return (a - m) / s
470
  e_ref = _z(_mono_env(ref_tail, sr)).astype(np.float32)
 
519
  break
520
 
521
  # 2) One-time: align the emit pointer to the groove
522
+ if (self.idx == 0 and self.params.combined_loop is not None) or self._needs_bar_realign:
523
+ ref_loop = self._reseed_ref_loop or self.params.combined_loop
524
+ if ref_loop is not None:
525
+ head_len = min(self._stream.shape[0] - self._next_emit_start, int(round(2 * spb * sr)))
526
+ seg = self._stream[self._next_emit_start : self._next_emit_start + head_len]
527
+ gen_head = au.Waveform(seg.astype(np.float32, copy=False), sr).as_stereo()
528
+ offs = _estimate_first_offset_samples(ref_loop, gen_head, sr, spb)
529
+ if offs != 0:
530
+ self._next_emit_start = max(0, self._next_emit_start + offs)
531
+ print(f"🎯 Offset compensation: {offs/sr:+.3f}s")
532
+ self._realign_emit_pointer_to_bar(sr)
533
+ self._needs_bar_realign = False
534
+ self._reseed_ref_loop = None
535
 
536
  # 3) Emit exactly bars_per_chunk × spb from the stream
537
  start = self._next_emit_start
538
+ step_total = chunk_step_f + self._emit_phase
539
+ step_int = int(np.floor(step_total))
540
+ self._emit_phase = float(step_total - step_int)
541
+ end = start + step_int
542
  if end > self._stream.shape[0]:
543
  # shouldn't happen often; generate a bit more and loop
544
  continue
 
572
  cutoff = self._last_delivered_index - 5
573
  self.outbox = [ch for ch in self.outbox if ch.index > cutoff]
574
 
575
+ # 👉 If a reseed was requested, apply it *now*, between chunks
576
+ if self._pending_reseed is not None:
577
+ pkg = self._pending_reseed
578
+ self._pending_reseed = None
579
+
580
+ new_state = self.mrt.init_state()
581
+ new_state.context_tokens = pkg["ctx"] # exact (ctx_frames, depth)
582
+ self.state = new_state
583
+
584
+ # start a fresh stream and schedule one-time alignment
585
+ self._stream = None
586
+ self._next_emit_start = 0
587
+ self._reseed_ref_loop = pkg.get("ref") or self.params.combined_loop
588
+ self._needs_bar_realign = True
589
+
590
+ print("🔁 Reseed installed at bar boundary; will realign before next slice")
591
+
592
  print(f"✅ Completed chunk {self.idx}")
593
 
594
  print("🛑 JamWorker stopped")