thecollabagepatch commited on
Commit
eb5d99b
Β·
1 Parent(s): d41a575

rewritten run method in jam worker

Browse files
Files changed (1) hide show
  1. jam_worker.py +77 -90
jam_worker.py CHANGED
@@ -50,6 +50,9 @@ class JamWorker(threading.Thread):
50
  self.outbox: list[JamChunk] = []
51
  self._stop_event = threading.Event()
52
 
 
 
 
53
  # NEW: Track delivery state
54
  self._last_delivered_index = 0
55
  self._max_buffer_ahead = 5
@@ -350,139 +353,127 @@ class JamWorker(threading.Thread):
350
  self._pending_drop_intro_bars = getattr(self, "_pending_drop_intro_bars", 0) + 1
351
 
352
  def run(self):
353
- """Main worker loop - generate chunks continuously but don't get too far ahead"""
354
- spb = self._seconds_per_bar()
355
  chunk_secs = self.params.bars_per_chunk * spb
356
- xfade = float(self.mrt.config.crossfade_length) # seconds
 
 
 
 
 
 
 
 
 
 
 
357
 
358
  def _mono_env(x: np.ndarray, sr: int, win_ms: float = 10.0) -> np.ndarray:
359
- """Rectified moving-average envelope, then a simple onset-y novelty (half-wave diff)."""
360
- if x.ndim == 2:
361
- x = x.mean(axis=1)
362
  x = np.abs(x).astype(np.float32)
363
  w = max(1, int(round(win_ms * 1e-3 * sr)))
364
  if w > 1:
365
  kern = np.ones(w, dtype=np.float32) / float(w)
366
  x = np.convolve(x, kern, mode="same")
367
- # onset-ish novelty: positive first difference (half-wave)
368
  d = np.diff(x, prepend=x[:1])
369
  d[d < 0] = 0.0
370
  return d
371
 
372
- def _estimate_first_offset_samples(ref_loop_wav, gen_wav, sr: int, spb: float, max_ms: int = 180) -> int:
373
- """
374
- Estimate how late/early the first downbeat is by correlating
375
- the last bar of the reference vs the first two bars of the generated chunk.
376
- Allows small +/- offsets; upsample envelopes x4 for sub-sample precision then round.
377
- """
378
  try:
 
379
  ref = ref_loop_wav if ref_loop_wav.sample_rate == sr else ref_loop_wav.resample(sr)
380
  n_bar = int(round(spb * sr))
381
-
382
  ref_tail = ref.samples[-n_bar:, :] if ref.samples.shape[0] >= n_bar else ref.samples
383
- gen_head = gen_wav.samples[: int(2 * n_bar), :]
384
  if ref_tail.size == 0 or gen_head.size == 0:
385
  return 0
386
 
387
- e_ref = _mono_env(ref_tail, sr) # length ~ n_bar
388
- e_gen = _mono_env(gen_head, sr) # length ~ 2*n_bar
389
-
390
- # z-score for scale invariance
391
  def _z(a):
392
- m, s = float(a.mean()), float(a.std() or 1.0)
393
- return (a - m) / s
394
- e_ref = _z(e_ref).astype(np.float32)
395
- e_gen = _z(e_gen).astype(np.float32)
396
 
397
- # Light upsampling for finer lag resolution (x4)
398
  def _upsample(a, r=4):
399
- n = len(a)
400
- grid = np.arange(n, dtype=np.float32)
401
  fine = np.linspace(0, n - 1, num=n * r, dtype=np.float32)
402
  return np.interp(fine, grid, a).astype(np.float32)
403
  up = 4
404
- e_ref_u = _upsample(e_ref, up)
405
- e_gen_u = _upsample(e_gen, up)
406
 
407
- # Correlate in a tight window
408
  max_lag_u = int(round((max_ms / 1000.0) * sr * up))
409
  seg = min(len(e_ref_u), len(e_gen_u))
410
  e_ref_u = e_ref_u[-seg:]
411
- # pad head so we can slide +/- lags
412
  pad = np.zeros(max_lag_u, dtype=np.float32)
413
  e_gen_u_pad = np.concatenate([pad, e_gen_u, pad])
414
 
415
  best_lag_u, best_score = 0, -1e9
416
- # allow tiny early OR late (negative = model early, positive = late)
417
  for lag_u in range(-max_lag_u, max_lag_u + 1):
418
  start = max_lag_u + lag_u
419
  b = e_gen_u_pad[start : start + seg]
420
- # normalized dot (already z-scored, but keep it consistent)
421
  denom = (np.linalg.norm(e_ref_u) * np.linalg.norm(b)) or 1.0
422
  score = float(np.dot(e_ref_u, b) / denom)
423
  if score > best_score:
424
  best_score, best_lag_u = score, lag_u
425
-
426
- # convert envelope-lag back to audio samples and round
427
- lag_samples = int(round(best_lag_u / up))
428
- return lag_samples
429
  except Exception:
430
  return 0
431
 
432
- print("πŸš€ JamWorker started with flow control...")
433
 
434
  while not self._stop_event.is_set():
435
- # Don’t get too far ahead of the consumer
436
  if not self._should_generate_next_chunk():
437
- print("⏸️ Buffer full, waiting for consumption...")
438
- time.sleep(0.5)
439
  continue
440
 
441
- # Snapshot knobs + compute index
442
- with self._lock:
443
- style_vec = self.params.style_vec
444
- self.mrt.guidance_weight = float(self.params.guidance_weight)
445
- self.mrt.temperature = float(self.params.temperature)
446
- self.mrt.topk = int(self.params.topk)
447
- next_idx = self.idx + 1
448
-
449
- print(f"🎹 Generating chunk {next_idx}...")
450
- self.last_chunk_started_at = time.time()
451
-
452
- # ---- Generate enough model sub-chunks to yield *audible* chunk_secs ----
453
- # First sub-chunk contributes full L; subsequent contribute (L - xfade)
454
- assembled = 0.0
455
- chunks = []
456
- while assembled < chunk_secs and not self._stop_event.is_set():
457
  wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
458
- chunks.append(wav)
459
- L = wav.samples.shape[0] / float(self.mrt.sample_rate)
460
- assembled += L if len(chunks) == 1 else max(0.0, L - xfade)
461
 
462
  if self._stop_event.is_set():
463
  break
464
 
465
- # ---- Stitch (utils drops the very first model pre-roll) & trim at model SR ----
466
- y = stitch_generated(chunks, self.mrt.sample_rate, xfade).as_stereo()
467
- y = hard_trim_seconds(y, chunk_secs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
 
469
- # ---- ONE-TIME: grid-align the very first jam chunk to kill the flam ----
470
- if next_idx == 1 and self.params.combined_loop is not None:
471
- offset = _estimate_first_offset_samples(
472
- self.params.combined_loop, y, int(self.mrt.sample_rate), spb, max_ms=180 # try 160–200
473
- )
474
- if offset != 0:
475
- # positive => model late: trim head; negative => model early: pad head (rare)
476
- if offset > 0:
477
- y.samples = y.samples[offset:, :]
478
- else:
479
- pad = np.zeros((abs(offset), y.samples.shape[1]), dtype=y.samples.dtype)
480
- y.samples = np.concatenate([pad, y.samples], axis=0)
481
- print(f"🎯 First-chunk offset compensation: {offset/self.mrt.sample_rate:+.3f}s")
482
- y = hard_trim_seconds(y, chunk_secs)
483
-
484
- # ---- Post-processing ----
485
- if next_idx == 1 and self.params.ref_loop is not None:
486
  y, _ = match_loudness_to_reference(
487
  self.params.ref_loop, y,
488
  method=self.params.loudness_mode,
@@ -491,25 +482,21 @@ class JamWorker(threading.Thread):
491
  else:
492
  apply_micro_fades(y, 3)
493
 
494
- # ---- Resample + bar-snap + encode ----
495
  b64, meta = self._snap_and_encode(
496
- y,
497
- seconds=chunk_secs,
498
- target_sr=self.params.target_sr,
499
- bars=self.params.bars_per_chunk
500
  )
501
- meta["xfade_seconds"] = xfade # tiny hint for client if you want butter at chunk joins
502
 
503
- # ---- Publish ----
504
  with self._lock:
505
- self.idx = next_idx
506
- self.outbox.append(JamChunk(index=next_idx, audio_base64=b64, metadata=meta))
507
  if len(self.outbox) > 10:
508
  cutoff = self._last_delivered_index - 5
509
  self.outbox = [ch for ch in self.outbox if ch.index > cutoff]
510
 
511
- self.last_chunk_completed_at = time.time()
512
- print(f"βœ… Completed chunk {next_idx}")
513
 
514
  print("πŸ›‘ JamWorker stopped")
515
 
 
50
  self.outbox: list[JamChunk] = []
51
  self._stop_event = threading.Event()
52
 
53
+ self._stream = None
54
+ self._next_emit_start = 0
55
+
56
  # NEW: Track delivery state
57
  self._last_delivered_index = 0
58
  self._max_buffer_ahead = 5
 
353
  self._pending_drop_intro_bars = getattr(self, "_pending_drop_intro_bars", 0) + 1
354
 
355
  def run(self):
356
+ """Main worker loop β€” generate into a continuous stream, then emit bar-aligned slices."""
357
+ spb = self._seconds_per_bar() # seconds per bar
358
  chunk_secs = self.params.bars_per_chunk * spb
359
+ xfade = float(self.mrt.config.crossfade_length) # seconds
360
+ sr = int(self.mrt.sample_rate)
361
+ chunk_samps = int(round(chunk_secs * sr))
362
+
363
+ def _need(first_chunk_extra=False):
364
+ """How many more samples we still need in the stream to emit next slice."""
365
+ have = 0 if getattr(self, "_stream", None) is None else self._stream.shape[0] - getattr(self, "_next_emit_start", 0)
366
+ want = chunk_samps
367
+ if first_chunk_extra:
368
+ # reserve two bars extra so first-chunk onset alignment has material
369
+ want += int(round(2 * spb * sr))
370
+ return max(0, want - have)
371
 
372
  def _mono_env(x: np.ndarray, sr: int, win_ms: float = 10.0) -> np.ndarray:
373
+ if x.ndim == 2: x = x.mean(axis=1)
 
 
374
  x = np.abs(x).astype(np.float32)
375
  w = max(1, int(round(win_ms * 1e-3 * sr)))
376
  if w > 1:
377
  kern = np.ones(w, dtype=np.float32) / float(w)
378
  x = np.convolve(x, kern, mode="same")
 
379
  d = np.diff(x, prepend=x[:1])
380
  d[d < 0] = 0.0
381
  return d
382
 
383
+ def _estimate_first_offset_samples(ref_loop_wav, gen_head_wav, sr: int, spb: float) -> int:
384
+ """Tempo-aware first-downbeat offset (positive => model late)."""
 
 
 
 
385
  try:
386
+ max_ms = int(max(160.0, min(0.25 * spb * 1000.0, 450.0)))
387
  ref = ref_loop_wav if ref_loop_wav.sample_rate == sr else ref_loop_wav.resample(sr)
388
  n_bar = int(round(spb * sr))
 
389
  ref_tail = ref.samples[-n_bar:, :] if ref.samples.shape[0] >= n_bar else ref.samples
390
+ gen_head = gen_head_wav.samples[: int(2 * n_bar), :]
391
  if ref_tail.size == 0 or gen_head.size == 0:
392
  return 0
393
 
394
+ # envelopes + z-score
395
+ import numpy as np
 
 
396
  def _z(a):
397
+ m, s = float(a.mean()), float(a.std() or 1.0); return (a - m) / s
398
+ e_ref = _z(_mono_env(ref_tail, sr)).astype(np.float32)
399
+ e_gen = _z(_mono_env(gen_head, sr)).astype(np.float32)
 
400
 
401
+ # upsample x4 for finer lag
402
  def _upsample(a, r=4):
403
+ n = len(a); grid = np.arange(n, dtype=np.float32)
 
404
  fine = np.linspace(0, n - 1, num=n * r, dtype=np.float32)
405
  return np.interp(fine, grid, a).astype(np.float32)
406
  up = 4
407
+ e_ref_u, e_gen_u = _upsample(e_ref, up), _upsample(e_gen, up)
 
408
 
 
409
  max_lag_u = int(round((max_ms / 1000.0) * sr * up))
410
  seg = min(len(e_ref_u), len(e_gen_u))
411
  e_ref_u = e_ref_u[-seg:]
 
412
  pad = np.zeros(max_lag_u, dtype=np.float32)
413
  e_gen_u_pad = np.concatenate([pad, e_gen_u, pad])
414
 
415
  best_lag_u, best_score = 0, -1e9
 
416
  for lag_u in range(-max_lag_u, max_lag_u + 1):
417
  start = max_lag_u + lag_u
418
  b = e_gen_u_pad[start : start + seg]
 
419
  denom = (np.linalg.norm(e_ref_u) * np.linalg.norm(b)) or 1.0
420
  score = float(np.dot(e_ref_u, b) / denom)
421
  if score > best_score:
422
  best_score, best_lag_u = score, lag_u
423
+ return int(round(best_lag_u / up))
 
 
 
424
  except Exception:
425
  return 0
426
 
427
+ print("πŸš€ JamWorker started (bar-aligned streaming)…")
428
 
429
  while not self._stop_event.is_set():
 
430
  if not self._should_generate_next_chunk():
431
+ time.sleep(0.25)
 
432
  continue
433
 
434
+ # 1) Generate until we have enough material in the stream
435
+ need = _need(first_chunk_extra=(self.idx == 0))
436
+ while need > 0 and not self._stop_event.is_set():
437
+ with self._lock:
438
+ style_vec = self.params.style_vec
439
+ self.mrt.guidance_weight = float(self.params.guidance_weight)
440
+ self.mrt.temperature = float(self.params.temperature)
441
+ self.mrt.topk = int(self.params.topk)
 
 
 
 
 
 
 
 
442
  wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
443
+ self._append_model_chunk_to_stream(wav) # equal-power xfade into a persistent stream
444
+ need = _need(first_chunk_extra=(self.idx == 0))
 
445
 
446
  if self._stop_event.is_set():
447
  break
448
 
449
+ # 2) One-time: align the emit pointer to the groove
450
+ if self.idx == 0 and self.params.combined_loop is not None:
451
+ # Compare ref tail vs the head of what we're about to emit
452
+ head_len = min(self._stream.shape[0] - self._next_emit_start, int(round(2 * spb * sr)))
453
+ seg = self._stream[self._next_emit_start : self._next_emit_start + head_len]
454
+ gen_head = au.Waveform(seg.astype(np.float32, copy=False), sr).as_stereo()
455
+ offs = _estimate_first_offset_samples(self.params.combined_loop, gen_head, sr, spb)
456
+ if offs != 0:
457
+ # positive => model late: skip some samples; negative => model early: "rewind" by padding
458
+ self._next_emit_start = max(0, self._next_emit_start + offs)
459
+ print(f"🎯 First-chunk offset compensation: {offs/sr:+.3f}s")
460
+ # snap to next bar boundary
461
+ self._realign_emit_pointer_to_bar(sr)
462
+
463
+ # 3) Emit exactly bars_per_chunk Γ— spb from the stream
464
+ start = self._next_emit_start
465
+ end = start + chunk_samps
466
+ if end > self._stream.shape[0]:
467
+ # shouldn't happen often; generate a bit more and loop
468
+ continue
469
+
470
+ slice_ = self._stream[start:end]
471
+ self._next_emit_start = end
472
 
473
+ y = au.Waveform(slice_.astype(np.float32, copy=False), sr).as_stereo()
474
+
475
+ # 4) Post-processing / loudness
476
+ if self.idx == 0 and self.params.ref_loop is not None:
 
 
 
 
 
 
 
 
 
 
 
 
 
477
  y, _ = match_loudness_to_reference(
478
  self.params.ref_loop, y,
479
  method=self.params.loudness_mode,
 
482
  else:
483
  apply_micro_fades(y, 3)
484
 
485
+ # 5) Resample + exact-length snap + encode
486
  b64, meta = self._snap_and_encode(
487
+ y, seconds=chunk_secs, target_sr=self.params.target_sr, bars=self.params.bars_per_chunk
 
 
 
488
  )
489
+ meta["xfade_seconds"] = xfade
490
 
491
+ # 6) Publish
492
  with self._lock:
493
+ self.idx += 1
494
+ self.outbox.append(JamChunk(index=self.idx, audio_base64=b64, metadata=meta))
495
  if len(self.outbox) > 10:
496
  cutoff = self._last_delivered_index - 5
497
  self.outbox = [ch for ch in self.outbox if ch.index > cutoff]
498
 
499
+ print(f"βœ… Completed chunk {self.idx}")
 
500
 
501
  print("πŸ›‘ JamWorker stopped")
502