Spaces:
Running
Running
from magenta_rt import system, audio as au | |
import numpy as np | |
from fastapi import FastAPI, UploadFile, File, Form | |
import tempfile, io, base64, math, threading | |
from fastapi.middleware.cors import CORSMiddleware | |
# loudness utils | |
try: | |
import pyloudnorm as pyln | |
_HAS_LOUDNORM = True | |
except Exception: | |
_HAS_LOUDNORM = False | |
def _measure_lufs(wav: au.Waveform) -> float: | |
# pyloudnorm expects float32/float64, shape (n,) or (n, ch) | |
meter = pyln.Meter(wav.sample_rate) # defaults to BS.1770-4 | |
return float(meter.integrated_loudness(wav.samples)) | |
def _rms(x: np.ndarray) -> float: | |
if x.size == 0: return 0.0 | |
return float(np.sqrt(np.mean(x**2))) | |
def match_loudness_to_reference( | |
ref: au.Waveform, | |
target: au.Waveform, | |
method: str = "auto", # "auto"|"lufs"|"rms"|"none" | |
headroom_db: float = 1.0 | |
) -> tuple[au.Waveform, dict]: | |
""" | |
Scales `target` to match `ref` loudness. Returns (adjusted_wave, stats). | |
""" | |
stats = {"method": method, "applied_gain_db": 0.0} | |
if method == "none": | |
return target, stats | |
if method == "auto": | |
method = "lufs" if _HAS_LOUDNORM else "rms" | |
if method == "lufs" and _HAS_LOUDNORM: | |
L_ref = _measure_lufs(ref) | |
L_tgt = _measure_lufs(target) | |
delta_db = L_ref - L_tgt | |
gain = 10.0 ** (delta_db / 20.0) | |
y = target.samples.astype(np.float32) * gain | |
stats.update({"ref_lufs": L_ref, "tgt_lufs_before": L_tgt, "applied_gain_db": delta_db}) | |
else: | |
# RMS fallback | |
ra = _rms(ref.samples) | |
rb = _rms(target.samples) | |
if rb <= 1e-12: | |
return target, stats | |
gain = ra / rb | |
y = target.samples.astype(np.float32) * gain | |
stats.update({"ref_rms": ra, "tgt_rms_before": rb, "applied_gain_db": 20*np.log10(max(gain,1e-12))}) | |
# simple peak “limiter” to keep headroom | |
limit = 10 ** (-headroom_db / 20.0) # e.g., -1 dBFS | |
peak = float(np.max(np.abs(y))) if y.size else 0.0 | |
if peak > limit: | |
y *= (limit / peak) | |
stats["post_peak_limited"] = True | |
else: | |
stats["post_peak_limited"] = False | |
target.samples = y.astype(np.float32) | |
return target, stats | |
# ---------------------------- | |
# Crossfade stitch (your good path) | |
# ---------------------------- | |
def stitch_generated(chunks, sr, xfade_s): | |
if not chunks: | |
raise ValueError("no chunks") | |
xfade_n = int(round(xfade_s * sr)) | |
if xfade_n <= 0: | |
return au.Waveform(np.concatenate([c.samples for c in chunks], axis=0), sr) | |
t = np.linspace(0, np.pi/2, xfade_n, endpoint=False, dtype=np.float32) | |
eq_in, eq_out = np.sin(t)[:, None], np.cos(t)[:, None] | |
first = chunks[0].samples | |
if first.shape[0] < xfade_n: | |
raise ValueError("chunk shorter than crossfade prefix") | |
out = first[xfade_n:].copy() # drop model pre-roll | |
for i in range(1, len(chunks)): | |
cur = chunks[i].samples | |
if cur.shape[0] < xfade_n: | |
continue | |
head, tail = cur[:xfade_n], cur[xfade_n:] | |
mixed = out[-xfade_n:] * eq_out + head * eq_in | |
out = np.concatenate([out[:-xfade_n], mixed, tail], axis=0) | |
return au.Waveform(out, sr) | |
# ---------------------------- | |
# Bar-aligned token context | |
# ---------------------------- | |
def make_bar_aligned_context(tokens, bpm, fps=25, ctx_frames=250, beats_per_bar=4): | |
frames_per_bar_f = (beats_per_bar * 60.0 / bpm) * fps | |
frames_per_bar = int(round(frames_per_bar_f)) | |
if abs(frames_per_bar - frames_per_bar_f) > 1e-3: | |
reps = int(np.ceil(ctx_frames / len(tokens))) | |
return np.tile(tokens, (reps, 1))[-ctx_frames:] | |
reps = int(np.ceil(ctx_frames / len(tokens))) | |
tiled = np.tile(tokens, (reps, 1)) | |
end = (len(tiled) // frames_per_bar) * frames_per_bar | |
if end < ctx_frames: | |
return tiled[-ctx_frames:] | |
start = end - ctx_frames | |
return tiled[start:end] | |
def hard_trim_seconds(wav: au.Waveform, seconds: float) -> au.Waveform: | |
n = int(round(seconds * wav.sample_rate)) | |
return au.Waveform(wav.samples[:n], wav.sample_rate) | |
def apply_micro_fades(wav: au.Waveform, ms: int = 5) -> None: | |
n = int(wav.sample_rate * ms / 1000.0) | |
if n > 0 and wav.samples.shape[0] > 2*n: | |
env = np.linspace(0.0, 1.0, n, dtype=np.float32)[:, None] | |
wav.samples[:n] *= env | |
wav.samples[-n:] *= env[::-1] | |
# ---------------------------- | |
# Main generation (single combined style vector) | |
# ---------------------------- | |
def generate_loop_continuation_with_mrt( | |
mrt, | |
input_wav_path: str, | |
bpm: float, | |
extra_styles=None, | |
style_weights=None, | |
bars: int = 8, | |
beats_per_bar: int = 4, | |
loop_weight: float = 1.0, # NEW | |
loudness_mode: str = "auto", # "auto"|"lufs"|"rms"|"none" | |
loudness_headroom_db: float = 1.0, # for the peak guard | |
): | |
# Load loop & encode | |
loop = au.Waveform.from_file(input_wav_path).resample(mrt.sample_rate).as_stereo() | |
tokens_full = mrt.codec.encode(loop).astype(np.int32) | |
tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth] | |
# Context | |
context_tokens = make_bar_aligned_context( | |
tokens, | |
bpm=bpm, | |
fps=int(mrt.codec.frame_rate), | |
ctx_frames=mrt.config.context_length_frames, | |
beats_per_bar=beats_per_bar, | |
) | |
state = mrt.init_state() | |
state.context_tokens = context_tokens | |
# ---------- STYLE: weighted avg into ONE vector ---------- | |
# Base embed from loop with adjustable loop_weight | |
embeds = [] | |
weights = [] | |
# loop embedding | |
loop_embed = mrt.embed_style(loop) | |
embeds.append(loop_embed) | |
weights.append(float(loop_weight)) # <--- use requested loop weight | |
# extra styles | |
if extra_styles: | |
for i, s in enumerate(extra_styles): | |
if s.strip(): | |
embeds.append(mrt.embed_style(s.strip())) | |
w = style_weights[i] if (style_weights and i < len(style_weights)) else 1.0 | |
weights.append(float(w)) | |
# Prevent all-zero weights; normalize | |
wsum = float(sum(weights)) | |
if wsum <= 0.0: | |
# fallback: rely on loop to avoid NaNs | |
weights = [1.0] + [0.0] * (len(weights) - 1) | |
wsum = 1.0 | |
weights = [w / wsum for w in weights] | |
# weighted sum -> single style vector (match dtype) | |
combined_style = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(loop_embed.dtype) | |
# Chunks to cover exact bars | |
seconds_per_bar = beats_per_bar * (60.0 / bpm) | |
total_secs = bars * seconds_per_bar | |
chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate # ~2.0 | |
steps = int(math.ceil(total_secs / chunk_secs)) + 1 # pad then trim | |
# Generate | |
chunks = [] | |
for _ in range(steps): | |
wav, state = mrt.generate_chunk(state=state, style=combined_style) # ONE style vector | |
chunks.append(wav) | |
# Stitch -> trim -> polish | |
out = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo() | |
out = hard_trim_seconds(out, total_secs).peak_normalize(0.95) | |
apply_micro_fades(out, 5) | |
# Loudness match to the *input loop* so the return level feels consistent | |
out, loud_stats = match_loudness_to_reference( | |
ref=loop, target=out, | |
method=loudness_mode, | |
headroom_db=loudness_headroom_db, | |
) | |
return out, loud_stats | |
# ---------------------------- | |
# FastAPI app with lazy, thread-safe model init | |
# ---------------------------- | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # or lock to your domain(s) | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
_MRT = None | |
_MRT_LOCK = threading.Lock() | |
def get_mrt(): | |
global _MRT | |
if _MRT is None: | |
with _MRT_LOCK: | |
if _MRT is None: | |
_MRT = system.MagentaRT(tag="base", guidance_weight=1.0, device="gpu", lazy=False) | |
return _MRT | |
def generate( | |
loop_audio: UploadFile = File(...), | |
bpm: float = Form(...), | |
bars: int = Form(8), | |
beats_per_bar: int = Form(4), | |
styles: str = Form("acid house"), | |
style_weights: str = Form(""), | |
loop_weight: float = Form(1.0), # NEW | |
loudness_mode: str = Form("auto"), # NEW | |
loudness_headroom_db: float = Form(1.0), # NEW | |
): | |
# Read file | |
data = loop_audio.file.read() | |
if not data: | |
return {"error": "Empty file"} | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: | |
tmp.write(data) | |
tmp_path = tmp.name | |
# Parse styles + weights | |
extra_styles = [s for s in (styles.split(",") if styles else []) if s.strip()] | |
weights = [float(x) for x in style_weights.split(",")] if style_weights else None | |
mrt = get_mrt() # warm once, in this worker thread | |
mrt = get_mrt() | |
wav, loud_stats = generate_loop_continuation_with_mrt( | |
mrt, | |
input_wav_path=tmp_path, | |
bpm=bpm, | |
extra_styles=extra_styles, | |
style_weights=weights, | |
bars=bars, | |
beats_per_bar=beats_per_bar, | |
loop_weight=loop_weight, | |
loudness_mode=loudness_mode, | |
loudness_headroom_db=loudness_headroom_db, | |
) | |
# Return base64 WAV + minimal metadata | |
buf = io.BytesIO() | |
# add format="WAV" when writing to a file-like object | |
wav.write(buf, subtype="FLOAT", format="WAV") | |
buf.seek(0) | |
audio_b64 = base64.b64encode(buf.read()).decode("utf-8") | |
return { | |
"audio_base64": audio_b64, | |
"metadata": { | |
"bpm": int(round(bpm)), | |
"bars": int(bars), | |
"beats_per_bar": int(beats_per_bar), | |
"styles": extra_styles, | |
"style_weights": weights, | |
"loop_weight": loop_weight, | |
"loudness": loud_stats, # NEW | |
"sample_rate": mrt.sample_rate, | |
"channels": mrt.num_channels, | |
"crossfade_seconds": mrt.config.crossfade_length, | |
}, | |
} | |
def health(): | |
return {"ok": True} |