thecollabagepatch commited on
Commit
87147f5
·
1 Parent(s): 56dfd15
Files changed (1) hide show
  1. app.py +78 -0
app.py CHANGED
@@ -15,6 +15,8 @@ from utils import (
15
 
16
  from jam_worker import JamWorker, JamParams, JamChunk
17
  import uuid, threading
 
 
18
 
19
  import gradio as gr
20
  from typing import Optional
@@ -358,6 +360,82 @@ def get_mrt():
358
  _MRT = system.MagentaRT(tag="large", guidance_weight=5.0, device="gpu", lazy=False)
359
  return _MRT
360
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  @app.post("/generate")
362
  def generate(
363
  loop_audio: UploadFile = File(...),
 
15
 
16
  from jam_worker import JamWorker, JamParams, JamChunk
17
  import uuid, threading
18
+ import os
19
+ import logging
20
 
21
  import gradio as gr
22
  from typing import Optional
 
360
  _MRT = system.MagentaRT(tag="large", guidance_weight=5.0, device="gpu", lazy=False)
361
  return _MRT
362
 
363
+ _WARMED = False
364
+ _WARMUP_LOCK = threading.Lock()
365
+
366
+ def _mrt_warmup():
367
+ """
368
+ Build a minimal, bar-aligned silent context and run one 2s generate_chunk
369
+ to trigger XLA JIT & autotune so first real request is fast.
370
+ """
371
+ global _WARMED
372
+ with _WARMUP_LOCK:
373
+ if _WARMED:
374
+ return
375
+ try:
376
+ mrt = get_mrt()
377
+
378
+ # --- derive timing from model config ---
379
+ codec_fps = float(mrt.codec.frame_rate)
380
+ ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
381
+ sr = int(mrt.sample_rate)
382
+
383
+ # We'll align to 120 BPM, 4/4, and generate one ~2s chunk
384
+ bpm = 120.0
385
+ beats_per_bar = 4
386
+
387
+ # --- build a silent, stereo context of ctx_seconds ---
388
+ import numpy as np, soundfile as sf
389
+ samples = int(max(1, round(ctx_seconds * sr)))
390
+ silent = np.zeros((samples, 2), dtype=np.float32)
391
+
392
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
393
+ sf.write(tmp.name, silent, sr, subtype="PCM_16")
394
+ tmp_path = tmp.name
395
+
396
+ try:
397
+ # Load as Waveform and take a tail of exactly ctx_seconds
398
+ loop = au.Waveform.from_file(tmp_path).resample(sr).as_stereo()
399
+ seconds_per_bar = beats_per_bar * (60.0 / bpm)
400
+ ctx_tail = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds)
401
+
402
+ # Tokens for context window
403
+ tokens_full = mrt.codec.encode(ctx_tail).astype(np.int32)
404
+ tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth]
405
+ context_tokens = make_bar_aligned_context(
406
+ tokens,
407
+ bpm=bpm,
408
+ fps=int(mrt.codec.frame_rate),
409
+ ctx_frames=mrt.config.context_length_frames,
410
+ beats_per_bar=beats_per_bar,
411
+ )
412
+
413
+ # Init state and a basic style vector (text token is fine)
414
+ state = mrt.init_state()
415
+ state.context_tokens = context_tokens
416
+ style_vec = mrt.embed_style("warmup")
417
+
418
+ # --- one throwaway chunk (~2s) ---
419
+ _wav, _state = mrt.generate_chunk(state=state, style=style_vec)
420
+
421
+ logging.info("MagentaRT warmup complete.")
422
+ finally:
423
+ try:
424
+ os.unlink(tmp_path)
425
+ except Exception:
426
+ pass
427
+
428
+ _WARMED = True
429
+ except Exception as e:
430
+ # Never crash on warmup errors; log and continue serving
431
+ logging.exception("MagentaRT warmup failed (continuing without warmup): %s", e)
432
+
433
+ # Kick it off in the background on server start
434
+ @app.on_event("startup")
435
+ def _kickoff_warmup():
436
+ if os.getenv("MRT_WARMUP", "1") != "0":
437
+ threading.Thread(target=_mrt_warmup, name="mrt-warmup", daemon=True).start()
438
+
439
  @app.post("/generate")
440
  def generate(
441
  loop_audio: UploadFile = File(...),