Commit
·
87147f5
1
Parent(s):
56dfd15
warmup
Browse files
app.py
CHANGED
@@ -15,6 +15,8 @@ from utils import (
|
|
15 |
|
16 |
from jam_worker import JamWorker, JamParams, JamChunk
|
17 |
import uuid, threading
|
|
|
|
|
18 |
|
19 |
import gradio as gr
|
20 |
from typing import Optional
|
@@ -358,6 +360,82 @@ def get_mrt():
|
|
358 |
_MRT = system.MagentaRT(tag="large", guidance_weight=5.0, device="gpu", lazy=False)
|
359 |
return _MRT
|
360 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
361 |
@app.post("/generate")
|
362 |
def generate(
|
363 |
loop_audio: UploadFile = File(...),
|
|
|
15 |
|
16 |
from jam_worker import JamWorker, JamParams, JamChunk
|
17 |
import uuid, threading
|
18 |
+
import os
|
19 |
+
import logging
|
20 |
|
21 |
import gradio as gr
|
22 |
from typing import Optional
|
|
|
360 |
_MRT = system.MagentaRT(tag="large", guidance_weight=5.0, device="gpu", lazy=False)
|
361 |
return _MRT
|
362 |
|
363 |
+
_WARMED = False
|
364 |
+
_WARMUP_LOCK = threading.Lock()
|
365 |
+
|
366 |
+
def _mrt_warmup():
|
367 |
+
"""
|
368 |
+
Build a minimal, bar-aligned silent context and run one 2s generate_chunk
|
369 |
+
to trigger XLA JIT & autotune so first real request is fast.
|
370 |
+
"""
|
371 |
+
global _WARMED
|
372 |
+
with _WARMUP_LOCK:
|
373 |
+
if _WARMED:
|
374 |
+
return
|
375 |
+
try:
|
376 |
+
mrt = get_mrt()
|
377 |
+
|
378 |
+
# --- derive timing from model config ---
|
379 |
+
codec_fps = float(mrt.codec.frame_rate)
|
380 |
+
ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
|
381 |
+
sr = int(mrt.sample_rate)
|
382 |
+
|
383 |
+
# We'll align to 120 BPM, 4/4, and generate one ~2s chunk
|
384 |
+
bpm = 120.0
|
385 |
+
beats_per_bar = 4
|
386 |
+
|
387 |
+
# --- build a silent, stereo context of ctx_seconds ---
|
388 |
+
import numpy as np, soundfile as sf
|
389 |
+
samples = int(max(1, round(ctx_seconds * sr)))
|
390 |
+
silent = np.zeros((samples, 2), dtype=np.float32)
|
391 |
+
|
392 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
393 |
+
sf.write(tmp.name, silent, sr, subtype="PCM_16")
|
394 |
+
tmp_path = tmp.name
|
395 |
+
|
396 |
+
try:
|
397 |
+
# Load as Waveform and take a tail of exactly ctx_seconds
|
398 |
+
loop = au.Waveform.from_file(tmp_path).resample(sr).as_stereo()
|
399 |
+
seconds_per_bar = beats_per_bar * (60.0 / bpm)
|
400 |
+
ctx_tail = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds)
|
401 |
+
|
402 |
+
# Tokens for context window
|
403 |
+
tokens_full = mrt.codec.encode(ctx_tail).astype(np.int32)
|
404 |
+
tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth]
|
405 |
+
context_tokens = make_bar_aligned_context(
|
406 |
+
tokens,
|
407 |
+
bpm=bpm,
|
408 |
+
fps=int(mrt.codec.frame_rate),
|
409 |
+
ctx_frames=mrt.config.context_length_frames,
|
410 |
+
beats_per_bar=beats_per_bar,
|
411 |
+
)
|
412 |
+
|
413 |
+
# Init state and a basic style vector (text token is fine)
|
414 |
+
state = mrt.init_state()
|
415 |
+
state.context_tokens = context_tokens
|
416 |
+
style_vec = mrt.embed_style("warmup")
|
417 |
+
|
418 |
+
# --- one throwaway chunk (~2s) ---
|
419 |
+
_wav, _state = mrt.generate_chunk(state=state, style=style_vec)
|
420 |
+
|
421 |
+
logging.info("MagentaRT warmup complete.")
|
422 |
+
finally:
|
423 |
+
try:
|
424 |
+
os.unlink(tmp_path)
|
425 |
+
except Exception:
|
426 |
+
pass
|
427 |
+
|
428 |
+
_WARMED = True
|
429 |
+
except Exception as e:
|
430 |
+
# Never crash on warmup errors; log and continue serving
|
431 |
+
logging.exception("MagentaRT warmup failed (continuing without warmup): %s", e)
|
432 |
+
|
433 |
+
# Kick it off in the background on server start
|
434 |
+
@app.on_event("startup")
|
435 |
+
def _kickoff_warmup():
|
436 |
+
if os.getenv("MRT_WARMUP", "1") != "0":
|
437 |
+
threading.Thread(target=_mrt_warmup, name="mrt-warmup", daemon=True).start()
|
438 |
+
|
439 |
@app.post("/generate")
|
440 |
def generate(
|
441 |
loop_audio: UploadFile = File(...),
|