Commit
·
9fb3c06
1
Parent(s):
c4dc2c2
aligning to grid better
Browse files- jam_worker.py +68 -49
jam_worker.py
CHANGED
@@ -355,61 +355,77 @@ class JamWorker(threading.Thread):
|
|
355 |
chunk_secs = self.params.bars_per_chunk * spb
|
356 |
xfade = float(self.mrt.config.crossfade_length) # seconds
|
357 |
|
358 |
-
|
359 |
-
|
360 |
if x.ndim == 2:
|
361 |
x = x.mean(axis=1)
|
362 |
x = np.abs(x).astype(np.float32)
|
363 |
w = max(1, int(round(win_ms * 1e-3 * sr)))
|
364 |
-
if w
|
365 |
-
|
366 |
-
|
367 |
-
#
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
372 |
try:
|
373 |
-
|
374 |
-
ref = ref_loop_wav
|
375 |
-
if ref.sample_rate != sr:
|
376 |
-
ref = ref.resample(sr)
|
377 |
-
# last 1 bar of the reference (what the model just "heard")
|
378 |
n_bar = int(round(spb * sr))
|
|
|
379 |
ref_tail = ref.samples[-n_bar:, :] if ref.samples.shape[0] >= n_bar else ref.samples
|
380 |
-
# first 2 bars of the generated chunk (search window)
|
381 |
gen_head = gen_wav.samples[: int(2 * n_bar), :]
|
382 |
if ref_tail.size == 0 or gen_head.size == 0:
|
383 |
return 0
|
384 |
|
385 |
-
#
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
e_ref = e_ref
|
393 |
-
e_gen = e_gen
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
409 |
if score > best_score:
|
410 |
-
best_score = score
|
411 |
-
|
412 |
-
|
|
|
|
|
413 |
except Exception:
|
414 |
return 0
|
415 |
|
@@ -453,13 +469,16 @@ class JamWorker(threading.Thread):
|
|
453 |
# ---- ONE-TIME: grid-align the very first jam chunk to kill the flam ----
|
454 |
if next_idx == 1 and self.params.combined_loop is not None:
|
455 |
offset = _estimate_first_offset_samples(
|
456 |
-
self.params.combined_loop, y, int(self.mrt.sample_rate), max_ms=
|
457 |
)
|
458 |
-
if offset
|
459 |
-
#
|
460 |
-
|
461 |
-
|
462 |
-
|
|
|
|
|
|
|
463 |
y = hard_trim_seconds(y, chunk_secs)
|
464 |
|
465 |
# ---- Post-processing ----
|
|
|
355 |
chunk_secs = self.params.bars_per_chunk * spb
|
356 |
xfade = float(self.mrt.config.crossfade_length) # seconds
|
357 |
|
358 |
+
def _mono_env(x: np.ndarray, sr: int, win_ms: float = 10.0) -> np.ndarray:
|
359 |
+
"""Rectified moving-average envelope, then a simple onset-y novelty (half-wave diff)."""
|
360 |
if x.ndim == 2:
|
361 |
x = x.mean(axis=1)
|
362 |
x = np.abs(x).astype(np.float32)
|
363 |
w = max(1, int(round(win_ms * 1e-3 * sr)))
|
364 |
+
if w > 1:
|
365 |
+
kern = np.ones(w, dtype=np.float32) / float(w)
|
366 |
+
x = np.convolve(x, kern, mode="same")
|
367 |
+
# onset-ish novelty: positive first difference (half-wave)
|
368 |
+
d = np.diff(x, prepend=x[:1])
|
369 |
+
d[d < 0] = 0.0
|
370 |
+
return d
|
371 |
+
|
372 |
+
def _estimate_first_offset_samples(ref_loop_wav, gen_wav, sr: int, spb: float, max_ms: int = 180) -> int:
|
373 |
+
"""
|
374 |
+
Estimate how late/early the first downbeat is by correlating
|
375 |
+
the last bar of the reference vs the first two bars of the generated chunk.
|
376 |
+
Allows small +/- offsets; upsample envelopes x4 for sub-sample precision then round.
|
377 |
+
"""
|
378 |
try:
|
379 |
+
ref = ref_loop_wav if ref_loop_wav.sample_rate == sr else ref_loop_wav.resample(sr)
|
|
|
|
|
|
|
|
|
380 |
n_bar = int(round(spb * sr))
|
381 |
+
|
382 |
ref_tail = ref.samples[-n_bar:, :] if ref.samples.shape[0] >= n_bar else ref.samples
|
|
|
383 |
gen_head = gen_wav.samples[: int(2 * n_bar), :]
|
384 |
if ref_tail.size == 0 or gen_head.size == 0:
|
385 |
return 0
|
386 |
|
387 |
+
e_ref = _mono_env(ref_tail, sr) # length ~ n_bar
|
388 |
+
e_gen = _mono_env(gen_head, sr) # length ~ 2*n_bar
|
389 |
+
|
390 |
+
# z-score for scale invariance
|
391 |
+
def _z(a):
|
392 |
+
m, s = float(a.mean()), float(a.std() or 1.0)
|
393 |
+
return (a - m) / s
|
394 |
+
e_ref = _z(e_ref).astype(np.float32)
|
395 |
+
e_gen = _z(e_gen).astype(np.float32)
|
396 |
+
|
397 |
+
# Light upsampling for finer lag resolution (x4)
|
398 |
+
def _upsample(a, r=4):
|
399 |
+
n = len(a)
|
400 |
+
grid = np.arange(n, dtype=np.float32)
|
401 |
+
fine = np.linspace(0, n - 1, num=n * r, dtype=np.float32)
|
402 |
+
return np.interp(fine, grid, a).astype(np.float32)
|
403 |
+
up = 4
|
404 |
+
e_ref_u = _upsample(e_ref, up)
|
405 |
+
e_gen_u = _upsample(e_gen, up)
|
406 |
+
|
407 |
+
# Correlate in a tight window
|
408 |
+
max_lag_u = int(round((max_ms / 1000.0) * sr * up))
|
409 |
+
seg = min(len(e_ref_u), len(e_gen_u))
|
410 |
+
e_ref_u = e_ref_u[-seg:]
|
411 |
+
# pad head so we can slide +/- lags
|
412 |
+
pad = np.zeros(max_lag_u, dtype=np.float32)
|
413 |
+
e_gen_u_pad = np.concatenate([pad, e_gen_u, pad])
|
414 |
+
|
415 |
+
best_lag_u, best_score = 0, -1e9
|
416 |
+
# allow tiny early OR late (negative = model early, positive = late)
|
417 |
+
for lag_u in range(-max_lag_u, max_lag_u + 1):
|
418 |
+
start = max_lag_u + lag_u
|
419 |
+
b = e_gen_u_pad[start : start + seg]
|
420 |
+
# normalized dot (already z-scored, but keep it consistent)
|
421 |
+
denom = (np.linalg.norm(e_ref_u) * np.linalg.norm(b)) or 1.0
|
422 |
+
score = float(np.dot(e_ref_u, b) / denom)
|
423 |
if score > best_score:
|
424 |
+
best_score, best_lag_u = score, lag_u
|
425 |
+
|
426 |
+
# convert envelope-lag back to audio samples and round
|
427 |
+
lag_samples = int(round(best_lag_u / up))
|
428 |
+
return lag_samples
|
429 |
except Exception:
|
430 |
return 0
|
431 |
|
|
|
469 |
# ---- ONE-TIME: grid-align the very first jam chunk to kill the flam ----
|
470 |
if next_idx == 1 and self.params.combined_loop is not None:
|
471 |
offset = _estimate_first_offset_samples(
|
472 |
+
self.params.combined_loop, y, int(self.mrt.sample_rate), spb, max_ms=180 # try 160–200
|
473 |
)
|
474 |
+
if offset != 0:
|
475 |
+
# positive => model late: trim head; negative => model early: pad head (rare)
|
476 |
+
if offset > 0:
|
477 |
+
y.samples = y.samples[offset:, :]
|
478 |
+
else:
|
479 |
+
pad = np.zeros((abs(offset), y.samples.shape[1]), dtype=y.samples.dtype)
|
480 |
+
y.samples = np.concatenate([pad, y.samples], axis=0)
|
481 |
+
print(f"🎯 First-chunk offset compensation: {offset/self.mrt.sample_rate:+.3f}s")
|
482 |
y = hard_trim_seconds(y, chunk_secs)
|
483 |
|
484 |
# ---- Post-processing ----
|