thecollabagepatch commited on
Commit
cd609af
·
1 Parent(s): 1e82ab2

always save

Browse files
Files changed (1) hide show
  1. jam_worker.py +139 -0
jam_worker.py CHANGED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # jam_worker.py
2
+ import threading, time, base64, io, uuid
3
+ from dataclasses import dataclass, field
4
+ import numpy as np
5
+ import soundfile as sf
6
+
7
+ # Pull in your helpers from app.py or refactor them into a shared utils module.
8
+ from utils import (
9
+ match_loudness_to_reference, stitch_generated, hard_trim_seconds,
10
+ apply_micro_fades, make_bar_aligned_context, take_bar_aligned_tail,
11
+ resample_and_snap, wav_bytes_base64
12
+ )
13
+ from scipy.signal import resample_poly
14
+ from math import gcd
15
+
16
+ @dataclass
17
+ class JamParams:
18
+ bpm: float
19
+ beats_per_bar: int
20
+ bars_per_chunk: int # 4 or 8
21
+ target_sr: int
22
+ loudness_mode: str = "auto"
23
+ headroom_db: float = 1.0
24
+ style_vec: np.ndarray | None = None # combined_style vector
25
+ ref_loop: any = None # au.Waveform at model SR for 1st-chunk loudness
26
+ guidance_weight: float = 1.1
27
+ temperature: float = 1.1
28
+ topk: int = 40
29
+
30
+ @dataclass
31
+ class JamChunk:
32
+ index: int
33
+ audio_base64: str
34
+ metadata: dict
35
+
36
+ class JamWorker(threading.Thread):
37
+ def __init__(self, mrt, params: JamParams):
38
+ super().__init__(daemon=True)
39
+ self.mrt = mrt
40
+ self.params = params
41
+ self.state = mrt.init_state()
42
+ self.idx = 0
43
+ self.outbox: list[JamChunk] = []
44
+ self._stop = threading.Event()
45
+ self.last_chunk_started_at = None
46
+ self.last_chunk_completed_at = None
47
+ self._lock = threading.Lock()
48
+
49
+ def stop(self):
50
+ self._stop.set()
51
+
52
+ def update_style(self, new_style_vec: np.ndarray | None):
53
+ with self._lock:
54
+ if new_style_vec is not None:
55
+ self.params.style_vec = new_style_vec
56
+
57
+ def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None):
58
+ with self._lock:
59
+ if guidance_weight is not None: self.params.guidance_weight = float(guidance_weight)
60
+ if temperature is not None: self.params.temperature = float(temperature)
61
+ if topk is not None: self.params.topk = int(topk)
62
+
63
+ def _seconds_per_bar(self) -> float:
64
+ return self.params.beats_per_bar * (60.0 / self.params.bpm)
65
+
66
+ def _snap_and_encode(self, y, seconds, target_sr, bars):
67
+ cur_sr = int(self.mrt.sample_rate)
68
+ x = y.samples if y.samples.ndim == 2 else y.samples[:, None]
69
+ x = resample_and_snap(x, cur_sr=cur_sr, target_sr=target_sr, seconds=seconds)
70
+ b64, total_samples, channels = wav_bytes_base64(x, target_sr)
71
+ meta = {
72
+ "bpm": int(round(self.params.bpm)),
73
+ "bars": int(bars),
74
+ "beats_per_bar": int(self.params.beats_per_bar),
75
+ "sample_rate": int(target_sr),
76
+ "channels": channels,
77
+ "total_samples": total_samples,
78
+ "seconds_per_bar": self._seconds_per_bar(),
79
+ "loop_duration_seconds": bars * self._seconds_per_bar(),
80
+ "guidance_weight": self.params.guidance_weight,
81
+ "temperature": self.params.temperature,
82
+ "topk": self.params.topk,
83
+ }
84
+ return b64, meta
85
+
86
+ def run(self):
87
+ spb = self._seconds_per_bar()
88
+ chunk_secs = self.params.bars_per_chunk * spb
89
+ xfade = self.mrt.config.crossfade_length
90
+
91
+ # Prime: set initial context on state (caller should have done this; safe to re-set here)
92
+ # NOTE: We assume caller passed a style_vec computed from tail/whole/blend.
93
+ while not self._stop.is_set():
94
+ # honor live knob updates atomically
95
+ with self._lock:
96
+ style_vec = self.params.style_vec
97
+ # Temporarily override MRT knobs (thread-local overrides)
98
+ self.mrt.guidance_weight = self.params.guidance_weight
99
+ self.mrt.temperature = self.params.temperature
100
+ self.mrt.topk = self.params.topk
101
+
102
+ # 1) generate enough model chunks to cover chunk_secs
103
+ need = chunk_secs
104
+ chunks = []
105
+ self.last_chunk_started_at = time.time()
106
+ while need > 0 and not self._stop.is_set():
107
+ wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
108
+ chunks.append(wav)
109
+ # model chunk length (seconds) at model SR
110
+ need -= (wav.samples.shape[0] / float(self.mrt.sample_rate))
111
+
112
+ if self._stop.is_set():
113
+ break
114
+
115
+ # 2) stitch and trim to exact seconds at model SR
116
+ y = stitch_generated(chunks, self.mrt.sample_rate, xfade).as_stereo()
117
+ y = hard_trim_seconds(y, chunk_secs)
118
+
119
+ # 3) post-process
120
+ if self.idx == 0 and self.params.ref_loop is not None:
121
+ y, _ = match_loudness_to_reference(self.params.ref_loop, y,
122
+ method=self.params.loudness_mode,
123
+ headroom_db=self.params.headroom_db)
124
+ else:
125
+ apply_micro_fades(y, 3)
126
+
127
+ # 4) resample + snap + b64
128
+ b64, meta = self._snap_and_encode(y, seconds=chunk_secs,
129
+ target_sr=self.params.target_sr,
130
+ bars=self.params.bars_per_chunk)
131
+
132
+ # 5) enqueue
133
+ with self._lock:
134
+ self.idx += 1
135
+ self.outbox.append(JamChunk(index=self.idx, audio_base64=b64, metadata=meta))
136
+
137
+ self.last_chunk_completed_at = time.time()
138
+
139
+ # optional: cleanup here if needed