thecollabagepatch commited on
Commit
956f1a2
·
1 Parent(s): c4aed03

keep jamming button

Browse files
Files changed (3) hide show
  1. app.py +173 -161
  2. jam_worker.py +0 -0
  3. utils.py +168 -0
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from magenta_rt import system, audio as au
2
  import numpy as np
3
- from fastapi import FastAPI, UploadFile, File, Form
4
  import tempfile, io, base64, math, threading
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from contextlib import contextmanager
@@ -8,6 +8,16 @@ import soundfile as sf
8
  import numpy as np
9
  from math import gcd
10
  from scipy.signal import resample_poly
 
 
 
 
 
 
 
 
 
 
11
 
12
  @contextmanager
13
  def mrt_overrides(mrt, **kwargs):
@@ -30,133 +40,6 @@ try:
30
  except Exception:
31
  _HAS_LOUDNORM = False
32
 
33
- def _measure_lufs(wav: au.Waveform) -> float:
34
- # pyloudnorm expects float32/float64, shape (n,) or (n, ch)
35
- meter = pyln.Meter(wav.sample_rate) # defaults to BS.1770-4
36
- return float(meter.integrated_loudness(wav.samples))
37
-
38
- def _rms(x: np.ndarray) -> float:
39
- if x.size == 0: return 0.0
40
- return float(np.sqrt(np.mean(x**2)))
41
-
42
- def match_loudness_to_reference(
43
- ref: au.Waveform,
44
- target: au.Waveform,
45
- method: str = "auto", # "auto"|"lufs"|"rms"|"none"
46
- headroom_db: float = 1.0
47
- ) -> tuple[au.Waveform, dict]:
48
- """
49
- Scales `target` to match `ref` loudness. Returns (adjusted_wave, stats).
50
- """
51
- stats = {"method": method, "applied_gain_db": 0.0}
52
-
53
- if method == "none":
54
- return target, stats
55
-
56
- if method == "auto":
57
- method = "lufs" if _HAS_LOUDNORM else "rms"
58
-
59
- if method == "lufs" and _HAS_LOUDNORM:
60
- L_ref = _measure_lufs(ref)
61
- L_tgt = _measure_lufs(target)
62
- delta_db = L_ref - L_tgt
63
- gain = 10.0 ** (delta_db / 20.0)
64
- y = target.samples.astype(np.float32) * gain
65
- stats.update({"ref_lufs": L_ref, "tgt_lufs_before": L_tgt, "applied_gain_db": delta_db})
66
- else:
67
- # RMS fallback
68
- ra = _rms(ref.samples)
69
- rb = _rms(target.samples)
70
- if rb <= 1e-12:
71
- return target, stats
72
- gain = ra / rb
73
- y = target.samples.astype(np.float32) * gain
74
- stats.update({"ref_rms": ra, "tgt_rms_before": rb, "applied_gain_db": 20*np.log10(max(gain,1e-12))})
75
-
76
- # simple peak “limiter” to keep headroom
77
- limit = 10 ** (-headroom_db / 20.0) # e.g., -1 dBFS
78
- peak = float(np.max(np.abs(y))) if y.size else 0.0
79
- if peak > limit:
80
- y *= (limit / peak)
81
- stats["post_peak_limited"] = True
82
- else:
83
- stats["post_peak_limited"] = False
84
-
85
- target.samples = y.astype(np.float32)
86
- return target, stats
87
-
88
- # ----------------------------
89
- # Crossfade stitch (your good path)
90
- # ----------------------------
91
- def stitch_generated(chunks, sr, xfade_s):
92
- if not chunks:
93
- raise ValueError("no chunks")
94
- xfade_n = int(round(xfade_s * sr))
95
- if xfade_n <= 0:
96
- return au.Waveform(np.concatenate([c.samples for c in chunks], axis=0), sr)
97
-
98
- t = np.linspace(0, np.pi/2, xfade_n, endpoint=False, dtype=np.float32)
99
- eq_in, eq_out = np.sin(t)[:, None], np.cos(t)[:, None]
100
-
101
- first = chunks[0].samples
102
- if first.shape[0] < xfade_n:
103
- raise ValueError("chunk shorter than crossfade prefix")
104
- out = first[xfade_n:].copy() # drop model pre-roll
105
-
106
- for i in range(1, len(chunks)):
107
- cur = chunks[i].samples
108
- if cur.shape[0] < xfade_n:
109
- continue
110
- head, tail = cur[:xfade_n], cur[xfade_n:]
111
- mixed = out[-xfade_n:] * eq_out + head * eq_in
112
- out = np.concatenate([out[:-xfade_n], mixed, tail], axis=0)
113
-
114
- return au.Waveform(out, sr)
115
-
116
- # ----------------------------
117
- # Bar-aligned token context
118
- # ----------------------------
119
- def make_bar_aligned_context(tokens, bpm, fps=25, ctx_frames=250, beats_per_bar=4):
120
- frames_per_bar_f = (beats_per_bar * 60.0 / bpm) * fps
121
- frames_per_bar = int(round(frames_per_bar_f))
122
- if abs(frames_per_bar - frames_per_bar_f) > 1e-3:
123
- reps = int(np.ceil(ctx_frames / len(tokens)))
124
- return np.tile(tokens, (reps, 1))[-ctx_frames:]
125
- reps = int(np.ceil(ctx_frames / len(tokens)))
126
- tiled = np.tile(tokens, (reps, 1))
127
- end = (len(tiled) // frames_per_bar) * frames_per_bar
128
- if end < ctx_frames:
129
- return tiled[-ctx_frames:]
130
- start = end - ctx_frames
131
- return tiled[start:end]
132
-
133
- def hard_trim_seconds(wav: au.Waveform, seconds: float) -> au.Waveform:
134
- n = int(round(seconds * wav.sample_rate))
135
- return au.Waveform(wav.samples[:n], wav.sample_rate)
136
-
137
- def apply_micro_fades(wav: au.Waveform, ms: int = 5) -> None:
138
- n = int(wav.sample_rate * ms / 1000.0)
139
- if n > 0 and wav.samples.shape[0] > 2*n:
140
- env = np.linspace(0.0, 1.0, n, dtype=np.float32)[:, None]
141
- wav.samples[:n] *= env
142
- wav.samples[-n:] *= env[::-1]
143
-
144
- def take_bar_aligned_tail(wav, bpm, beats_per_bar, ctx_seconds, max_bars=None):
145
- """
146
- Return the LAST N bars whose duration is as close as possible to ctx_seconds,
147
- anchored to the end of `wav`, and bar-aligned.
148
- """
149
- spb = (60.0 / bpm) * beats_per_bar
150
-
151
- bars_needed = max(1, int(round(ctx_seconds / spb)))
152
- if max_bars is not None:
153
- bars_needed = min(bars_needed, max_bars)
154
- tail_seconds = bars_needed * spb
155
- n = int(round(tail_seconds * wav.sample_rate))
156
- if n >= wav.samples.shape[0]:
157
- return wav
158
- return au.Waveform(wav.samples[-n:], wav.sample_rate)
159
-
160
  # ----------------------------
161
  # Main generation (single combined style vector)
162
  # ----------------------------
@@ -326,42 +209,18 @@ def generate(
326
  input_sr = int(inp_info.samplerate)
327
  target_sr = int(target_sample_rate or input_sr)
328
 
329
- # 2) Convert magenta output to target_sr if needed
330
- # wav.samples: shape [num_samples, num_channels], float32/-1..1 (per your code)
331
  cur_sr = int(mrt.sample_rate)
332
- x = wav.samples # np.ndarray (S, C)
333
-
334
- if cur_sr != target_sr:
335
- g = gcd(cur_sr, target_sr)
336
- up, down = target_sr // g, cur_sr // g
337
- # ensure 2D shape (S, C)
338
- x = wav.samples
339
- if x.ndim == 1:
340
- x = x[:, None]
341
- y = np.column_stack([resample_poly(x[:, ch], up, down) for ch in range(x.shape[1])])
342
- else:
343
- y = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None]
344
-
345
- # 3) Snap to exact frame count for loop-perfect length
346
  seconds_per_bar = (60.0 / float(bpm)) * int(beats_per_bar)
347
- expected_len = int(round(float(bars) * seconds_per_bar * target_sr))
 
348
 
349
- if y.shape[0] < expected_len:
350
- pad = np.zeros((expected_len - y.shape[0], y.shape[1]), dtype=y.dtype)
351
- y = np.vstack([y, pad])
352
- elif y.shape[0] > expected_len:
353
- y = y[:expected_len, :]
354
-
355
- total_samples = int(y.shape[0])
356
  loop_duration_seconds = total_samples / float(target_sr)
357
 
358
- # 4) Write y into buf as WAV @ target_sr
359
- buf = io.BytesIO()
360
- sf.write(buf, y, target_sr, subtype="FLOAT", format="WAV")
361
- buf.seek(0)
362
- audio_b64 = base64.b64encode(buf.read()).decode("utf-8")
363
-
364
- # 5) Update metadata to be authoritative
365
  metadata = {
366
  "bpm": int(round(bpm)),
367
  "bars": int(bars),
@@ -371,9 +230,9 @@ def generate(
371
  "loop_weight": loop_weight,
372
  "loudness": loud_stats,
373
  "sample_rate": int(target_sr),
374
- "channels": int(y.shape[1]),
375
  "crossfade_seconds": mrt.config.crossfade_length,
376
- "total_samples": total_samples,
377
  "seconds_per_bar": seconds_per_bar,
378
  "loop_duration_seconds": loop_duration_seconds,
379
  "guidance_weight": guidance_weight,
@@ -382,6 +241,159 @@ def generate(
382
  }
383
  return {"audio_base64": audio_b64, "metadata": metadata}
384
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
  @app.get("/health")
386
  def health():
387
  return {"ok": True}
 
1
  from magenta_rt import system, audio as au
2
  import numpy as np
3
+ from fastapi import FastAPI, UploadFile, File, Form, Body, HTTPException, Response
4
  import tempfile, io, base64, math, threading
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from contextlib import contextmanager
 
8
  import numpy as np
9
  from math import gcd
10
  from scipy.signal import resample_poly
11
+ from utils import (
12
+ match_loudness_to_reference, stitch_generated, hard_trim_seconds,
13
+ apply_micro_fades, make_bar_aligned_context, take_bar_aligned_tail,
14
+ resample_and_snap, wav_bytes_base64
15
+ )
16
+
17
+ from jam_worker import JamWorker, JamParams, JamChunk
18
+ import uuid, threading
19
+ jam_registry: dict[str, JamWorker] = {}
20
+ jam_lock = threading.Lock()
21
 
22
  @contextmanager
23
  def mrt_overrides(mrt, **kwargs):
 
40
  except Exception:
41
  _HAS_LOUDNORM = False
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  # ----------------------------
44
  # Main generation (single combined style vector)
45
  # ----------------------------
 
209
  input_sr = int(inp_info.samplerate)
210
  target_sr = int(target_sample_rate or input_sr)
211
 
212
+ # 2) Convert to target SR + snap to exact bars
 
213
  cur_sr = int(mrt.sample_rate)
214
+ x = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None]
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  seconds_per_bar = (60.0 / float(bpm)) * int(beats_per_bar)
216
+ expected_secs = float(bars) * seconds_per_bar
217
+ x = resample_and_snap(x, cur_sr=cur_sr, target_sr=target_sr, seconds=expected_secs)
218
 
219
+ # 3) Encode WAV once (no extra write)
220
+ audio_b64, total_samples, channels = wav_bytes_base64(x, target_sr)
 
 
 
 
 
221
  loop_duration_seconds = total_samples / float(target_sr)
222
 
223
+ # 4) Metadata
 
 
 
 
 
 
224
  metadata = {
225
  "bpm": int(round(bpm)),
226
  "bars": int(bars),
 
230
  "loop_weight": loop_weight,
231
  "loudness": loud_stats,
232
  "sample_rate": int(target_sr),
233
+ "channels": int(channels),
234
  "crossfade_seconds": mrt.config.crossfade_length,
235
+ "total_samples": int(total_samples),
236
  "seconds_per_bar": seconds_per_bar,
237
  "loop_duration_seconds": loop_duration_seconds,
238
  "guidance_weight": guidance_weight,
 
241
  }
242
  return {"audio_base64": audio_b64, "metadata": metadata}
243
 
244
+ # ----------------------------
245
+ # the 'keep jamming' button
246
+ # ----------------------------
247
+
248
+ @app.post("/jam/start")
249
+ def jam_start(
250
+ loop_audio: UploadFile = File(...),
251
+ bpm: float = Form(...),
252
+ bars_per_chunk: int = Form(4),
253
+ beats_per_bar: int = Form(4),
254
+ styles: str = Form(""),
255
+ style_weights: str = Form(""),
256
+ loop_weight: float = Form(1.0),
257
+ loudness_mode: str = Form("auto"),
258
+ loudness_headroom_db: float = Form(1.0),
259
+ guidance_weight: float = Form(1.1),
260
+ temperature: float = Form(1.1),
261
+ topk: int = Form(40),
262
+ target_sample_rate: int | None = Form(None),
263
+ ):
264
+ # enforce single active jam per GPU
265
+ with jam_lock:
266
+ for sid, w in list(jam_registry.items()):
267
+ if w.is_alive():
268
+ raise HTTPException(status_code=429, detail="A jam is already running. Try again later.")
269
+
270
+ # read input + prep context/style (reuse your existing code)
271
+ data = loop_audio.file.read()
272
+ if not data: raise HTTPException(status_code=400, detail="Empty file")
273
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
274
+ tmp.write(data); tmp_path = tmp.name
275
+
276
+ mrt = get_mrt()
277
+ loop = au.Waveform.from_file(tmp_path).resample(mrt.sample_rate).as_stereo()
278
+
279
+ # build tail context + style vec (tail-biased)
280
+ codec_fps = float(mrt.codec.frame_rate)
281
+ ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
282
+ loop_tail = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds)
283
+
284
+ # style vec = normalized mix of loop_tail + extra styles
285
+ embeds, weights = [mrt.embed_style(loop_tail)], [float(loop_weight)]
286
+ extra = [s for s in (styles.split(",") if styles else []) if s.strip()]
287
+ sw = [float(x) for x in style_weights.split(",")] if style_weights else []
288
+ for i, s in enumerate(extra):
289
+ embeds.append(mrt.embed_style(s.strip()))
290
+ weights.append(sw[i] if i < len(sw) else 1.0)
291
+ wsum = sum(weights) or 1.0
292
+ weights = [w / wsum for w in weights]
293
+ style_vec = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(embeds[0].dtype)
294
+
295
+ # target SR (default input SR)
296
+ inp_info = sf.info(tmp_path)
297
+ input_sr = int(inp_info.samplerate)
298
+ target_sr = int(target_sample_rate or input_sr)
299
+
300
+ params = JamParams(
301
+ bpm=bpm, beats_per_bar=beats_per_bar, bars_per_chunk=bars_per_chunk,
302
+ target_sr=target_sr, loudness_mode=loudness_mode, headroom_db=loudness_headroom_db,
303
+ style_vec=style_vec, ref_loop=loop_tail,
304
+ guidance_weight=guidance_weight, temperature=temperature, topk=topk
305
+ )
306
+
307
+ worker = JamWorker(mrt, params)
308
+ sid = str(uuid.uuid4())
309
+ with jam_lock:
310
+ jam_registry[sid] = worker
311
+ worker.start()
312
+
313
+ return {"session_id": sid}
314
+
315
+ @app.get("/jam/next")
316
+ def jam_next(session_id: str, since: int = 0):
317
+ with jam_lock:
318
+ worker = jam_registry.get(session_id)
319
+ if worker is None or not worker.is_alive():
320
+ raise HTTPException(status_code=404, detail="Session not found")
321
+
322
+ # drain outbox entries with index > since
323
+ items = []
324
+ with worker._lock:
325
+ for ch in worker.outbox:
326
+ if ch.index > since:
327
+ items.append({"index": ch.index, "audio_base64": ch.audio_base64, "metadata": ch.metadata})
328
+ # optional: truncate old items to keep memory bounded
329
+ if len(worker.outbox) > 32:
330
+ worker.outbox = worker.outbox[-16:]
331
+
332
+ if not items:
333
+ return Response(status_code=204) # nothing yet
334
+ return {"chunks": items}
335
+
336
+ @app.post("/jam/stop")
337
+ def jam_stop(session_id: str = Body(..., embed=True)):
338
+ with jam_lock:
339
+ worker = jam_registry.get(session_id)
340
+ if worker is None:
341
+ raise HTTPException(status_code=404, detail="Session not found")
342
+ worker.stop()
343
+ worker.join(timeout=2.0)
344
+ with jam_lock:
345
+ jam_registry.pop(session_id, None)
346
+ return {"stopped": True}
347
+
348
+ @app.post("/jam/update")
349
+ def jam_update(session_id: str = Form(...),
350
+ guidance_weight: float | None = Form(None),
351
+ temperature: float | None = Form(None),
352
+ topk: int | None = Form(None)):
353
+ with jam_lock:
354
+ worker = jam_registry.get(session_id)
355
+ if worker is None or not worker.is_alive():
356
+ raise HTTPException(status_code=404, detail="Session not found")
357
+ worker.update_knobs(guidance_weight=guidance_weight, temperature=temperature, topk=topk)
358
+ return {"ok": True}
359
+
360
+ @app.get("/jam/status")
361
+ def jam_status(session_id: str):
362
+ with jam_lock:
363
+ worker = jam_registry.get(session_id)
364
+
365
+ if worker is None:
366
+ raise HTTPException(status_code=404, detail="Session not found")
367
+
368
+ running = worker.is_alive()
369
+
370
+ # Snapshot safely
371
+ with worker._lock:
372
+ last_index = int(worker.idx)
373
+ queued = len(worker.outbox)
374
+ p = worker.params
375
+ spb = p.beats_per_bar * (60.0 / p.bpm)
376
+ chunk_secs = p.bars_per_chunk * spb
377
+ target_sr = p.target_sr
378
+ bars_per_chunk = p.bars_per_chunk
379
+ beats_per_bar = p.beats_per_bar
380
+ bpm = p.bpm
381
+
382
+ return {
383
+ "running": running,
384
+ "last_index": last_index, # last finished chunk index (0 if none yet)
385
+ "queued_chunks": queued, # how many not-yet-fetched chunks are in the outbox
386
+ "bpm": bpm,
387
+ "beats_per_bar": beats_per_bar,
388
+ "bars_per_chunk": bars_per_chunk,
389
+ "seconds_per_bar": spb,
390
+ "chunk_duration_seconds": chunk_secs,
391
+ "target_sample_rate": target_sr,
392
+ "last_chunk_started_at": worker.last_chunk_started_at,
393
+ "last_chunk_completed_at": worker.last_chunk_completed_at,
394
+ }
395
+
396
+
397
  @app.get("/health")
398
  def health():
399
  return {"ok": True}
jam_worker.py ADDED
File without changes
utils.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils.py
2
+ from __future__ import annotations
3
+ import io, base64, math
4
+ from math import gcd
5
+ import numpy as np
6
+ import soundfile as sf
7
+ from scipy.signal import resample_poly
8
+
9
+ # Magenta RT audio types
10
+ from magenta_rt import audio as au
11
+
12
+ # Optional loudness
13
+ try:
14
+ import pyloudnorm as pyln
15
+ _HAS_LOUDNORM = True
16
+ except Exception:
17
+ _HAS_LOUDNORM = False
18
+
19
+
20
+ # ---------- Loudness ----------
21
+ def _measure_lufs(wav: au.Waveform) -> float:
22
+ meter = pyln.Meter(wav.sample_rate) # BS.1770-4
23
+ return float(meter.integrated_loudness(wav.samples))
24
+
25
+ def _rms(x: np.ndarray) -> float:
26
+ if x.size == 0: return 0.0
27
+ return float(np.sqrt(np.mean(x**2)))
28
+
29
+ def match_loudness_to_reference(
30
+ ref: au.Waveform,
31
+ target: au.Waveform,
32
+ method: str = "auto", # "auto"|"lufs"|"rms"|"none"
33
+ headroom_db: float = 1.0
34
+ ) -> tuple[au.Waveform, dict]:
35
+ stats = {"method": method, "applied_gain_db": 0.0}
36
+ if method == "none":
37
+ return target, stats
38
+
39
+ if method == "auto":
40
+ method = "lufs" if _HAS_LOUDNORM else "rms"
41
+
42
+ if method == "lufs" and _HAS_LOUDNORM:
43
+ L_ref = _measure_lufs(ref)
44
+ L_tgt = _measure_lufs(target)
45
+ delta_db = L_ref - L_tgt
46
+ gain = 10.0 ** (delta_db / 20.0)
47
+ y = target.samples.astype(np.float32) * gain
48
+ stats.update({"ref_lufs": L_ref, "tgt_lufs_before": L_tgt, "applied_gain_db": delta_db})
49
+ else:
50
+ ra = _rms(ref.samples)
51
+ rb = _rms(target.samples)
52
+ if rb <= 1e-12:
53
+ return target, stats
54
+ gain = ra / rb
55
+ y = target.samples.astype(np.float32) * gain
56
+ stats.update({"ref_rms": ra, "tgt_rms_before": rb, "applied_gain_db": 20*np.log10(max(gain,1e-12))})
57
+
58
+ # simple peak “limiter” to keep headroom
59
+ limit = 10 ** (-headroom_db / 20.0) # e.g., -1 dBFS
60
+ peak = float(np.max(np.abs(y))) if y.size else 0.0
61
+ if peak > limit:
62
+ y *= (limit / peak)
63
+ stats["post_peak_limited"] = True
64
+ else:
65
+ stats["post_peak_limited"] = False
66
+
67
+ target.samples = y.astype(np.float32)
68
+ return target, stats
69
+
70
+
71
+ # ---------- Stitch / fades / trims ----------
72
+ def stitch_generated(chunks, sr: int, xfade_s: float) -> au.Waveform:
73
+ if not chunks:
74
+ raise ValueError("no chunks")
75
+ xfade_n = int(round(xfade_s * sr))
76
+ if xfade_n <= 0:
77
+ return au.Waveform(np.concatenate([c.samples for c in chunks], axis=0), sr)
78
+
79
+ t = np.linspace(0, np.pi/2, xfade_n, endpoint=False, dtype=np.float32)
80
+ eq_in, eq_out = np.sin(t)[:, None], np.cos(t)[:, None]
81
+
82
+ first = chunks[0].samples
83
+ if first.shape[0] < xfade_n:
84
+ raise ValueError("chunk shorter than crossfade prefix")
85
+ out = first[xfade_n:].copy() # drop model pre-roll
86
+
87
+ for i in range(1, len(chunks)):
88
+ cur = chunks[i].samples
89
+ if cur.shape[0] < xfade_n:
90
+ continue
91
+ head, tail = cur[:xfade_n], cur[xfade_n:]
92
+ mixed = out[-xfade_n:] * eq_out + head * eq_in
93
+ out = np.concatenate([out[:-xfade_n], mixed, tail], axis=0)
94
+
95
+ return au.Waveform(out, sr)
96
+
97
+ def hard_trim_seconds(wav: au.Waveform, seconds: float) -> au.Waveform:
98
+ n = int(round(seconds * wav.sample_rate))
99
+ return au.Waveform(wav.samples[:n], wav.sample_rate)
100
+
101
+ def apply_micro_fades(wav: au.Waveform, ms: int = 5) -> None:
102
+ n = int(wav.sample_rate * ms / 1000.0)
103
+ if n > 0 and wav.samples.shape[0] > 2*n:
104
+ env = np.linspace(0.0, 1.0, n, dtype=np.float32)[:, None]
105
+ wav.samples[:n] *= env
106
+ wav.samples[-n:] *= env[::-1]
107
+
108
+
109
+ # ---------- Token context helpers ----------
110
+ def make_bar_aligned_context(tokens, bpm, fps=25, ctx_frames=250, beats_per_bar=4):
111
+ frames_per_bar_f = (beats_per_bar * 60.0 / bpm) * fps
112
+ frames_per_bar = int(round(frames_per_bar_f))
113
+ if abs(frames_per_bar - frames_per_bar_f) > 1e-3:
114
+ reps = int(np.ceil(ctx_frames / len(tokens)))
115
+ return np.tile(tokens, (reps, 1))[-ctx_frames:]
116
+ reps = int(np.ceil(ctx_frames / len(tokens)))
117
+ tiled = np.tile(tokens, (reps, 1))
118
+ end = (len(tiled) // frames_per_bar) * frames_per_bar
119
+ if end < ctx_frames:
120
+ return tiled[-ctx_frames:]
121
+ start = end - ctx_frames
122
+ return tiled[start:end]
123
+
124
+ def take_bar_aligned_tail(wav: au.Waveform, bpm: float, beats_per_bar: int, ctx_seconds: float, max_bars=None) -> au.Waveform:
125
+ spb = (60.0 / bpm) * beats_per_bar
126
+ bars_needed = max(1, int(round(ctx_seconds / spb)))
127
+ if max_bars is not None:
128
+ bars_needed = min(bars_needed, max_bars)
129
+ tail_seconds = bars_needed * spb
130
+ n = int(round(tail_seconds * wav.sample_rate))
131
+ if n >= wav.samples.shape[0]:
132
+ return wav
133
+ return au.Waveform(wav.samples[-n:], wav.sample_rate)
134
+
135
+
136
+ # ---------- SR normalize + snap ----------
137
+ def resample_and_snap(x: np.ndarray, cur_sr: int, target_sr: int, seconds: float) -> np.ndarray:
138
+ """
139
+ x: np.ndarray shape (S, C), float32
140
+ Returns: exact-length array (round(seconds*target_sr), C)
141
+ """
142
+ if x.ndim == 1:
143
+ x = x[:, None]
144
+ if cur_sr != target_sr:
145
+ g = gcd(cur_sr, target_sr)
146
+ up, down = target_sr // g, cur_sr // g
147
+ x = resample_poly(x, up, down, axis=0)
148
+
149
+ expected_len = int(round(seconds * target_sr))
150
+ if x.shape[0] < expected_len:
151
+ pad = np.zeros((expected_len - x.shape[0], x.shape[1]), dtype=x.dtype)
152
+ x = np.vstack([x, pad])
153
+ elif x.shape[0] > expected_len:
154
+ x = x[:expected_len, :]
155
+ return x.astype(np.float32, copy=False)
156
+
157
+
158
+ # ---------- WAV encode ----------
159
+ def wav_bytes_base64(x: np.ndarray, sr: int) -> tuple[str, int, int]:
160
+ """
161
+ x: np.ndarray shape (S, C)
162
+ returns: (base64_wav, total_samples, channels)
163
+ """
164
+ buf = io.BytesIO()
165
+ sf.write(buf, x, sr, subtype="FLOAT", format="WAV")
166
+ buf.seek(0)
167
+ b64 = base64.b64encode(buf.read()).decode("utf-8")
168
+ return b64, int(x.shape[0]), int(x.shape[1])