Spaces:
Running
Running
from magenta_rt import system, audio as au | |
import numpy as np | |
from fastapi import FastAPI, UploadFile, File, Form, Body, HTTPException, Response | |
import tempfile, io, base64, math, threading | |
from fastapi.middleware.cors import CORSMiddleware | |
from contextlib import contextmanager | |
import soundfile as sf | |
import numpy as np | |
from math import gcd | |
from scipy.signal import resample_poly | |
from utils import ( | |
match_loudness_to_reference, stitch_generated, hard_trim_seconds, | |
apply_micro_fades, make_bar_aligned_context, take_bar_aligned_tail, | |
resample_and_snap, wav_bytes_base64 | |
) | |
from jam_worker import JamWorker, JamParams, JamChunk | |
import uuid, threading | |
jam_registry: dict[str, JamWorker] = {} | |
jam_lock = threading.Lock() | |
def mrt_overrides(mrt, **kwargs): | |
"""Temporarily set attributes on MRT if they exist; restore after.""" | |
old = {} | |
try: | |
for k, v in kwargs.items(): | |
if hasattr(mrt, k): | |
old[k] = getattr(mrt, k) | |
setattr(mrt, k, v) | |
yield | |
finally: | |
for k, v in old.items(): | |
setattr(mrt, k, v) | |
# loudness utils | |
try: | |
import pyloudnorm as pyln | |
_HAS_LOUDNORM = True | |
except Exception: | |
_HAS_LOUDNORM = False | |
# ---------------------------- | |
# Main generation (single combined style vector) | |
# ---------------------------- | |
def generate_loop_continuation_with_mrt( | |
mrt, | |
input_wav_path: str, | |
bpm: float, | |
extra_styles=None, | |
style_weights=None, | |
bars: int = 8, | |
beats_per_bar: int = 4, | |
loop_weight: float = 1.0, | |
loudness_mode: str = "auto", | |
loudness_headroom_db: float = 1.0, | |
intro_bars_to_drop: int = 0, # <— NEW | |
): | |
# Load & prep (unchanged) | |
loop = au.Waveform.from_file(input_wav_path).resample(mrt.sample_rate).as_stereo() | |
# Use tail for context (your recent change) | |
codec_fps = float(mrt.codec.frame_rate) | |
ctx_seconds = float(mrt.config.context_length_frames) / codec_fps | |
loop_for_context = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds) | |
tokens_full = mrt.codec.encode(loop_for_context).astype(np.int32) | |
tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth] | |
# Bar-aligned token window (unchanged) | |
context_tokens = make_bar_aligned_context( | |
tokens, bpm=bpm, fps=int(mrt.codec.frame_rate), | |
ctx_frames=mrt.config.context_length_frames, beats_per_bar=beats_per_bar | |
) | |
state = mrt.init_state() | |
state.context_tokens = context_tokens | |
# STYLE embed (optional: switch to loop_for_context if you want stronger “recent” bias) | |
loop_embed = mrt.embed_style(loop_for_context) | |
embeds, weights = [loop_embed], [float(loop_weight)] | |
if extra_styles: | |
for i, s in enumerate(extra_styles): | |
if s.strip(): | |
embeds.append(mrt.embed_style(s.strip())) | |
w = style_weights[i] if (style_weights and i < len(style_weights)) else 1.0 | |
weights.append(float(w)) | |
wsum = float(sum(weights)) or 1.0 | |
weights = [w / wsum for w in weights] | |
combined_style = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(loop_embed.dtype) | |
# --- Length math --- | |
seconds_per_bar = beats_per_bar * (60.0 / bpm) | |
total_secs = bars * seconds_per_bar | |
drop_bars = max(0, int(intro_bars_to_drop)) | |
drop_secs = min(drop_bars, bars) * seconds_per_bar # clamp to <= bars | |
gen_total_secs = total_secs + drop_secs # generate extra | |
# Chunk scheduling to cover gen_total_secs | |
chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate # ~2.0 | |
steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1 # pad then trim | |
# Generate | |
chunks = [] | |
for _ in range(steps): | |
wav, state = mrt.generate_chunk(state=state, style=combined_style) | |
chunks.append(wav) | |
# Stitch continuous audio | |
stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo() | |
# Trim to generated length (bars + dropped bars) | |
stitched = hard_trim_seconds(stitched, gen_total_secs) | |
# 👉 Drop the intro bars | |
if drop_secs > 0: | |
n_drop = int(round(drop_secs * stitched.sample_rate)) | |
stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate) | |
# Final exact-length trim to requested bars | |
out = hard_trim_seconds(stitched, total_secs) | |
# Final polish AFTER drop | |
out = out.peak_normalize(0.95) | |
apply_micro_fades(out, 5) | |
# Loudness match to input (after drop) so bar 1 sits right | |
out, loud_stats = match_loudness_to_reference( | |
ref=loop, target=out, | |
method=loudness_mode, headroom_db=loudness_headroom_db | |
) | |
return out, loud_stats | |
# ---------------------------- | |
# FastAPI app with lazy, thread-safe model init | |
# ---------------------------- | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # or lock to your domain(s) | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
_MRT = None | |
_MRT_LOCK = threading.Lock() | |
def get_mrt(): | |
global _MRT | |
if _MRT is None: | |
with _MRT_LOCK: | |
if _MRT is None: | |
_MRT = system.MagentaRT(tag="base", guidance_weight=1.0, device="gpu", lazy=False) | |
return _MRT | |
def generate( | |
loop_audio: UploadFile = File(...), | |
bpm: float = Form(...), | |
bars: int = Form(8), | |
beats_per_bar: int = Form(4), | |
styles: str = Form("acid house"), | |
style_weights: str = Form(""), | |
loop_weight: float = Form(1.0), | |
loudness_mode: str = Form("auto"), | |
loudness_headroom_db: float = Form(1.0), | |
guidance_weight: float = Form(5.0), | |
temperature: float = Form(1.1), | |
topk: int = Form(40), | |
target_sample_rate: int | None = Form(None), | |
intro_bars_to_drop: int = Form(0), # <— NEW | |
): | |
# Read file | |
data = loop_audio.file.read() | |
if not data: | |
return {"error": "Empty file"} | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: | |
tmp.write(data) | |
tmp_path = tmp.name | |
# Parse styles + weights | |
extra_styles = [s for s in (styles.split(",") if styles else []) if s.strip()] | |
weights = [float(x) for x in style_weights.split(",")] if style_weights else None | |
mrt = get_mrt() # warm once, in this worker thread | |
# Temporarily override MRT inference knobs for this request | |
with mrt_overrides(mrt, | |
guidance_weight=guidance_weight, | |
temperature=temperature, | |
topk=topk): | |
wav, loud_stats = generate_loop_continuation_with_mrt( | |
mrt, | |
input_wav_path=tmp_path, | |
bpm=bpm, | |
extra_styles=extra_styles, | |
style_weights=weights, | |
bars=bars, | |
beats_per_bar=beats_per_bar, | |
loop_weight=loop_weight, | |
loudness_mode=loudness_mode, | |
loudness_headroom_db=loudness_headroom_db, | |
intro_bars_to_drop=intro_bars_to_drop, # <— pass through | |
) | |
# 1) Figure out the desired SR | |
inp_info = sf.info(tmp_path) | |
input_sr = int(inp_info.samplerate) | |
target_sr = int(target_sample_rate or input_sr) | |
# 2) Convert to target SR + snap to exact bars | |
cur_sr = int(mrt.sample_rate) | |
x = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None] | |
seconds_per_bar = (60.0 / float(bpm)) * int(beats_per_bar) | |
expected_secs = float(bars) * seconds_per_bar | |
x = resample_and_snap(x, cur_sr=cur_sr, target_sr=target_sr, seconds=expected_secs) | |
# 3) Encode WAV once (no extra write) | |
audio_b64, total_samples, channels = wav_bytes_base64(x, target_sr) | |
loop_duration_seconds = total_samples / float(target_sr) | |
# 4) Metadata | |
metadata = { | |
"bpm": int(round(bpm)), | |
"bars": int(bars), | |
"beats_per_bar": int(beats_per_bar), | |
"styles": extra_styles, | |
"style_weights": weights, | |
"loop_weight": loop_weight, | |
"loudness": loud_stats, | |
"sample_rate": int(target_sr), | |
"channels": int(channels), | |
"crossfade_seconds": mrt.config.crossfade_length, | |
"total_samples": int(total_samples), | |
"seconds_per_bar": seconds_per_bar, | |
"loop_duration_seconds": loop_duration_seconds, | |
"guidance_weight": guidance_weight, | |
"temperature": temperature, | |
"topk": topk, | |
} | |
return {"audio_base64": audio_b64, "metadata": metadata} | |
# ---------------------------- | |
# the 'keep jamming' button | |
# ---------------------------- | |
def jam_start( | |
loop_audio: UploadFile = File(...), | |
bpm: float = Form(...), | |
bars_per_chunk: int = Form(4), | |
beats_per_bar: int = Form(4), | |
styles: str = Form(""), | |
style_weights: str = Form(""), | |
loop_weight: float = Form(1.0), | |
loudness_mode: str = Form("auto"), | |
loudness_headroom_db: float = Form(1.0), | |
guidance_weight: float = Form(1.1), | |
temperature: float = Form(1.1), | |
topk: int = Form(40), | |
target_sample_rate: int | None = Form(None), | |
): | |
# enforce single active jam per GPU | |
with jam_lock: | |
for sid, w in list(jam_registry.items()): | |
if w.is_alive(): | |
raise HTTPException(status_code=429, detail="A jam is already running. Try again later.") | |
# read input + prep context/style (reuse your existing code) | |
data = loop_audio.file.read() | |
if not data: raise HTTPException(status_code=400, detail="Empty file") | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: | |
tmp.write(data); tmp_path = tmp.name | |
mrt = get_mrt() | |
loop = au.Waveform.from_file(tmp_path).resample(mrt.sample_rate).as_stereo() | |
# build tail context + style vec (tail-biased) | |
codec_fps = float(mrt.codec.frame_rate) | |
ctx_seconds = float(mrt.config.context_length_frames) / codec_fps | |
loop_tail = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds) | |
# style vec = normalized mix of loop_tail + extra styles | |
embeds, weights = [mrt.embed_style(loop_tail)], [float(loop_weight)] | |
extra = [s for s in (styles.split(",") if styles else []) if s.strip()] | |
sw = [float(x) for x in style_weights.split(",")] if style_weights else [] | |
for i, s in enumerate(extra): | |
embeds.append(mrt.embed_style(s.strip())) | |
weights.append(sw[i] if i < len(sw) else 1.0) | |
wsum = sum(weights) or 1.0 | |
weights = [w / wsum for w in weights] | |
style_vec = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(embeds[0].dtype) | |
# target SR (default input SR) | |
inp_info = sf.info(tmp_path) | |
input_sr = int(inp_info.samplerate) | |
target_sr = int(target_sample_rate or input_sr) | |
params = JamParams( | |
bpm=bpm, | |
beats_per_bar=beats_per_bar, | |
bars_per_chunk=bars_per_chunk, | |
target_sr=target_sr, | |
loudness_mode=loudness_mode, | |
headroom_db=loudness_headroom_db, | |
style_vec=style_vec, | |
ref_loop=loop_tail, # For loudness matching | |
combined_loop=loop, # NEW: Full loop for context setup | |
guidance_weight=guidance_weight, | |
temperature=temperature, | |
topk=topk | |
) | |
worker = JamWorker(mrt, params) | |
sid = str(uuid.uuid4()) | |
with jam_lock: | |
jam_registry[sid] = worker | |
worker.start() | |
return {"session_id": sid} | |
def jam_next(session_id: str): | |
""" | |
Get the next sequential chunk in the jam session. | |
This ensures chunks are delivered in order without gaps. | |
""" | |
with jam_lock: | |
worker = jam_registry.get(session_id) | |
if worker is None or not worker.is_alive(): | |
raise HTTPException(status_code=404, detail="Session not found") | |
# Get the next sequential chunk (this blocks until ready) | |
chunk = worker.get_next_chunk() | |
if chunk is None: | |
raise HTTPException(status_code=408, detail="Chunk not ready within timeout") | |
return { | |
"chunk": { | |
"index": chunk.index, | |
"audio_base64": chunk.audio_base64, | |
"metadata": chunk.metadata | |
} | |
} | |
def jam_consume(session_id: str = Form(...), chunk_index: int = Form(...)): | |
""" | |
Mark a chunk as consumed by the frontend. | |
This helps the worker manage its buffer and generation flow. | |
""" | |
with jam_lock: | |
worker = jam_registry.get(session_id) | |
if worker is None or not worker.is_alive(): | |
raise HTTPException(status_code=404, detail="Session not found") | |
worker.mark_chunk_consumed(chunk_index) | |
return {"consumed": chunk_index} | |
def jam_stop(session_id: str = Body(..., embed=True)): | |
with jam_lock: | |
worker = jam_registry.get(session_id) | |
if worker is None: | |
raise HTTPException(status_code=404, detail="Session not found") | |
worker.stop() | |
worker.join(timeout=5.0) | |
if worker.is_alive(): | |
# It’s daemon=True, so it won’t block process exit, but report it | |
print(f"⚠️ JamWorker {session_id} did not stop within timeout") | |
with jam_lock: | |
jam_registry.pop(session_id, None) | |
return {"stopped": True} | |
def jam_update(session_id: str = Form(...), | |
guidance_weight: float | None = Form(None), | |
temperature: float | None = Form(None), | |
topk: int | None = Form(None)): | |
with jam_lock: | |
worker = jam_registry.get(session_id) | |
if worker is None or not worker.is_alive(): | |
raise HTTPException(status_code=404, detail="Session not found") | |
worker.update_knobs(guidance_weight=guidance_weight, temperature=temperature, topk=topk) | |
return {"ok": True} | |
def jam_status(session_id: str): | |
with jam_lock: | |
worker = jam_registry.get(session_id) | |
if worker is None: | |
raise HTTPException(status_code=404, detail="Session not found") | |
running = worker.is_alive() | |
# Snapshot safely | |
with worker._lock: | |
last_generated = int(worker.idx) | |
last_delivered = int(worker._last_delivered_index) | |
queued = len(worker.outbox) | |
buffer_ahead = last_generated - last_delivered | |
p = worker.params | |
spb = p.beats_per_bar * (60.0 / p.bpm) | |
chunk_secs = p.bars_per_chunk * spb | |
return { | |
"running": running, | |
"last_generated_index": last_generated, # Last chunk that finished generating | |
"last_delivered_index": last_delivered, # Last chunk sent to frontend | |
"buffer_ahead": buffer_ahead, # How many chunks ahead we are | |
"queued_chunks": queued, # Total chunks in outbox | |
"bpm": p.bpm, | |
"beats_per_bar": p.beats_per_bar, | |
"bars_per_chunk": p.bars_per_chunk, | |
"seconds_per_bar": spb, | |
"chunk_duration_seconds": chunk_secs, | |
"target_sample_rate": p.target_sr, | |
"last_chunk_started_at": worker.last_chunk_started_at, | |
"last_chunk_completed_at": worker.last_chunk_completed_at, | |
} | |
def health(): | |
return {"ok": True} |