Spaces:
Running
Running
File size: 5,544 Bytes
cd609af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
# jam_worker.py
import threading, time, base64, io, uuid
from dataclasses import dataclass, field
import numpy as np
import soundfile as sf
# Pull in your helpers from app.py or refactor them into a shared utils module.
from utils import (
match_loudness_to_reference, stitch_generated, hard_trim_seconds,
apply_micro_fades, make_bar_aligned_context, take_bar_aligned_tail,
resample_and_snap, wav_bytes_base64
)
from scipy.signal import resample_poly
from math import gcd
@dataclass
class JamParams:
bpm: float
beats_per_bar: int
bars_per_chunk: int # 4 or 8
target_sr: int
loudness_mode: str = "auto"
headroom_db: float = 1.0
style_vec: np.ndarray | None = None # combined_style vector
ref_loop: any = None # au.Waveform at model SR for 1st-chunk loudness
guidance_weight: float = 1.1
temperature: float = 1.1
topk: int = 40
@dataclass
class JamChunk:
index: int
audio_base64: str
metadata: dict
class JamWorker(threading.Thread):
def __init__(self, mrt, params: JamParams):
super().__init__(daemon=True)
self.mrt = mrt
self.params = params
self.state = mrt.init_state()
self.idx = 0
self.outbox: list[JamChunk] = []
self._stop = threading.Event()
self.last_chunk_started_at = None
self.last_chunk_completed_at = None
self._lock = threading.Lock()
def stop(self):
self._stop.set()
def update_style(self, new_style_vec: np.ndarray | None):
with self._lock:
if new_style_vec is not None:
self.params.style_vec = new_style_vec
def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None):
with self._lock:
if guidance_weight is not None: self.params.guidance_weight = float(guidance_weight)
if temperature is not None: self.params.temperature = float(temperature)
if topk is not None: self.params.topk = int(topk)
def _seconds_per_bar(self) -> float:
return self.params.beats_per_bar * (60.0 / self.params.bpm)
def _snap_and_encode(self, y, seconds, target_sr, bars):
cur_sr = int(self.mrt.sample_rate)
x = y.samples if y.samples.ndim == 2 else y.samples[:, None]
x = resample_and_snap(x, cur_sr=cur_sr, target_sr=target_sr, seconds=seconds)
b64, total_samples, channels = wav_bytes_base64(x, target_sr)
meta = {
"bpm": int(round(self.params.bpm)),
"bars": int(bars),
"beats_per_bar": int(self.params.beats_per_bar),
"sample_rate": int(target_sr),
"channels": channels,
"total_samples": total_samples,
"seconds_per_bar": self._seconds_per_bar(),
"loop_duration_seconds": bars * self._seconds_per_bar(),
"guidance_weight": self.params.guidance_weight,
"temperature": self.params.temperature,
"topk": self.params.topk,
}
return b64, meta
def run(self):
spb = self._seconds_per_bar()
chunk_secs = self.params.bars_per_chunk * spb
xfade = self.mrt.config.crossfade_length
# Prime: set initial context on state (caller should have done this; safe to re-set here)
# NOTE: We assume caller passed a style_vec computed from tail/whole/blend.
while not self._stop.is_set():
# honor live knob updates atomically
with self._lock:
style_vec = self.params.style_vec
# Temporarily override MRT knobs (thread-local overrides)
self.mrt.guidance_weight = self.params.guidance_weight
self.mrt.temperature = self.params.temperature
self.mrt.topk = self.params.topk
# 1) generate enough model chunks to cover chunk_secs
need = chunk_secs
chunks = []
self.last_chunk_started_at = time.time()
while need > 0 and not self._stop.is_set():
wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
chunks.append(wav)
# model chunk length (seconds) at model SR
need -= (wav.samples.shape[0] / float(self.mrt.sample_rate))
if self._stop.is_set():
break
# 2) stitch and trim to exact seconds at model SR
y = stitch_generated(chunks, self.mrt.sample_rate, xfade).as_stereo()
y = hard_trim_seconds(y, chunk_secs)
# 3) post-process
if self.idx == 0 and self.params.ref_loop is not None:
y, _ = match_loudness_to_reference(self.params.ref_loop, y,
method=self.params.loudness_mode,
headroom_db=self.params.headroom_db)
else:
apply_micro_fades(y, 3)
# 4) resample + snap + b64
b64, meta = self._snap_and_encode(y, seconds=chunk_secs,
target_sr=self.params.target_sr,
bars=self.params.bars_per_chunk)
# 5) enqueue
with self._lock:
self.idx += 1
self.outbox.append(JamChunk(index=self.idx, audio_base64=b64, metadata=meta))
self.last_chunk_completed_at = time.time()
# optional: cleanup here if needed
|