thecollabagepatch commited on
Commit
9fb3c06
·
1 Parent(s): c4dc2c2

aligning to grid better

Browse files
Files changed (1) hide show
  1. jam_worker.py +68 -49
jam_worker.py CHANGED
@@ -355,61 +355,77 @@ class JamWorker(threading.Thread):
355
  chunk_secs = self.params.bars_per_chunk * spb
356
  xfade = float(self.mrt.config.crossfade_length) # seconds
357
 
358
- # ---- tiny helper: mono + simple envelope ----
359
- def _mono_env(x: np.ndarray, sr: int, win_ms: float = 20.0) -> np.ndarray:
360
  if x.ndim == 2:
361
  x = x.mean(axis=1)
362
  x = np.abs(x).astype(np.float32)
363
  w = max(1, int(round(win_ms * 1e-3 * sr)))
364
- if w == 1:
365
- return x
366
- kern = np.ones(w, dtype=np.float32) / float(w)
367
- # moving average (same length)
368
- return np.convolve(x, kern, mode="same")
369
-
370
- # ---- estimate how late the first downbeat is (<= max_ms) ----
371
- def _estimate_first_offset_samples(ref_loop_wav, gen_wav, sr: int, max_ms: int = 120) -> int:
 
 
 
 
 
 
372
  try:
373
- # resample ref to model SR if needed
374
- ref = ref_loop_wav
375
- if ref.sample_rate != sr:
376
- ref = ref.resample(sr)
377
- # last 1 bar of the reference (what the model just "heard")
378
  n_bar = int(round(spb * sr))
 
379
  ref_tail = ref.samples[-n_bar:, :] if ref.samples.shape[0] >= n_bar else ref.samples
380
- # first 2 bars of the generated chunk (search window)
381
  gen_head = gen_wav.samples[: int(2 * n_bar), :]
382
  if ref_tail.size == 0 or gen_head.size == 0:
383
  return 0
384
 
385
- # envelopes
386
- e_ref = _mono_env(ref_tail, sr)
387
- e_gen = _mono_env(gen_head, sr)
388
-
389
- max_lag = int(round((max_ms / 1000.0) * sr))
390
- # ensure the window is long enough
391
- seg = min(len(e_ref), len(e_gen))
392
- e_ref = e_ref[-seg:]
393
- e_gen = e_gen[: seg + max_lag] # allow positive lag (gen late)
394
-
395
- if len(e_gen) < seg:
396
- return 0
397
-
398
- # brute-force short-range correlation (gen late => positive lag)
399
- best_lag = 0
400
- best_score = -1e9
401
- for lag in range(0, max_lag + 1):
402
- a = e_ref
403
- b = e_gen[lag : lag + seg]
404
- if len(b) != seg:
405
- break
406
- # normalized dot to be robust-ish
407
- denom = (np.linalg.norm(a) * np.linalg.norm(b)) or 1.0
408
- score = float(np.dot(a, b) / denom)
 
 
 
 
 
 
 
 
 
 
 
 
409
  if score > best_score:
410
- best_score = score
411
- best_lag = lag
412
- return int(best_lag)
 
 
413
  except Exception:
414
  return 0
415
 
@@ -453,13 +469,16 @@ class JamWorker(threading.Thread):
453
  # ---- ONE-TIME: grid-align the very first jam chunk to kill the flam ----
454
  if next_idx == 1 and self.params.combined_loop is not None:
455
  offset = _estimate_first_offset_samples(
456
- self.params.combined_loop, y, int(self.mrt.sample_rate), max_ms=120
457
  )
458
- if offset > 0:
459
- # Trim the head by the detected offset; we'll snap length later
460
- y.samples = y.samples[offset:, :]
461
- print(f"🎯 First-chunk offset compensation: -{offset/self.mrt.sample_rate:.3f}s")
462
- # hard trim again (defensive), remaining length exactness happens in _snap_and_encode
 
 
 
463
  y = hard_trim_seconds(y, chunk_secs)
464
 
465
  # ---- Post-processing ----
 
355
  chunk_secs = self.params.bars_per_chunk * spb
356
  xfade = float(self.mrt.config.crossfade_length) # seconds
357
 
358
+ def _mono_env(x: np.ndarray, sr: int, win_ms: float = 10.0) -> np.ndarray:
359
+ """Rectified moving-average envelope, then a simple onset-y novelty (half-wave diff)."""
360
  if x.ndim == 2:
361
  x = x.mean(axis=1)
362
  x = np.abs(x).astype(np.float32)
363
  w = max(1, int(round(win_ms * 1e-3 * sr)))
364
+ if w > 1:
365
+ kern = np.ones(w, dtype=np.float32) / float(w)
366
+ x = np.convolve(x, kern, mode="same")
367
+ # onset-ish novelty: positive first difference (half-wave)
368
+ d = np.diff(x, prepend=x[:1])
369
+ d[d < 0] = 0.0
370
+ return d
371
+
372
+ def _estimate_first_offset_samples(ref_loop_wav, gen_wav, sr: int, spb: float, max_ms: int = 180) -> int:
373
+ """
374
+ Estimate how late/early the first downbeat is by correlating
375
+ the last bar of the reference vs the first two bars of the generated chunk.
376
+ Allows small +/- offsets; upsample envelopes x4 for sub-sample precision then round.
377
+ """
378
  try:
379
+ ref = ref_loop_wav if ref_loop_wav.sample_rate == sr else ref_loop_wav.resample(sr)
 
 
 
 
380
  n_bar = int(round(spb * sr))
381
+
382
  ref_tail = ref.samples[-n_bar:, :] if ref.samples.shape[0] >= n_bar else ref.samples
 
383
  gen_head = gen_wav.samples[: int(2 * n_bar), :]
384
  if ref_tail.size == 0 or gen_head.size == 0:
385
  return 0
386
 
387
+ e_ref = _mono_env(ref_tail, sr) # length ~ n_bar
388
+ e_gen = _mono_env(gen_head, sr) # length ~ 2*n_bar
389
+
390
+ # z-score for scale invariance
391
+ def _z(a):
392
+ m, s = float(a.mean()), float(a.std() or 1.0)
393
+ return (a - m) / s
394
+ e_ref = _z(e_ref).astype(np.float32)
395
+ e_gen = _z(e_gen).astype(np.float32)
396
+
397
+ # Light upsampling for finer lag resolution (x4)
398
+ def _upsample(a, r=4):
399
+ n = len(a)
400
+ grid = np.arange(n, dtype=np.float32)
401
+ fine = np.linspace(0, n - 1, num=n * r, dtype=np.float32)
402
+ return np.interp(fine, grid, a).astype(np.float32)
403
+ up = 4
404
+ e_ref_u = _upsample(e_ref, up)
405
+ e_gen_u = _upsample(e_gen, up)
406
+
407
+ # Correlate in a tight window
408
+ max_lag_u = int(round((max_ms / 1000.0) * sr * up))
409
+ seg = min(len(e_ref_u), len(e_gen_u))
410
+ e_ref_u = e_ref_u[-seg:]
411
+ # pad head so we can slide +/- lags
412
+ pad = np.zeros(max_lag_u, dtype=np.float32)
413
+ e_gen_u_pad = np.concatenate([pad, e_gen_u, pad])
414
+
415
+ best_lag_u, best_score = 0, -1e9
416
+ # allow tiny early OR late (negative = model early, positive = late)
417
+ for lag_u in range(-max_lag_u, max_lag_u + 1):
418
+ start = max_lag_u + lag_u
419
+ b = e_gen_u_pad[start : start + seg]
420
+ # normalized dot (already z-scored, but keep it consistent)
421
+ denom = (np.linalg.norm(e_ref_u) * np.linalg.norm(b)) or 1.0
422
+ score = float(np.dot(e_ref_u, b) / denom)
423
  if score > best_score:
424
+ best_score, best_lag_u = score, lag_u
425
+
426
+ # convert envelope-lag back to audio samples and round
427
+ lag_samples = int(round(best_lag_u / up))
428
+ return lag_samples
429
  except Exception:
430
  return 0
431
 
 
469
  # ---- ONE-TIME: grid-align the very first jam chunk to kill the flam ----
470
  if next_idx == 1 and self.params.combined_loop is not None:
471
  offset = _estimate_first_offset_samples(
472
+ self.params.combined_loop, y, int(self.mrt.sample_rate), spb, max_ms=180 # try 160–200
473
  )
474
+ if offset != 0:
475
+ # positive => model late: trim head; negative => model early: pad head (rare)
476
+ if offset > 0:
477
+ y.samples = y.samples[offset:, :]
478
+ else:
479
+ pad = np.zeros((abs(offset), y.samples.shape[1]), dtype=y.samples.dtype)
480
+ y.samples = np.concatenate([pad, y.samples], axis=0)
481
+ print(f"🎯 First-chunk offset compensation: {offset/self.mrt.sample_rate:+.3f}s")
482
  y = hard_trim_seconds(y, chunk_secs)
483
 
484
  # ---- Post-processing ----