thecollabagepatch commited on
Commit
f8b3793
·
1 Parent(s): 8cf69d0

ofc sample rates

Browse files
Files changed (1) hide show
  1. app.py +57 -32
app.py CHANGED
@@ -4,6 +4,10 @@ from fastapi import FastAPI, UploadFile, File, Form
4
  import tempfile, io, base64, math, threading
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from contextlib import contextmanager
 
 
 
 
7
 
8
  @contextmanager
9
  def mrt_overrides(mrt, **kwargs):
@@ -257,10 +261,10 @@ def generate(
257
  loop_weight: float = Form(1.0),
258
  loudness_mode: str = Form("auto"),
259
  loudness_headroom_db: float = Form(1.0),
260
- # NEW per-request knobs
261
  guidance_weight: float = Form(5.0),
262
  temperature: float = Form(1.1),
263
  topk: int = Form(40),
 
264
  ):
265
  # Read file
266
  data = loop_audio.file.read()
@@ -293,45 +297,66 @@ def generate(
293
  loudness_headroom_db=loudness_headroom_db,
294
  )
295
 
296
- total_samples = int(wav.samples.shape[0])
297
- sample_rate = int(get_mrt().sample_rate) # or mrt.sample_rate (same instance here)
298
- loop_duration_seconds = total_samples / float(sample_rate)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
- # Also include the bar math (useful for sanity checks downstream)
301
  seconds_per_bar = (60.0 / float(bpm)) * int(beats_per_bar)
 
 
 
 
 
 
 
 
 
 
302
 
303
- # Return base64 WAV + minimal metadata
304
  buf = io.BytesIO()
305
- # add format="WAV" when writing to a file-like object
306
- wav.write(buf, subtype="FLOAT", format="WAV")
307
  buf.seek(0)
308
  audio_b64 = base64.b64encode(buf.read()).decode("utf-8")
309
 
310
- return {
311
- "audio_base64": audio_b64,
312
- "metadata": {
313
- "bpm": int(round(bpm)),
314
- "bars": int(bars),
315
- "beats_per_bar": int(beats_per_bar),
316
- "styles": extra_styles,
317
- "style_weights": weights,
318
- "loop_weight": loop_weight,
319
- "loudness": loud_stats,
320
- "sample_rate": sample_rate,
321
- "channels": mrt.num_channels,
322
- "crossfade_seconds": mrt.config.crossfade_length,
323
-
324
- # New timing fields
325
- "total_samples": total_samples,
326
- "seconds_per_bar": seconds_per_bar,
327
- "loop_duration_seconds": loop_duration_seconds,
328
-
329
- # Echo the actual knobs used
330
- "guidance_weight": guidance_weight,
331
- "temperature": temperature,
332
- "topk": topk,
333
- },
334
  }
 
335
 
336
  @app.get("/health")
337
  def health():
 
4
  import tempfile, io, base64, math, threading
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from contextlib import contextmanager
7
+ import soundfile as sf
8
+ import numpy as np
9
+ from math import gcd
10
+ from scipy.signal import resample_poly
11
 
12
  @contextmanager
13
  def mrt_overrides(mrt, **kwargs):
 
261
  loop_weight: float = Form(1.0),
262
  loudness_mode: str = Form("auto"),
263
  loudness_headroom_db: float = Form(1.0),
 
264
  guidance_weight: float = Form(5.0),
265
  temperature: float = Form(1.1),
266
  topk: int = Form(40),
267
+ target_sample_rate: int | None = Form(None), # <-- add this
268
  ):
269
  # Read file
270
  data = loop_audio.file.read()
 
297
  loudness_headroom_db=loudness_headroom_db,
298
  )
299
 
300
+ # 1) Figure out the desired SR
301
+ inp_info = sf.info(tmp_path)
302
+ input_sr = int(inp_info.samplerate)
303
+ target_sr = int(target_sample_rate or input_sr)
304
+
305
+ # 2) Convert magenta output to target_sr if needed
306
+ # wav.samples: shape [num_samples, num_channels], float32/-1..1 (per your code)
307
+ cur_sr = int(mrt.sample_rate)
308
+ x = wav.samples # np.ndarray (S, C)
309
+
310
+ if cur_sr != target_sr:
311
+ g = gcd(cur_sr, target_sr)
312
+ up, down = target_sr // g, cur_sr // g
313
+ # ensure 2D shape (S, C)
314
+ x = wav.samples
315
+ if x.ndim == 1:
316
+ x = x[:, None]
317
+ y = np.column_stack([resample_poly(x[:, ch], up, down) for ch in range(x.shape[1])])
318
+ else:
319
+ y = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None]
320
 
321
+ # 3) Snap to exact frame count for loop-perfect length
322
  seconds_per_bar = (60.0 / float(bpm)) * int(beats_per_bar)
323
+ expected_len = int(round(float(bars) * seconds_per_bar * target_sr))
324
+
325
+ if y.shape[0] < expected_len:
326
+ pad = np.zeros((expected_len - y.shape[0], y.shape[1]), dtype=y.dtype)
327
+ y = np.vstack([y, pad])
328
+ elif y.shape[0] > expected_len:
329
+ y = y[:expected_len, :]
330
+
331
+ total_samples = int(y.shape[0])
332
+ loop_duration_seconds = total_samples / float(target_sr)
333
 
334
+ # 4) Write y into buf as WAV @ target_sr
335
  buf = io.BytesIO()
336
+ sf.write(buf, y, target_sr, subtype="FLOAT", format="WAV")
 
337
  buf.seek(0)
338
  audio_b64 = base64.b64encode(buf.read()).decode("utf-8")
339
 
340
+ # 5) Update metadata to be authoritative
341
+ metadata = {
342
+ "bpm": int(round(bpm)),
343
+ "bars": int(bars),
344
+ "beats_per_bar": int(beats_per_bar),
345
+ "styles": extra_styles,
346
+ "style_weights": weights,
347
+ "loop_weight": loop_weight,
348
+ "loudness": loud_stats,
349
+ "sample_rate": int(target_sr),
350
+ "channels": int(y.shape[1]),
351
+ "crossfade_seconds": mrt.config.crossfade_length,
352
+ "total_samples": total_samples,
353
+ "seconds_per_bar": seconds_per_bar,
354
+ "loop_duration_seconds": loop_duration_seconds,
355
+ "guidance_weight": guidance_weight,
356
+ "temperature": temperature,
357
+ "topk": topk,
 
 
 
 
 
 
358
  }
359
+ return {"audio_base64": audio_b64, "metadata": metadata}
360
 
361
  @app.get("/health")
362
  def health():