Commit
Β·
eb5d99b
1
Parent(s):
d41a575
rewritten run method in jam worker
Browse files- jam_worker.py +77 -90
jam_worker.py
CHANGED
@@ -50,6 +50,9 @@ class JamWorker(threading.Thread):
|
|
50 |
self.outbox: list[JamChunk] = []
|
51 |
self._stop_event = threading.Event()
|
52 |
|
|
|
|
|
|
|
53 |
# NEW: Track delivery state
|
54 |
self._last_delivered_index = 0
|
55 |
self._max_buffer_ahead = 5
|
@@ -350,139 +353,127 @@ class JamWorker(threading.Thread):
|
|
350 |
self._pending_drop_intro_bars = getattr(self, "_pending_drop_intro_bars", 0) + 1
|
351 |
|
352 |
def run(self):
|
353 |
-
"""Main worker loop
|
354 |
-
spb = self._seconds_per_bar()
|
355 |
chunk_secs = self.params.bars_per_chunk * spb
|
356 |
-
xfade = float(self.mrt.config.crossfade_length)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
357 |
|
358 |
def _mono_env(x: np.ndarray, sr: int, win_ms: float = 10.0) -> np.ndarray:
|
359 |
-
|
360 |
-
if x.ndim == 2:
|
361 |
-
x = x.mean(axis=1)
|
362 |
x = np.abs(x).astype(np.float32)
|
363 |
w = max(1, int(round(win_ms * 1e-3 * sr)))
|
364 |
if w > 1:
|
365 |
kern = np.ones(w, dtype=np.float32) / float(w)
|
366 |
x = np.convolve(x, kern, mode="same")
|
367 |
-
# onset-ish novelty: positive first difference (half-wave)
|
368 |
d = np.diff(x, prepend=x[:1])
|
369 |
d[d < 0] = 0.0
|
370 |
return d
|
371 |
|
372 |
-
def _estimate_first_offset_samples(ref_loop_wav,
|
373 |
-
"""
|
374 |
-
Estimate how late/early the first downbeat is by correlating
|
375 |
-
the last bar of the reference vs the first two bars of the generated chunk.
|
376 |
-
Allows small +/- offsets; upsample envelopes x4 for sub-sample precision then round.
|
377 |
-
"""
|
378 |
try:
|
|
|
379 |
ref = ref_loop_wav if ref_loop_wav.sample_rate == sr else ref_loop_wav.resample(sr)
|
380 |
n_bar = int(round(spb * sr))
|
381 |
-
|
382 |
ref_tail = ref.samples[-n_bar:, :] if ref.samples.shape[0] >= n_bar else ref.samples
|
383 |
-
gen_head =
|
384 |
if ref_tail.size == 0 or gen_head.size == 0:
|
385 |
return 0
|
386 |
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
# z-score for scale invariance
|
391 |
def _z(a):
|
392 |
-
m, s = float(a.mean()), float(a.std() or 1.0)
|
393 |
-
|
394 |
-
|
395 |
-
e_gen = _z(e_gen).astype(np.float32)
|
396 |
|
397 |
-
#
|
398 |
def _upsample(a, r=4):
|
399 |
-
n = len(a)
|
400 |
-
grid = np.arange(n, dtype=np.float32)
|
401 |
fine = np.linspace(0, n - 1, num=n * r, dtype=np.float32)
|
402 |
return np.interp(fine, grid, a).astype(np.float32)
|
403 |
up = 4
|
404 |
-
e_ref_u = _upsample(e_ref, up)
|
405 |
-
e_gen_u = _upsample(e_gen, up)
|
406 |
|
407 |
-
# Correlate in a tight window
|
408 |
max_lag_u = int(round((max_ms / 1000.0) * sr * up))
|
409 |
seg = min(len(e_ref_u), len(e_gen_u))
|
410 |
e_ref_u = e_ref_u[-seg:]
|
411 |
-
# pad head so we can slide +/- lags
|
412 |
pad = np.zeros(max_lag_u, dtype=np.float32)
|
413 |
e_gen_u_pad = np.concatenate([pad, e_gen_u, pad])
|
414 |
|
415 |
best_lag_u, best_score = 0, -1e9
|
416 |
-
# allow tiny early OR late (negative = model early, positive = late)
|
417 |
for lag_u in range(-max_lag_u, max_lag_u + 1):
|
418 |
start = max_lag_u + lag_u
|
419 |
b = e_gen_u_pad[start : start + seg]
|
420 |
-
# normalized dot (already z-scored, but keep it consistent)
|
421 |
denom = (np.linalg.norm(e_ref_u) * np.linalg.norm(b)) or 1.0
|
422 |
score = float(np.dot(e_ref_u, b) / denom)
|
423 |
if score > best_score:
|
424 |
best_score, best_lag_u = score, lag_u
|
425 |
-
|
426 |
-
# convert envelope-lag back to audio samples and round
|
427 |
-
lag_samples = int(round(best_lag_u / up))
|
428 |
-
return lag_samples
|
429 |
except Exception:
|
430 |
return 0
|
431 |
|
432 |
-
print("π JamWorker started
|
433 |
|
434 |
while not self._stop_event.is_set():
|
435 |
-
# Donβt get too far ahead of the consumer
|
436 |
if not self._should_generate_next_chunk():
|
437 |
-
|
438 |
-
time.sleep(0.5)
|
439 |
continue
|
440 |
|
441 |
-
#
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
print(f"πΉ Generating chunk {next_idx}...")
|
450 |
-
self.last_chunk_started_at = time.time()
|
451 |
-
|
452 |
-
# ---- Generate enough model sub-chunks to yield *audible* chunk_secs ----
|
453 |
-
# First sub-chunk contributes full L; subsequent contribute (L - xfade)
|
454 |
-
assembled = 0.0
|
455 |
-
chunks = []
|
456 |
-
while assembled < chunk_secs and not self._stop_event.is_set():
|
457 |
wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
|
458 |
-
|
459 |
-
|
460 |
-
assembled += L if len(chunks) == 1 else max(0.0, L - xfade)
|
461 |
|
462 |
if self._stop_event.is_set():
|
463 |
break
|
464 |
|
465 |
-
#
|
466 |
-
|
467 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
468 |
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
)
|
474 |
-
if offset != 0:
|
475 |
-
# positive => model late: trim head; negative => model early: pad head (rare)
|
476 |
-
if offset > 0:
|
477 |
-
y.samples = y.samples[offset:, :]
|
478 |
-
else:
|
479 |
-
pad = np.zeros((abs(offset), y.samples.shape[1]), dtype=y.samples.dtype)
|
480 |
-
y.samples = np.concatenate([pad, y.samples], axis=0)
|
481 |
-
print(f"π― First-chunk offset compensation: {offset/self.mrt.sample_rate:+.3f}s")
|
482 |
-
y = hard_trim_seconds(y, chunk_secs)
|
483 |
-
|
484 |
-
# ---- Post-processing ----
|
485 |
-
if next_idx == 1 and self.params.ref_loop is not None:
|
486 |
y, _ = match_loudness_to_reference(
|
487 |
self.params.ref_loop, y,
|
488 |
method=self.params.loudness_mode,
|
@@ -491,25 +482,21 @@ class JamWorker(threading.Thread):
|
|
491 |
else:
|
492 |
apply_micro_fades(y, 3)
|
493 |
|
494 |
-
#
|
495 |
b64, meta = self._snap_and_encode(
|
496 |
-
y,
|
497 |
-
seconds=chunk_secs,
|
498 |
-
target_sr=self.params.target_sr,
|
499 |
-
bars=self.params.bars_per_chunk
|
500 |
)
|
501 |
-
meta["xfade_seconds"] = xfade
|
502 |
|
503 |
-
#
|
504 |
with self._lock:
|
505 |
-
self.idx
|
506 |
-
self.outbox.append(JamChunk(index=
|
507 |
if len(self.outbox) > 10:
|
508 |
cutoff = self._last_delivered_index - 5
|
509 |
self.outbox = [ch for ch in self.outbox if ch.index > cutoff]
|
510 |
|
511 |
-
|
512 |
-
print(f"β
Completed chunk {next_idx}")
|
513 |
|
514 |
print("π JamWorker stopped")
|
515 |
|
|
|
50 |
self.outbox: list[JamChunk] = []
|
51 |
self._stop_event = threading.Event()
|
52 |
|
53 |
+
self._stream = None
|
54 |
+
self._next_emit_start = 0
|
55 |
+
|
56 |
# NEW: Track delivery state
|
57 |
self._last_delivered_index = 0
|
58 |
self._max_buffer_ahead = 5
|
|
|
353 |
self._pending_drop_intro_bars = getattr(self, "_pending_drop_intro_bars", 0) + 1
|
354 |
|
355 |
def run(self):
|
356 |
+
"""Main worker loop β generate into a continuous stream, then emit bar-aligned slices."""
|
357 |
+
spb = self._seconds_per_bar() # seconds per bar
|
358 |
chunk_secs = self.params.bars_per_chunk * spb
|
359 |
+
xfade = float(self.mrt.config.crossfade_length) # seconds
|
360 |
+
sr = int(self.mrt.sample_rate)
|
361 |
+
chunk_samps = int(round(chunk_secs * sr))
|
362 |
+
|
363 |
+
def _need(first_chunk_extra=False):
|
364 |
+
"""How many more samples we still need in the stream to emit next slice."""
|
365 |
+
have = 0 if getattr(self, "_stream", None) is None else self._stream.shape[0] - getattr(self, "_next_emit_start", 0)
|
366 |
+
want = chunk_samps
|
367 |
+
if first_chunk_extra:
|
368 |
+
# reserve two bars extra so first-chunk onset alignment has material
|
369 |
+
want += int(round(2 * spb * sr))
|
370 |
+
return max(0, want - have)
|
371 |
|
372 |
def _mono_env(x: np.ndarray, sr: int, win_ms: float = 10.0) -> np.ndarray:
|
373 |
+
if x.ndim == 2: x = x.mean(axis=1)
|
|
|
|
|
374 |
x = np.abs(x).astype(np.float32)
|
375 |
w = max(1, int(round(win_ms * 1e-3 * sr)))
|
376 |
if w > 1:
|
377 |
kern = np.ones(w, dtype=np.float32) / float(w)
|
378 |
x = np.convolve(x, kern, mode="same")
|
|
|
379 |
d = np.diff(x, prepend=x[:1])
|
380 |
d[d < 0] = 0.0
|
381 |
return d
|
382 |
|
383 |
+
def _estimate_first_offset_samples(ref_loop_wav, gen_head_wav, sr: int, spb: float) -> int:
|
384 |
+
"""Tempo-aware first-downbeat offset (positive => model late)."""
|
|
|
|
|
|
|
|
|
385 |
try:
|
386 |
+
max_ms = int(max(160.0, min(0.25 * spb * 1000.0, 450.0)))
|
387 |
ref = ref_loop_wav if ref_loop_wav.sample_rate == sr else ref_loop_wav.resample(sr)
|
388 |
n_bar = int(round(spb * sr))
|
|
|
389 |
ref_tail = ref.samples[-n_bar:, :] if ref.samples.shape[0] >= n_bar else ref.samples
|
390 |
+
gen_head = gen_head_wav.samples[: int(2 * n_bar), :]
|
391 |
if ref_tail.size == 0 or gen_head.size == 0:
|
392 |
return 0
|
393 |
|
394 |
+
# envelopes + z-score
|
395 |
+
import numpy as np
|
|
|
|
|
396 |
def _z(a):
|
397 |
+
m, s = float(a.mean()), float(a.std() or 1.0); return (a - m) / s
|
398 |
+
e_ref = _z(_mono_env(ref_tail, sr)).astype(np.float32)
|
399 |
+
e_gen = _z(_mono_env(gen_head, sr)).astype(np.float32)
|
|
|
400 |
|
401 |
+
# upsample x4 for finer lag
|
402 |
def _upsample(a, r=4):
|
403 |
+
n = len(a); grid = np.arange(n, dtype=np.float32)
|
|
|
404 |
fine = np.linspace(0, n - 1, num=n * r, dtype=np.float32)
|
405 |
return np.interp(fine, grid, a).astype(np.float32)
|
406 |
up = 4
|
407 |
+
e_ref_u, e_gen_u = _upsample(e_ref, up), _upsample(e_gen, up)
|
|
|
408 |
|
|
|
409 |
max_lag_u = int(round((max_ms / 1000.0) * sr * up))
|
410 |
seg = min(len(e_ref_u), len(e_gen_u))
|
411 |
e_ref_u = e_ref_u[-seg:]
|
|
|
412 |
pad = np.zeros(max_lag_u, dtype=np.float32)
|
413 |
e_gen_u_pad = np.concatenate([pad, e_gen_u, pad])
|
414 |
|
415 |
best_lag_u, best_score = 0, -1e9
|
|
|
416 |
for lag_u in range(-max_lag_u, max_lag_u + 1):
|
417 |
start = max_lag_u + lag_u
|
418 |
b = e_gen_u_pad[start : start + seg]
|
|
|
419 |
denom = (np.linalg.norm(e_ref_u) * np.linalg.norm(b)) or 1.0
|
420 |
score = float(np.dot(e_ref_u, b) / denom)
|
421 |
if score > best_score:
|
422 |
best_score, best_lag_u = score, lag_u
|
423 |
+
return int(round(best_lag_u / up))
|
|
|
|
|
|
|
424 |
except Exception:
|
425 |
return 0
|
426 |
|
427 |
+
print("π JamWorker started (bar-aligned streaming)β¦")
|
428 |
|
429 |
while not self._stop_event.is_set():
|
|
|
430 |
if not self._should_generate_next_chunk():
|
431 |
+
time.sleep(0.25)
|
|
|
432 |
continue
|
433 |
|
434 |
+
# 1) Generate until we have enough material in the stream
|
435 |
+
need = _need(first_chunk_extra=(self.idx == 0))
|
436 |
+
while need > 0 and not self._stop_event.is_set():
|
437 |
+
with self._lock:
|
438 |
+
style_vec = self.params.style_vec
|
439 |
+
self.mrt.guidance_weight = float(self.params.guidance_weight)
|
440 |
+
self.mrt.temperature = float(self.params.temperature)
|
441 |
+
self.mrt.topk = int(self.params.topk)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
442 |
wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
|
443 |
+
self._append_model_chunk_to_stream(wav) # equal-power xfade into a persistent stream
|
444 |
+
need = _need(first_chunk_extra=(self.idx == 0))
|
|
|
445 |
|
446 |
if self._stop_event.is_set():
|
447 |
break
|
448 |
|
449 |
+
# 2) One-time: align the emit pointer to the groove
|
450 |
+
if self.idx == 0 and self.params.combined_loop is not None:
|
451 |
+
# Compare ref tail vs the head of what we're about to emit
|
452 |
+
head_len = min(self._stream.shape[0] - self._next_emit_start, int(round(2 * spb * sr)))
|
453 |
+
seg = self._stream[self._next_emit_start : self._next_emit_start + head_len]
|
454 |
+
gen_head = au.Waveform(seg.astype(np.float32, copy=False), sr).as_stereo()
|
455 |
+
offs = _estimate_first_offset_samples(self.params.combined_loop, gen_head, sr, spb)
|
456 |
+
if offs != 0:
|
457 |
+
# positive => model late: skip some samples; negative => model early: "rewind" by padding
|
458 |
+
self._next_emit_start = max(0, self._next_emit_start + offs)
|
459 |
+
print(f"π― First-chunk offset compensation: {offs/sr:+.3f}s")
|
460 |
+
# snap to next bar boundary
|
461 |
+
self._realign_emit_pointer_to_bar(sr)
|
462 |
+
|
463 |
+
# 3) Emit exactly bars_per_chunk Γ spb from the stream
|
464 |
+
start = self._next_emit_start
|
465 |
+
end = start + chunk_samps
|
466 |
+
if end > self._stream.shape[0]:
|
467 |
+
# shouldn't happen often; generate a bit more and loop
|
468 |
+
continue
|
469 |
+
|
470 |
+
slice_ = self._stream[start:end]
|
471 |
+
self._next_emit_start = end
|
472 |
|
473 |
+
y = au.Waveform(slice_.astype(np.float32, copy=False), sr).as_stereo()
|
474 |
+
|
475 |
+
# 4) Post-processing / loudness
|
476 |
+
if self.idx == 0 and self.params.ref_loop is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
477 |
y, _ = match_loudness_to_reference(
|
478 |
self.params.ref_loop, y,
|
479 |
method=self.params.loudness_mode,
|
|
|
482 |
else:
|
483 |
apply_micro_fades(y, 3)
|
484 |
|
485 |
+
# 5) Resample + exact-length snap + encode
|
486 |
b64, meta = self._snap_and_encode(
|
487 |
+
y, seconds=chunk_secs, target_sr=self.params.target_sr, bars=self.params.bars_per_chunk
|
|
|
|
|
|
|
488 |
)
|
489 |
+
meta["xfade_seconds"] = xfade
|
490 |
|
491 |
+
# 6) Publish
|
492 |
with self._lock:
|
493 |
+
self.idx += 1
|
494 |
+
self.outbox.append(JamChunk(index=self.idx, audio_base64=b64, metadata=meta))
|
495 |
if len(self.outbox) > 10:
|
496 |
cutoff = self._last_delivered_index - 5
|
497 |
self.outbox = [ch for ch in self.outbox if ch.index > cutoff]
|
498 |
|
499 |
+
print(f"β
Completed chunk {self.idx}")
|
|
|
500 |
|
501 |
print("π JamWorker stopped")
|
502 |
|