thecollabagepatch commited on
Commit
f70477a
·
1 Parent(s): b564cf9

lets try it

Browse files
Files changed (2) hide show
  1. Dockerfile +137 -0
  2. app.py +298 -0
Dockerfile ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # thecollabagepatch/magenta:latest
2
+ FROM nvidia/cuda:12.6.2-cudnn-runtime-ubuntu22.04
3
+
4
+ # CUDA libs present + on loader path
5
+ RUN apt-get update && apt-get install -y --no-install-recommends \
6
+ cuda-libraries-12-4 && rm -rf /var/lib/apt/lists/*
7
+ ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:/usr/local/cuda-12.4/lib64:/usr/local/cuda-12.4/compat:/usr/local/cuda/targets/x86_64-linux/lib:${LD_LIBRARY_PATH}
8
+ RUN ln -sf /usr/local/cuda/targets/x86_64-linux/lib /usr/local/cuda/lib64 || true
9
+
10
+ # Ensure the NVIDIA repo key is present (non-interactive) and install cuDNN 9.8
11
+ RUN set -eux; \
12
+ apt-get update && apt-get install -y --no-install-recommends gnupg ca-certificates curl; \
13
+ install -d -m 0755 /usr/share/keyrings; \
14
+ # Refresh the *same* keyring the base source uses (no second source file)
15
+ curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/3bf863cc.pub \
16
+ | gpg --batch --yes --dearmor -o /usr/share/keyrings/cuda-archive-keyring.gpg; \
17
+ apt-get update; \
18
+ # If libcudnn is "held", unhold it so we can move to 9.8
19
+ apt-mark unhold libcudnn9-cuda-12 || true; \
20
+ # Install cuDNN 9.8 for CUDA 12 (correct dev package name!)
21
+ apt-get install -y --no-install-recommends \
22
+ 'libcudnn9-cuda-12=9.8.*' \
23
+ 'libcudnn9-dev-cuda-12=9.8.*' \
24
+ --allow-downgrades --allow-change-held-packages; \
25
+ apt-mark hold libcudnn9-cuda-12 || true; \
26
+ ldconfig; \
27
+ rm -rf /var/lib/apt/lists/*
28
+
29
+ # (optional) preload workaround if still needed
30
+ ENV LD_PRELOAD=/usr/local/cuda/lib64/libcusparse.so.12:/usr/local/cuda/lib64/libcublas.so.12:/usr/local/cuda/lib64/libcublasLt.so.12:/usr/local/cuda/lib64/libcufft.so.11:/usr/local/cuda/lib64/libcusolver.so.11
31
+
32
+ ENV DEBIAN_FRONTEND=noninteractive \
33
+ PYTHONUNBUFFERED=1 \
34
+ PIP_NO_CACHE_DIR=1 \
35
+ TF_FORCE_GPU_ALLOW_GROWTH=true \
36
+ XLA_PYTHON_CLIENT_PREALLOCATE=false
37
+
38
+ ENV JAX_PLATFORMS=""
39
+
40
+ # --- OS deps ---
41
+ RUN apt-get update && apt-get install -y --no-install-recommends \
42
+ software-properties-common curl ca-certificates git \
43
+ libsndfile1 ffmpeg \
44
+ build-essential pkg-config \
45
+ && add-apt-repository ppa:deadsnakes/ppa -y \
46
+ && apt-get update && apt-get install -y --no-install-recommends \
47
+ python3.11 python3.11-venv python3.11-distutils python3-pip \
48
+ && rm -rf /var/lib/apt/lists/*
49
+
50
+ # Make python3 => 3.11 for convenience
51
+ RUN ln -sf /usr/bin/python3.11 /usr/bin/python && python -m pip install --upgrade pip
52
+
53
+ # --- Python deps (pin order matters!) ---
54
+ # 1) JAX CUDA pins
55
+ RUN python -m pip install "jax[cuda12]==0.6.2" "jaxlib==0.6.2"
56
+
57
+ # 2) Lock seqio early to avoid backtracking madness
58
+ RUN python -m pip install "seqio==0.0.11"
59
+
60
+ # 3) Install Magenta RT *without* deps so we control pins
61
+ RUN python -m pip install --no-deps 'git+https://github.com/magenta/magenta-realtime#egg=magenta_rt[gpu]'
62
+
63
+ # 4) TF nightlies (MATCH DATES!)
64
+ RUN python -m pip install \
65
+ "tf_nightly==2.20.0.dev20250619" \
66
+ "tensorflow-text-nightly==2.20.0.dev20250316" \
67
+ "tf-hub-nightly"
68
+
69
+ # 5) tf2jax pinned alongside tf_nightly so pip doesn’t drag stable TF
70
+ RUN python -m pip install tf2jax "tf_nightly==2.20.0.dev20250619"
71
+
72
+ # 6) The rest of MRT deps + API runtime deps
73
+ RUN python -m pip install \
74
+ gin-config librosa resampy soundfile \
75
+ google-auth google-auth-oauthlib google-auth-httplib2 \
76
+ google-api-core googleapis-common-protos google-resumable-media \
77
+ google-cloud-storage requests tqdm typing-extensions numpy==2.1.3 \
78
+ fastapi uvicorn[standard] python-multipart pyloudnorm
79
+
80
+ # 7) Exact commits for T5X/Flaxformer as in pyproject
81
+ RUN python -m pip install \
82
+ "t5x @ git+https://github.com/google-research/t5x.git@92c5b46" \
83
+ "flaxformer @ git+https://github.com/google/flaxformer@399ea3a"
84
+
85
+ # ---- FINAL: enforce TF nightlies and clean any stable TF ----
86
+ RUN python - <<'PY'
87
+ import sys, sysconfig, glob, os, shutil
88
+ # Find a writable site dir (site-packages OR dist-packages)
89
+ cands = [sysconfig.get_paths().get('purelib'), sysconfig.get_paths().get('platlib')]
90
+ cands += [p for p in sys.path if p and p.endswith(('site-packages','dist-packages'))]
91
+ site = next(p for p in cands if p and os.path.isdir(p))
92
+
93
+ patterns = [
94
+ "tensorflow", "tensorflow-*.dist-info", "tensorflow-*.egg-info",
95
+ "tf-nightly-*.dist-info", "tf_nightly-*.dist-info",
96
+ "tensorflow_text", "tensorflow_text-*.dist-info",
97
+ "tf-hub-nightly-*.dist-info", "tf_hub_nightly-*.dist-info",
98
+ "tf_keras-nightly-*.dist-info", "tf_keras_nightly-*.dist-info",
99
+ "tensorboard*", "tb-nightly-*.dist-info",
100
+ "keras*", # remove stray keras
101
+ "tensorflow_hub*", "tensorflow_io*",
102
+ ]
103
+ for pat in patterns:
104
+ for path in glob.glob(os.path.join(site, pat)):
105
+ if os.path.isdir(path): shutil.rmtree(path, ignore_errors=True)
106
+ else:
107
+ try: os.remove(path)
108
+ except FileNotFoundError: pass
109
+
110
+ print("TF/Hub/Text cleared in:", site)
111
+ PY
112
+
113
+ # Reinstall pinned nightlies in ONE transaction
114
+ RUN python -m pip install --no-cache-dir --force-reinstall \
115
+ "tf-nightly==2.20.0.dev20250619" \
116
+ "tensorflow-text-nightly==2.20.0.dev20250316" \
117
+ "tf-hub-nightly"
118
+
119
+ RUN python -m pip install huggingface_hub
120
+
121
+ RUN python -m pip install --no-cache-dir --force-reinstall "protobuf==4.25.3"
122
+
123
+ # Switch to Spaces’ preferred user
124
+ RUN useradd -m -u 1000 appuser
125
+ RUN mkdir -p /home/appuser/app && chown -R appuser:appuser /home/appuser
126
+ WORKDIR /home/appuser/app
127
+ # keep app under the user’s home (optional)
128
+ COPY --chown=appuser:appuser /srv/app/app.py /home/appuser/app/app.py
129
+
130
+ USER appuser
131
+
132
+
133
+ # expose Spaces’ default
134
+ EXPOSE 7860
135
+
136
+ # respect HF’s PORT env var (falls back to 7860 if not set)
137
+ CMD ["bash", "-lc", "python -m uvicorn app:app --host 0.0.0.0 --port ${PORT:-7860}"]
app.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from magenta_rt import system, audio as au
2
+ import numpy as np
3
+ from fastapi import FastAPI, UploadFile, File, Form
4
+ import tempfile, io, base64, math, threading
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+
7
+ # loudness utils
8
+ try:
9
+ import pyloudnorm as pyln
10
+ _HAS_LOUDNORM = True
11
+ except Exception:
12
+ _HAS_LOUDNORM = False
13
+
14
+ def _measure_lufs(wav: au.Waveform) -> float:
15
+ # pyloudnorm expects float32/float64, shape (n,) or (n, ch)
16
+ meter = pyln.Meter(wav.sample_rate) # defaults to BS.1770-4
17
+ return float(meter.integrated_loudness(wav.samples))
18
+
19
+ def _rms(x: np.ndarray) -> float:
20
+ if x.size == 0: return 0.0
21
+ return float(np.sqrt(np.mean(x**2)))
22
+
23
+ def match_loudness_to_reference(
24
+ ref: au.Waveform,
25
+ target: au.Waveform,
26
+ method: str = "auto", # "auto"|"lufs"|"rms"|"none"
27
+ headroom_db: float = 1.0
28
+ ) -> tuple[au.Waveform, dict]:
29
+ """
30
+ Scales `target` to match `ref` loudness. Returns (adjusted_wave, stats).
31
+ """
32
+ stats = {"method": method, "applied_gain_db": 0.0}
33
+
34
+ if method == "none":
35
+ return target, stats
36
+
37
+ if method == "auto":
38
+ method = "lufs" if _HAS_LOUDNORM else "rms"
39
+
40
+ if method == "lufs" and _HAS_LOUDNORM:
41
+ L_ref = _measure_lufs(ref)
42
+ L_tgt = _measure_lufs(target)
43
+ delta_db = L_ref - L_tgt
44
+ gain = 10.0 ** (delta_db / 20.0)
45
+ y = target.samples.astype(np.float32) * gain
46
+ stats.update({"ref_lufs": L_ref, "tgt_lufs_before": L_tgt, "applied_gain_db": delta_db})
47
+ else:
48
+ # RMS fallback
49
+ ra = _rms(ref.samples)
50
+ rb = _rms(target.samples)
51
+ if rb <= 1e-12:
52
+ return target, stats
53
+ gain = ra / rb
54
+ y = target.samples.astype(np.float32) * gain
55
+ stats.update({"ref_rms": ra, "tgt_rms_before": rb, "applied_gain_db": 20*np.log10(max(gain,1e-12))})
56
+
57
+ # simple peak “limiter” to keep headroom
58
+ limit = 10 ** (-headroom_db / 20.0) # e.g., -1 dBFS
59
+ peak = float(np.max(np.abs(y))) if y.size else 0.0
60
+ if peak > limit:
61
+ y *= (limit / peak)
62
+ stats["post_peak_limited"] = True
63
+ else:
64
+ stats["post_peak_limited"] = False
65
+
66
+ target.samples = y.astype(np.float32)
67
+ return target, stats
68
+
69
+ # ----------------------------
70
+ # Crossfade stitch (your good path)
71
+ # ----------------------------
72
+ def stitch_generated(chunks, sr, xfade_s):
73
+ if not chunks:
74
+ raise ValueError("no chunks")
75
+ xfade_n = int(round(xfade_s * sr))
76
+ if xfade_n <= 0:
77
+ return au.Waveform(np.concatenate([c.samples for c in chunks], axis=0), sr)
78
+
79
+ t = np.linspace(0, np.pi/2, xfade_n, endpoint=False, dtype=np.float32)
80
+ eq_in, eq_out = np.sin(t)[:, None], np.cos(t)[:, None]
81
+
82
+ first = chunks[0].samples
83
+ if first.shape[0] < xfade_n:
84
+ raise ValueError("chunk shorter than crossfade prefix")
85
+ out = first[xfade_n:].copy() # drop model pre-roll
86
+
87
+ for i in range(1, len(chunks)):
88
+ cur = chunks[i].samples
89
+ if cur.shape[0] < xfade_n:
90
+ continue
91
+ head, tail = cur[:xfade_n], cur[xfade_n:]
92
+ mixed = out[-xfade_n:] * eq_out + head * eq_in
93
+ out = np.concatenate([out[:-xfade_n], mixed, tail], axis=0)
94
+
95
+ return au.Waveform(out, sr)
96
+
97
+ # ----------------------------
98
+ # Bar-aligned token context
99
+ # ----------------------------
100
+ def make_bar_aligned_context(tokens, bpm, fps=25, ctx_frames=250, beats_per_bar=4):
101
+ frames_per_bar_f = (beats_per_bar * 60.0 / bpm) * fps
102
+ frames_per_bar = int(round(frames_per_bar_f))
103
+ if abs(frames_per_bar - frames_per_bar_f) > 1e-3:
104
+ reps = int(np.ceil(ctx_frames / len(tokens)))
105
+ return np.tile(tokens, (reps, 1))[-ctx_frames:]
106
+ reps = int(np.ceil(ctx_frames / len(tokens)))
107
+ tiled = np.tile(tokens, (reps, 1))
108
+ end = (len(tiled) // frames_per_bar) * frames_per_bar
109
+ if end < ctx_frames:
110
+ return tiled[-ctx_frames:]
111
+ start = end - ctx_frames
112
+ return tiled[start:end]
113
+
114
+ def hard_trim_seconds(wav: au.Waveform, seconds: float) -> au.Waveform:
115
+ n = int(round(seconds * wav.sample_rate))
116
+ return au.Waveform(wav.samples[:n], wav.sample_rate)
117
+
118
+ def apply_micro_fades(wav: au.Waveform, ms: int = 5) -> None:
119
+ n = int(wav.sample_rate * ms / 1000.0)
120
+ if n > 0 and wav.samples.shape[0] > 2*n:
121
+ env = np.linspace(0.0, 1.0, n, dtype=np.float32)[:, None]
122
+ wav.samples[:n] *= env
123
+ wav.samples[-n:] *= env[::-1]
124
+
125
+ # ----------------------------
126
+ # Main generation (single combined style vector)
127
+ # ----------------------------
128
+ def generate_loop_continuation_with_mrt(
129
+ mrt,
130
+ input_wav_path: str,
131
+ bpm: float,
132
+ extra_styles=None,
133
+ style_weights=None,
134
+ bars: int = 8,
135
+ beats_per_bar: int = 4,
136
+ loop_weight: float = 1.0, # NEW
137
+ loudness_mode: str = "auto", # "auto"|"lufs"|"rms"|"none"
138
+ loudness_headroom_db: float = 1.0, # for the peak guard
139
+ ):
140
+ # Load loop & encode
141
+ loop = au.Waveform.from_file(input_wav_path).resample(mrt.sample_rate).as_stereo()
142
+ tokens_full = mrt.codec.encode(loop).astype(np.int32)
143
+ tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth]
144
+
145
+ # Context
146
+ context_tokens = make_bar_aligned_context(
147
+ tokens,
148
+ bpm=bpm,
149
+ fps=int(mrt.codec.frame_rate),
150
+ ctx_frames=mrt.config.context_length_frames,
151
+ beats_per_bar=beats_per_bar,
152
+ )
153
+ state = mrt.init_state()
154
+ state.context_tokens = context_tokens
155
+
156
+ # ---------- STYLE: weighted avg into ONE vector ----------
157
+ # Base embed from loop with adjustable loop_weight
158
+ embeds = []
159
+ weights = []
160
+
161
+ # loop embedding
162
+ loop_embed = mrt.embed_style(loop)
163
+ embeds.append(loop_embed)
164
+ weights.append(float(loop_weight)) # <--- use requested loop weight
165
+
166
+ # extra styles
167
+ if extra_styles:
168
+ for i, s in enumerate(extra_styles):
169
+ if s.strip():
170
+ embeds.append(mrt.embed_style(s.strip()))
171
+ w = style_weights[i] if (style_weights and i < len(style_weights)) else 1.0
172
+ weights.append(float(w))
173
+
174
+ # Prevent all-zero weights; normalize
175
+ wsum = float(sum(weights))
176
+ if wsum <= 0.0:
177
+ # fallback: rely on loop to avoid NaNs
178
+ weights = [1.0] + [0.0] * (len(weights) - 1)
179
+ wsum = 1.0
180
+
181
+ weights = [w / wsum for w in weights]
182
+
183
+ # weighted sum -> single style vector (match dtype)
184
+ combined_style = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(loop_embed.dtype)
185
+
186
+ # Chunks to cover exact bars
187
+ seconds_per_bar = beats_per_bar * (60.0 / bpm)
188
+ total_secs = bars * seconds_per_bar
189
+ chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate # ~2.0
190
+ steps = int(math.ceil(total_secs / chunk_secs)) + 1 # pad then trim
191
+
192
+ # Generate
193
+ chunks = []
194
+ for _ in range(steps):
195
+ wav, state = mrt.generate_chunk(state=state, style=combined_style) # ONE style vector
196
+ chunks.append(wav)
197
+
198
+ # Stitch -> trim -> polish
199
+ out = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo()
200
+ out = hard_trim_seconds(out, total_secs).peak_normalize(0.95)
201
+ apply_micro_fades(out, 5)
202
+ # Loudness match to the *input loop* so the return level feels consistent
203
+ out, loud_stats = match_loudness_to_reference(
204
+ ref=loop, target=out,
205
+ method=loudness_mode,
206
+ headroom_db=loudness_headroom_db,
207
+ )
208
+ return out, loud_stats
209
+
210
+ # ----------------------------
211
+ # FastAPI app with lazy, thread-safe model init
212
+ # ----------------------------
213
+ app = FastAPI()
214
+
215
+ app.add_middleware(
216
+ CORSMiddleware,
217
+ allow_origins=["*"], # or lock to your domain(s)
218
+ allow_credentials=True,
219
+ allow_methods=["*"],
220
+ allow_headers=["*"],
221
+ )
222
+
223
+ _MRT = None
224
+ _MRT_LOCK = threading.Lock()
225
+
226
+ def get_mrt():
227
+ global _MRT
228
+ if _MRT is None:
229
+ with _MRT_LOCK:
230
+ if _MRT is None:
231
+ _MRT = system.MagentaRT(tag="base", guidance_weight=1.0, device="gpu", lazy=False)
232
+ return _MRT
233
+
234
+ @app.post("/generate")
235
+ def generate(
236
+ loop_audio: UploadFile = File(...),
237
+ bpm: float = Form(...),
238
+ bars: int = Form(8),
239
+ beats_per_bar: int = Form(4),
240
+ styles: str = Form("acid house"),
241
+ style_weights: str = Form(""),
242
+ loop_weight: float = Form(1.0), # NEW
243
+ loudness_mode: str = Form("auto"), # NEW
244
+ loudness_headroom_db: float = Form(1.0), # NEW
245
+ ):
246
+ # Read file
247
+ data = loop_audio.file.read()
248
+ if not data:
249
+ return {"error": "Empty file"}
250
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
251
+ tmp.write(data)
252
+ tmp_path = tmp.name
253
+
254
+ # Parse styles + weights
255
+ extra_styles = [s for s in (styles.split(",") if styles else []) if s.strip()]
256
+ weights = [float(x) for x in style_weights.split(",")] if style_weights else None
257
+
258
+ mrt = get_mrt() # warm once, in this worker thread
259
+ mrt = get_mrt()
260
+ wav, loud_stats = generate_loop_continuation_with_mrt(
261
+ mrt,
262
+ input_wav_path=tmp_path,
263
+ bpm=bpm,
264
+ extra_styles=extra_styles,
265
+ style_weights=weights,
266
+ bars=bars,
267
+ beats_per_bar=beats_per_bar,
268
+ loop_weight=loop_weight,
269
+ loudness_mode=loudness_mode,
270
+ loudness_headroom_db=loudness_headroom_db,
271
+ )
272
+
273
+ # Return base64 WAV + minimal metadata
274
+ buf = io.BytesIO()
275
+ # add format="WAV" when writing to a file-like object
276
+ wav.write(buf, subtype="FLOAT", format="WAV")
277
+ buf.seek(0)
278
+ audio_b64 = base64.b64encode(buf.read()).decode("utf-8")
279
+
280
+ return {
281
+ "audio_base64": audio_b64,
282
+ "metadata": {
283
+ "bpm": int(round(bpm)),
284
+ "bars": int(bars),
285
+ "beats_per_bar": int(beats_per_bar),
286
+ "styles": extra_styles,
287
+ "style_weights": weights,
288
+ "loop_weight": loop_weight,
289
+ "loudness": loud_stats, # NEW
290
+ "sample_rate": mrt.sample_rate,
291
+ "channels": mrt.num_channels,
292
+ "crossfade_seconds": mrt.config.crossfade_length,
293
+ },
294
+ }
295
+
296
+ @app.get("/health")
297
+ def health():
298
+ return {"ok": True}