thecollabagepatch commited on
Commit
ede0049
·
1 Parent(s): 159b8b1

expose more params

Browse files
Files changed (1) hide show
  1. app.py +44 -17
app.py CHANGED
@@ -3,6 +3,21 @@ import numpy as np
3
  from fastapi import FastAPI, UploadFile, File, Form
4
  import tempfile, io, base64, math, threading
5
  from fastapi.middleware.cors import CORSMiddleware
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  # loudness utils
8
  try:
@@ -239,9 +254,13 @@ def generate(
239
  beats_per_bar: int = Form(4),
240
  styles: str = Form("acid house"),
241
  style_weights: str = Form(""),
242
- loop_weight: float = Form(1.0), # NEW
243
- loudness_mode: str = Form("auto"), # NEW
244
- loudness_headroom_db: float = Form(1.0), # NEW
 
 
 
 
245
  ):
246
  # Read file
247
  data = loop_audio.file.read()
@@ -256,19 +275,23 @@ def generate(
256
  weights = [float(x) for x in style_weights.split(",")] if style_weights else None
257
 
258
  mrt = get_mrt() # warm once, in this worker thread
259
- mrt = get_mrt()
260
- wav, loud_stats = generate_loop_continuation_with_mrt(
261
- mrt,
262
- input_wav_path=tmp_path,
263
- bpm=bpm,
264
- extra_styles=extra_styles,
265
- style_weights=weights,
266
- bars=bars,
267
- beats_per_bar=beats_per_bar,
268
- loop_weight=loop_weight,
269
- loudness_mode=loudness_mode,
270
- loudness_headroom_db=loudness_headroom_db,
271
- )
 
 
 
 
272
 
273
  # Return base64 WAV + minimal metadata
274
  buf = io.BytesIO()
@@ -286,10 +309,14 @@ def generate(
286
  "styles": extra_styles,
287
  "style_weights": weights,
288
  "loop_weight": loop_weight,
289
- "loudness": loud_stats, # NEW
290
  "sample_rate": mrt.sample_rate,
291
  "channels": mrt.num_channels,
292
  "crossfade_seconds": mrt.config.crossfade_length,
 
 
 
 
293
  },
294
  }
295
 
 
3
  from fastapi import FastAPI, UploadFile, File, Form
4
  import tempfile, io, base64, math, threading
5
  from fastapi.middleware.cors import CORSMiddleware
6
+ from contextlib import contextmanager
7
+
8
+ @contextmanager
9
+ def mrt_overrides(mrt, **kwargs):
10
+ """Temporarily set attributes on MRT if they exist; restore after."""
11
+ old = {}
12
+ try:
13
+ for k, v in kwargs.items():
14
+ if hasattr(mrt, k):
15
+ old[k] = getattr(mrt, k)
16
+ setattr(mrt, k, v)
17
+ yield
18
+ finally:
19
+ for k, v in old.items():
20
+ setattr(mrt, k, v)
21
 
22
  # loudness utils
23
  try:
 
254
  beats_per_bar: int = Form(4),
255
  styles: str = Form("acid house"),
256
  style_weights: str = Form(""),
257
+ loop_weight: float = Form(1.0),
258
+ loudness_mode: str = Form("auto"),
259
+ loudness_headroom_db: float = Form(1.0),
260
+ # NEW per-request knobs
261
+ guidance_weight: float = Form(5.0),
262
+ temperature: float = Form(1.1),
263
+ topk: int = Form(40),
264
  ):
265
  # Read file
266
  data = loop_audio.file.read()
 
275
  weights = [float(x) for x in style_weights.split(",")] if style_weights else None
276
 
277
  mrt = get_mrt() # warm once, in this worker thread
278
+ # Temporarily override MRT inference knobs for this request
279
+ with mrt_overrides(mrt,
280
+ guidance_weight=guidance_weight,
281
+ temperature=temperature,
282
+ topk=topk):
283
+ wav, loud_stats = generate_loop_continuation_with_mrt(
284
+ mrt,
285
+ input_wav_path=tmp_path,
286
+ bpm=bpm,
287
+ extra_styles=extra_styles,
288
+ style_weights=weights,
289
+ bars=bars,
290
+ beats_per_bar=beats_per_bar,
291
+ loop_weight=loop_weight,
292
+ loudness_mode=loudness_mode,
293
+ loudness_headroom_db=loudness_headroom_db,
294
+ )
295
 
296
  # Return base64 WAV + minimal metadata
297
  buf = io.BytesIO()
 
309
  "styles": extra_styles,
310
  "style_weights": weights,
311
  "loop_weight": loop_weight,
312
+ "loudness": loud_stats,
313
  "sample_rate": mrt.sample_rate,
314
  "channels": mrt.num_channels,
315
  "crossfade_seconds": mrt.config.crossfade_length,
316
+ # Echo the actual knobs used
317
+ "guidance_weight": guidance_weight,
318
+ "temperature": temperature,
319
+ "topk": topk,
320
  },
321
  }
322