Spaces:
Running
Running
Commit
·
f8b3793
1
Parent(s):
8cf69d0
ofc sample rates
Browse files
app.py
CHANGED
@@ -4,6 +4,10 @@ from fastapi import FastAPI, UploadFile, File, Form
|
|
4 |
import tempfile, io, base64, math, threading
|
5 |
from fastapi.middleware.cors import CORSMiddleware
|
6 |
from contextlib import contextmanager
|
|
|
|
|
|
|
|
|
7 |
|
8 |
@contextmanager
|
9 |
def mrt_overrides(mrt, **kwargs):
|
@@ -257,10 +261,10 @@ def generate(
|
|
257 |
loop_weight: float = Form(1.0),
|
258 |
loudness_mode: str = Form("auto"),
|
259 |
loudness_headroom_db: float = Form(1.0),
|
260 |
-
# NEW per-request knobs
|
261 |
guidance_weight: float = Form(5.0),
|
262 |
temperature: float = Form(1.1),
|
263 |
topk: int = Form(40),
|
|
|
264 |
):
|
265 |
# Read file
|
266 |
data = loop_audio.file.read()
|
@@ -293,45 +297,66 @@ def generate(
|
|
293 |
loudness_headroom_db=loudness_headroom_db,
|
294 |
)
|
295 |
|
296 |
-
|
297 |
-
|
298 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
|
300 |
-
#
|
301 |
seconds_per_bar = (60.0 / float(bpm)) * int(beats_per_bar)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
302 |
|
303 |
-
#
|
304 |
buf = io.BytesIO()
|
305 |
-
|
306 |
-
wav.write(buf, subtype="FLOAT", format="WAV")
|
307 |
buf.seek(0)
|
308 |
audio_b64 = base64.b64encode(buf.read()).decode("utf-8")
|
309 |
|
310 |
-
|
311 |
-
|
312 |
-
"
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
# Echo the actual knobs used
|
330 |
-
"guidance_weight": guidance_weight,
|
331 |
-
"temperature": temperature,
|
332 |
-
"topk": topk,
|
333 |
-
},
|
334 |
}
|
|
|
335 |
|
336 |
@app.get("/health")
|
337 |
def health():
|
|
|
4 |
import tempfile, io, base64, math, threading
|
5 |
from fastapi.middleware.cors import CORSMiddleware
|
6 |
from contextlib import contextmanager
|
7 |
+
import soundfile as sf
|
8 |
+
import numpy as np
|
9 |
+
from math import gcd
|
10 |
+
from scipy.signal import resample_poly
|
11 |
|
12 |
@contextmanager
|
13 |
def mrt_overrides(mrt, **kwargs):
|
|
|
261 |
loop_weight: float = Form(1.0),
|
262 |
loudness_mode: str = Form("auto"),
|
263 |
loudness_headroom_db: float = Form(1.0),
|
|
|
264 |
guidance_weight: float = Form(5.0),
|
265 |
temperature: float = Form(1.1),
|
266 |
topk: int = Form(40),
|
267 |
+
target_sample_rate: int | None = Form(None), # <-- add this
|
268 |
):
|
269 |
# Read file
|
270 |
data = loop_audio.file.read()
|
|
|
297 |
loudness_headroom_db=loudness_headroom_db,
|
298 |
)
|
299 |
|
300 |
+
# 1) Figure out the desired SR
|
301 |
+
inp_info = sf.info(tmp_path)
|
302 |
+
input_sr = int(inp_info.samplerate)
|
303 |
+
target_sr = int(target_sample_rate or input_sr)
|
304 |
+
|
305 |
+
# 2) Convert magenta output to target_sr if needed
|
306 |
+
# wav.samples: shape [num_samples, num_channels], float32/-1..1 (per your code)
|
307 |
+
cur_sr = int(mrt.sample_rate)
|
308 |
+
x = wav.samples # np.ndarray (S, C)
|
309 |
+
|
310 |
+
if cur_sr != target_sr:
|
311 |
+
g = gcd(cur_sr, target_sr)
|
312 |
+
up, down = target_sr // g, cur_sr // g
|
313 |
+
# ensure 2D shape (S, C)
|
314 |
+
x = wav.samples
|
315 |
+
if x.ndim == 1:
|
316 |
+
x = x[:, None]
|
317 |
+
y = np.column_stack([resample_poly(x[:, ch], up, down) for ch in range(x.shape[1])])
|
318 |
+
else:
|
319 |
+
y = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None]
|
320 |
|
321 |
+
# 3) Snap to exact frame count for loop-perfect length
|
322 |
seconds_per_bar = (60.0 / float(bpm)) * int(beats_per_bar)
|
323 |
+
expected_len = int(round(float(bars) * seconds_per_bar * target_sr))
|
324 |
+
|
325 |
+
if y.shape[0] < expected_len:
|
326 |
+
pad = np.zeros((expected_len - y.shape[0], y.shape[1]), dtype=y.dtype)
|
327 |
+
y = np.vstack([y, pad])
|
328 |
+
elif y.shape[0] > expected_len:
|
329 |
+
y = y[:expected_len, :]
|
330 |
+
|
331 |
+
total_samples = int(y.shape[0])
|
332 |
+
loop_duration_seconds = total_samples / float(target_sr)
|
333 |
|
334 |
+
# 4) Write y into buf as WAV @ target_sr
|
335 |
buf = io.BytesIO()
|
336 |
+
sf.write(buf, y, target_sr, subtype="FLOAT", format="WAV")
|
|
|
337 |
buf.seek(0)
|
338 |
audio_b64 = base64.b64encode(buf.read()).decode("utf-8")
|
339 |
|
340 |
+
# 5) Update metadata to be authoritative
|
341 |
+
metadata = {
|
342 |
+
"bpm": int(round(bpm)),
|
343 |
+
"bars": int(bars),
|
344 |
+
"beats_per_bar": int(beats_per_bar),
|
345 |
+
"styles": extra_styles,
|
346 |
+
"style_weights": weights,
|
347 |
+
"loop_weight": loop_weight,
|
348 |
+
"loudness": loud_stats,
|
349 |
+
"sample_rate": int(target_sr),
|
350 |
+
"channels": int(y.shape[1]),
|
351 |
+
"crossfade_seconds": mrt.config.crossfade_length,
|
352 |
+
"total_samples": total_samples,
|
353 |
+
"seconds_per_bar": seconds_per_bar,
|
354 |
+
"loop_duration_seconds": loop_duration_seconds,
|
355 |
+
"guidance_weight": guidance_weight,
|
356 |
+
"temperature": temperature,
|
357 |
+
"topk": topk,
|
|
|
|
|
|
|
|
|
|
|
|
|
358 |
}
|
359 |
+
return {"audio_base64": audio_b64, "metadata": metadata}
|
360 |
|
361 |
@app.get("/health")
|
362 |
def health():
|