magenta / jam_worker.py
thecollabagepatch's picture
refresh context
53cce5a
raw
history blame
7.74 kB
# jam_worker.py
import threading, time, base64, io, uuid
from dataclasses import dataclass, field
import numpy as np
import soundfile as sf
# Pull in your helpers from app.py or refactor them into a shared utils module.
from utils import (
match_loudness_to_reference, stitch_generated, hard_trim_seconds,
apply_micro_fades, make_bar_aligned_context, take_bar_aligned_tail,
resample_and_snap, wav_bytes_base64
)
from scipy.signal import resample_poly
from math import gcd
@dataclass
class JamParams:
bpm: float
beats_per_bar: int
bars_per_chunk: int
target_sr: int
loudness_mode: str = "auto"
headroom_db: float = 1.0
style_vec: np.ndarray | None = None
ref_loop: any = None # au.Waveform at model SR for 1st-chunk loudness
combined_loop: any = None # NEW: Full combined audio for context setup
guidance_weight: float = 1.1
temperature: float = 1.1
topk: int = 40
@dataclass
class JamChunk:
index: int
audio_base64: str
metadata: dict
class JamWorker(threading.Thread):
def __init__(self, mrt, params: JamParams):
super().__init__(daemon=True)
self.mrt = mrt
self.params = params
# Initialize fresh state
self.state = mrt.init_state()
# CRITICAL: Set up fresh context from the new combined audio
if params.combined_loop is not None:
self._setup_context_from_combined_loop()
self.idx = 0
self.outbox: list[JamChunk] = []
self._stop_event = threading.Event()
self.last_chunk_started_at = None
self.last_chunk_completed_at = None
self._lock = threading.Lock()
def _setup_context_from_combined_loop(self):
"""Set up MRT context tokens from the combined loop audio"""
try:
# Import the utility functions (same as used in main generation)
from utils import make_bar_aligned_context, take_bar_aligned_tail
# Extract context from combined loop (same logic as generate_loop_continuation_with_mrt)
codec_fps = float(self.mrt.codec.frame_rate)
ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps
# Take tail portion for context (matches main generation)
loop_for_context = take_bar_aligned_tail(
self.params.combined_loop,
self.params.bpm,
self.params.beats_per_bar,
ctx_seconds
)
# Encode to tokens
tokens_full = self.mrt.codec.encode(loop_for_context).astype(np.int32)
tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
# Create bar-aligned context
context_tokens = make_bar_aligned_context(
tokens,
bpm=self.params.bpm,
fps=int(self.mrt.codec.frame_rate),
ctx_frames=self.mrt.config.context_length_frames,
beats_per_bar=self.params.beats_per_bar
)
# Set context on state - this is the key fix!
self.state.context_tokens = context_tokens
print(f"✅ JamWorker: Set up fresh context from combined loop")
print(f" Context shape: {context_tokens.shape if context_tokens is not None else None}")
except Exception as e:
print(f"❌ Failed to setup context from combined loop: {e}")
# Continue without context rather than crashing
def stop(self):
self._stop_event.set()
def update_style(self, new_style_vec: np.ndarray | None):
with self._lock:
if new_style_vec is not None:
self.params.style_vec = new_style_vec
def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None):
with self._lock:
if guidance_weight is not None: self.params.guidance_weight = float(guidance_weight)
if temperature is not None: self.params.temperature = float(temperature)
if topk is not None: self.params.topk = int(topk)
def _seconds_per_bar(self) -> float:
return self.params.beats_per_bar * (60.0 / self.params.bpm)
def _snap_and_encode(self, y, seconds, target_sr, bars):
cur_sr = int(self.mrt.sample_rate)
x = y.samples if y.samples.ndim == 2 else y.samples[:, None]
x = resample_and_snap(x, cur_sr=cur_sr, target_sr=target_sr, seconds=seconds)
b64, total_samples, channels = wav_bytes_base64(x, target_sr)
meta = {
"bpm": int(round(self.params.bpm)),
"bars": int(bars),
"beats_per_bar": int(self.params.beats_per_bar),
"sample_rate": int(target_sr),
"channels": channels,
"total_samples": total_samples,
"seconds_per_bar": self._seconds_per_bar(),
"loop_duration_seconds": bars * self._seconds_per_bar(),
"guidance_weight": self.params.guidance_weight,
"temperature": self.params.temperature,
"topk": self.params.topk,
}
return b64, meta
def run(self):
spb = self._seconds_per_bar()
chunk_secs = self.params.bars_per_chunk * spb
xfade = self.mrt.config.crossfade_length
# Prime: set initial context on state (caller should have done this; safe to re-set here)
# NOTE: We assume caller passed a style_vec computed from tail/whole/blend.
while not self._stop_event.is_set():
# honor live knob updates atomically
with self._lock:
style_vec = self.params.style_vec
# Temporarily override MRT knobs (thread-local overrides)
self.mrt.guidance_weight = self.params.guidance_weight
self.mrt.temperature = self.params.temperature
self.mrt.topk = self.params.topk
# 1) generate enough model chunks to cover chunk_secs
need = chunk_secs
chunks = []
self.last_chunk_started_at = time.time()
while need > 0 and not self._stop_event.is_set():
wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
chunks.append(wav)
# model chunk length (seconds) at model SR
need -= (wav.samples.shape[0] / float(self.mrt.sample_rate))
if self._stop_event.is_set():
break
# 2) stitch and trim to exact seconds at model SR
y = stitch_generated(chunks, self.mrt.sample_rate, xfade).as_stereo()
y = hard_trim_seconds(y, chunk_secs)
# 3) post-process
if self.idx == 0 and self.params.ref_loop is not None:
y, _ = match_loudness_to_reference(self.params.ref_loop, y,
method=self.params.loudness_mode,
headroom_db=self.params.headroom_db)
else:
apply_micro_fades(y, 3)
# 4) resample + snap + b64
b64, meta = self._snap_and_encode(y, seconds=chunk_secs,
target_sr=self.params.target_sr,
bars=self.params.bars_per_chunk)
# 5) enqueue
with self._lock:
self.idx += 1
self.outbox.append(JamChunk(index=self.idx, audio_base64=b64, metadata=meta))
self.last_chunk_completed_at = time.time()
# optional: cleanup here if needed