Spaces:
Running
Running
File size: 7,739 Bytes
cd609af 53cce5a cd609af 53cce5a cd609af 53cce5a cd609af 53cce5a cd609af 53cce5a cd609af 1889c0a cd609af 53cce5a cd609af 1889c0a cd609af 1889c0a cd609af 1889c0a cd609af 1889c0a cd609af 53cce5a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
# jam_worker.py
import threading, time, base64, io, uuid
from dataclasses import dataclass, field
import numpy as np
import soundfile as sf
# Pull in your helpers from app.py or refactor them into a shared utils module.
from utils import (
match_loudness_to_reference, stitch_generated, hard_trim_seconds,
apply_micro_fades, make_bar_aligned_context, take_bar_aligned_tail,
resample_and_snap, wav_bytes_base64
)
from scipy.signal import resample_poly
from math import gcd
@dataclass
class JamParams:
bpm: float
beats_per_bar: int
bars_per_chunk: int
target_sr: int
loudness_mode: str = "auto"
headroom_db: float = 1.0
style_vec: np.ndarray | None = None
ref_loop: any = None # au.Waveform at model SR for 1st-chunk loudness
combined_loop: any = None # NEW: Full combined audio for context setup
guidance_weight: float = 1.1
temperature: float = 1.1
topk: int = 40
@dataclass
class JamChunk:
index: int
audio_base64: str
metadata: dict
class JamWorker(threading.Thread):
def __init__(self, mrt, params: JamParams):
super().__init__(daemon=True)
self.mrt = mrt
self.params = params
# Initialize fresh state
self.state = mrt.init_state()
# CRITICAL: Set up fresh context from the new combined audio
if params.combined_loop is not None:
self._setup_context_from_combined_loop()
self.idx = 0
self.outbox: list[JamChunk] = []
self._stop_event = threading.Event()
self.last_chunk_started_at = None
self.last_chunk_completed_at = None
self._lock = threading.Lock()
def _setup_context_from_combined_loop(self):
"""Set up MRT context tokens from the combined loop audio"""
try:
# Import the utility functions (same as used in main generation)
from utils import make_bar_aligned_context, take_bar_aligned_tail
# Extract context from combined loop (same logic as generate_loop_continuation_with_mrt)
codec_fps = float(self.mrt.codec.frame_rate)
ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps
# Take tail portion for context (matches main generation)
loop_for_context = take_bar_aligned_tail(
self.params.combined_loop,
self.params.bpm,
self.params.beats_per_bar,
ctx_seconds
)
# Encode to tokens
tokens_full = self.mrt.codec.encode(loop_for_context).astype(np.int32)
tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
# Create bar-aligned context
context_tokens = make_bar_aligned_context(
tokens,
bpm=self.params.bpm,
fps=int(self.mrt.codec.frame_rate),
ctx_frames=self.mrt.config.context_length_frames,
beats_per_bar=self.params.beats_per_bar
)
# Set context on state - this is the key fix!
self.state.context_tokens = context_tokens
print(f"✅ JamWorker: Set up fresh context from combined loop")
print(f" Context shape: {context_tokens.shape if context_tokens is not None else None}")
except Exception as e:
print(f"❌ Failed to setup context from combined loop: {e}")
# Continue without context rather than crashing
def stop(self):
self._stop_event.set()
def update_style(self, new_style_vec: np.ndarray | None):
with self._lock:
if new_style_vec is not None:
self.params.style_vec = new_style_vec
def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None):
with self._lock:
if guidance_weight is not None: self.params.guidance_weight = float(guidance_weight)
if temperature is not None: self.params.temperature = float(temperature)
if topk is not None: self.params.topk = int(topk)
def _seconds_per_bar(self) -> float:
return self.params.beats_per_bar * (60.0 / self.params.bpm)
def _snap_and_encode(self, y, seconds, target_sr, bars):
cur_sr = int(self.mrt.sample_rate)
x = y.samples if y.samples.ndim == 2 else y.samples[:, None]
x = resample_and_snap(x, cur_sr=cur_sr, target_sr=target_sr, seconds=seconds)
b64, total_samples, channels = wav_bytes_base64(x, target_sr)
meta = {
"bpm": int(round(self.params.bpm)),
"bars": int(bars),
"beats_per_bar": int(self.params.beats_per_bar),
"sample_rate": int(target_sr),
"channels": channels,
"total_samples": total_samples,
"seconds_per_bar": self._seconds_per_bar(),
"loop_duration_seconds": bars * self._seconds_per_bar(),
"guidance_weight": self.params.guidance_weight,
"temperature": self.params.temperature,
"topk": self.params.topk,
}
return b64, meta
def run(self):
spb = self._seconds_per_bar()
chunk_secs = self.params.bars_per_chunk * spb
xfade = self.mrt.config.crossfade_length
# Prime: set initial context on state (caller should have done this; safe to re-set here)
# NOTE: We assume caller passed a style_vec computed from tail/whole/blend.
while not self._stop_event.is_set():
# honor live knob updates atomically
with self._lock:
style_vec = self.params.style_vec
# Temporarily override MRT knobs (thread-local overrides)
self.mrt.guidance_weight = self.params.guidance_weight
self.mrt.temperature = self.params.temperature
self.mrt.topk = self.params.topk
# 1) generate enough model chunks to cover chunk_secs
need = chunk_secs
chunks = []
self.last_chunk_started_at = time.time()
while need > 0 and not self._stop_event.is_set():
wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
chunks.append(wav)
# model chunk length (seconds) at model SR
need -= (wav.samples.shape[0] / float(self.mrt.sample_rate))
if self._stop_event.is_set():
break
# 2) stitch and trim to exact seconds at model SR
y = stitch_generated(chunks, self.mrt.sample_rate, xfade).as_stereo()
y = hard_trim_seconds(y, chunk_secs)
# 3) post-process
if self.idx == 0 and self.params.ref_loop is not None:
y, _ = match_loudness_to_reference(self.params.ref_loop, y,
method=self.params.loudness_mode,
headroom_db=self.params.headroom_db)
else:
apply_micro_fades(y, 3)
# 4) resample + snap + b64
b64, meta = self._snap_and_encode(y, seconds=chunk_secs,
target_sr=self.params.target_sr,
bars=self.params.bars_per_chunk)
# 5) enqueue
with self._lock:
self.idx += 1
self.outbox.append(JamChunk(index=self.idx, audio_base64=b64, metadata=meta))
self.last_chunk_completed_at = time.time()
# optional: cleanup here if needed
|