Spaces:
Running
Running
Commit
·
cd609af
1
Parent(s):
1e82ab2
always save
Browse files- jam_worker.py +139 -0
jam_worker.py
CHANGED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# jam_worker.py
|
2 |
+
import threading, time, base64, io, uuid
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
import numpy as np
|
5 |
+
import soundfile as sf
|
6 |
+
|
7 |
+
# Pull in your helpers from app.py or refactor them into a shared utils module.
|
8 |
+
from utils import (
|
9 |
+
match_loudness_to_reference, stitch_generated, hard_trim_seconds,
|
10 |
+
apply_micro_fades, make_bar_aligned_context, take_bar_aligned_tail,
|
11 |
+
resample_and_snap, wav_bytes_base64
|
12 |
+
)
|
13 |
+
from scipy.signal import resample_poly
|
14 |
+
from math import gcd
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class JamParams:
|
18 |
+
bpm: float
|
19 |
+
beats_per_bar: int
|
20 |
+
bars_per_chunk: int # 4 or 8
|
21 |
+
target_sr: int
|
22 |
+
loudness_mode: str = "auto"
|
23 |
+
headroom_db: float = 1.0
|
24 |
+
style_vec: np.ndarray | None = None # combined_style vector
|
25 |
+
ref_loop: any = None # au.Waveform at model SR for 1st-chunk loudness
|
26 |
+
guidance_weight: float = 1.1
|
27 |
+
temperature: float = 1.1
|
28 |
+
topk: int = 40
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
class JamChunk:
|
32 |
+
index: int
|
33 |
+
audio_base64: str
|
34 |
+
metadata: dict
|
35 |
+
|
36 |
+
class JamWorker(threading.Thread):
|
37 |
+
def __init__(self, mrt, params: JamParams):
|
38 |
+
super().__init__(daemon=True)
|
39 |
+
self.mrt = mrt
|
40 |
+
self.params = params
|
41 |
+
self.state = mrt.init_state()
|
42 |
+
self.idx = 0
|
43 |
+
self.outbox: list[JamChunk] = []
|
44 |
+
self._stop = threading.Event()
|
45 |
+
self.last_chunk_started_at = None
|
46 |
+
self.last_chunk_completed_at = None
|
47 |
+
self._lock = threading.Lock()
|
48 |
+
|
49 |
+
def stop(self):
|
50 |
+
self._stop.set()
|
51 |
+
|
52 |
+
def update_style(self, new_style_vec: np.ndarray | None):
|
53 |
+
with self._lock:
|
54 |
+
if new_style_vec is not None:
|
55 |
+
self.params.style_vec = new_style_vec
|
56 |
+
|
57 |
+
def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None):
|
58 |
+
with self._lock:
|
59 |
+
if guidance_weight is not None: self.params.guidance_weight = float(guidance_weight)
|
60 |
+
if temperature is not None: self.params.temperature = float(temperature)
|
61 |
+
if topk is not None: self.params.topk = int(topk)
|
62 |
+
|
63 |
+
def _seconds_per_bar(self) -> float:
|
64 |
+
return self.params.beats_per_bar * (60.0 / self.params.bpm)
|
65 |
+
|
66 |
+
def _snap_and_encode(self, y, seconds, target_sr, bars):
|
67 |
+
cur_sr = int(self.mrt.sample_rate)
|
68 |
+
x = y.samples if y.samples.ndim == 2 else y.samples[:, None]
|
69 |
+
x = resample_and_snap(x, cur_sr=cur_sr, target_sr=target_sr, seconds=seconds)
|
70 |
+
b64, total_samples, channels = wav_bytes_base64(x, target_sr)
|
71 |
+
meta = {
|
72 |
+
"bpm": int(round(self.params.bpm)),
|
73 |
+
"bars": int(bars),
|
74 |
+
"beats_per_bar": int(self.params.beats_per_bar),
|
75 |
+
"sample_rate": int(target_sr),
|
76 |
+
"channels": channels,
|
77 |
+
"total_samples": total_samples,
|
78 |
+
"seconds_per_bar": self._seconds_per_bar(),
|
79 |
+
"loop_duration_seconds": bars * self._seconds_per_bar(),
|
80 |
+
"guidance_weight": self.params.guidance_weight,
|
81 |
+
"temperature": self.params.temperature,
|
82 |
+
"topk": self.params.topk,
|
83 |
+
}
|
84 |
+
return b64, meta
|
85 |
+
|
86 |
+
def run(self):
|
87 |
+
spb = self._seconds_per_bar()
|
88 |
+
chunk_secs = self.params.bars_per_chunk * spb
|
89 |
+
xfade = self.mrt.config.crossfade_length
|
90 |
+
|
91 |
+
# Prime: set initial context on state (caller should have done this; safe to re-set here)
|
92 |
+
# NOTE: We assume caller passed a style_vec computed from tail/whole/blend.
|
93 |
+
while not self._stop.is_set():
|
94 |
+
# honor live knob updates atomically
|
95 |
+
with self._lock:
|
96 |
+
style_vec = self.params.style_vec
|
97 |
+
# Temporarily override MRT knobs (thread-local overrides)
|
98 |
+
self.mrt.guidance_weight = self.params.guidance_weight
|
99 |
+
self.mrt.temperature = self.params.temperature
|
100 |
+
self.mrt.topk = self.params.topk
|
101 |
+
|
102 |
+
# 1) generate enough model chunks to cover chunk_secs
|
103 |
+
need = chunk_secs
|
104 |
+
chunks = []
|
105 |
+
self.last_chunk_started_at = time.time()
|
106 |
+
while need > 0 and not self._stop.is_set():
|
107 |
+
wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
|
108 |
+
chunks.append(wav)
|
109 |
+
# model chunk length (seconds) at model SR
|
110 |
+
need -= (wav.samples.shape[0] / float(self.mrt.sample_rate))
|
111 |
+
|
112 |
+
if self._stop.is_set():
|
113 |
+
break
|
114 |
+
|
115 |
+
# 2) stitch and trim to exact seconds at model SR
|
116 |
+
y = stitch_generated(chunks, self.mrt.sample_rate, xfade).as_stereo()
|
117 |
+
y = hard_trim_seconds(y, chunk_secs)
|
118 |
+
|
119 |
+
# 3) post-process
|
120 |
+
if self.idx == 0 and self.params.ref_loop is not None:
|
121 |
+
y, _ = match_loudness_to_reference(self.params.ref_loop, y,
|
122 |
+
method=self.params.loudness_mode,
|
123 |
+
headroom_db=self.params.headroom_db)
|
124 |
+
else:
|
125 |
+
apply_micro_fades(y, 3)
|
126 |
+
|
127 |
+
# 4) resample + snap + b64
|
128 |
+
b64, meta = self._snap_and_encode(y, seconds=chunk_secs,
|
129 |
+
target_sr=self.params.target_sr,
|
130 |
+
bars=self.params.bars_per_chunk)
|
131 |
+
|
132 |
+
# 5) enqueue
|
133 |
+
with self._lock:
|
134 |
+
self.idx += 1
|
135 |
+
self.outbox.append(JamChunk(index=self.idx, audio_base64=b64, metadata=meta))
|
136 |
+
|
137 |
+
self.last_chunk_completed_at = time.time()
|
138 |
+
|
139 |
+
# optional: cleanup here if needed
|