Spaces:
Running
Running
File size: 5,849 Bytes
956f1a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
# utils.py
from __future__ import annotations
import io, base64, math
from math import gcd
import numpy as np
import soundfile as sf
from scipy.signal import resample_poly
# Magenta RT audio types
from magenta_rt import audio as au
# Optional loudness
try:
import pyloudnorm as pyln
_HAS_LOUDNORM = True
except Exception:
_HAS_LOUDNORM = False
# ---------- Loudness ----------
def _measure_lufs(wav: au.Waveform) -> float:
meter = pyln.Meter(wav.sample_rate) # BS.1770-4
return float(meter.integrated_loudness(wav.samples))
def _rms(x: np.ndarray) -> float:
if x.size == 0: return 0.0
return float(np.sqrt(np.mean(x**2)))
def match_loudness_to_reference(
ref: au.Waveform,
target: au.Waveform,
method: str = "auto", # "auto"|"lufs"|"rms"|"none"
headroom_db: float = 1.0
) -> tuple[au.Waveform, dict]:
stats = {"method": method, "applied_gain_db": 0.0}
if method == "none":
return target, stats
if method == "auto":
method = "lufs" if _HAS_LOUDNORM else "rms"
if method == "lufs" and _HAS_LOUDNORM:
L_ref = _measure_lufs(ref)
L_tgt = _measure_lufs(target)
delta_db = L_ref - L_tgt
gain = 10.0 ** (delta_db / 20.0)
y = target.samples.astype(np.float32) * gain
stats.update({"ref_lufs": L_ref, "tgt_lufs_before": L_tgt, "applied_gain_db": delta_db})
else:
ra = _rms(ref.samples)
rb = _rms(target.samples)
if rb <= 1e-12:
return target, stats
gain = ra / rb
y = target.samples.astype(np.float32) * gain
stats.update({"ref_rms": ra, "tgt_rms_before": rb, "applied_gain_db": 20*np.log10(max(gain,1e-12))})
# simple peak “limiter” to keep headroom
limit = 10 ** (-headroom_db / 20.0) # e.g., -1 dBFS
peak = float(np.max(np.abs(y))) if y.size else 0.0
if peak > limit:
y *= (limit / peak)
stats["post_peak_limited"] = True
else:
stats["post_peak_limited"] = False
target.samples = y.astype(np.float32)
return target, stats
# ---------- Stitch / fades / trims ----------
def stitch_generated(chunks, sr: int, xfade_s: float) -> au.Waveform:
if not chunks:
raise ValueError("no chunks")
xfade_n = int(round(xfade_s * sr))
if xfade_n <= 0:
return au.Waveform(np.concatenate([c.samples for c in chunks], axis=0), sr)
t = np.linspace(0, np.pi/2, xfade_n, endpoint=False, dtype=np.float32)
eq_in, eq_out = np.sin(t)[:, None], np.cos(t)[:, None]
first = chunks[0].samples
if first.shape[0] < xfade_n:
raise ValueError("chunk shorter than crossfade prefix")
out = first[xfade_n:].copy() # drop model pre-roll
for i in range(1, len(chunks)):
cur = chunks[i].samples
if cur.shape[0] < xfade_n:
continue
head, tail = cur[:xfade_n], cur[xfade_n:]
mixed = out[-xfade_n:] * eq_out + head * eq_in
out = np.concatenate([out[:-xfade_n], mixed, tail], axis=0)
return au.Waveform(out, sr)
def hard_trim_seconds(wav: au.Waveform, seconds: float) -> au.Waveform:
n = int(round(seconds * wav.sample_rate))
return au.Waveform(wav.samples[:n], wav.sample_rate)
def apply_micro_fades(wav: au.Waveform, ms: int = 5) -> None:
n = int(wav.sample_rate * ms / 1000.0)
if n > 0 and wav.samples.shape[0] > 2*n:
env = np.linspace(0.0, 1.0, n, dtype=np.float32)[:, None]
wav.samples[:n] *= env
wav.samples[-n:] *= env[::-1]
# ---------- Token context helpers ----------
def make_bar_aligned_context(tokens, bpm, fps=25, ctx_frames=250, beats_per_bar=4):
frames_per_bar_f = (beats_per_bar * 60.0 / bpm) * fps
frames_per_bar = int(round(frames_per_bar_f))
if abs(frames_per_bar - frames_per_bar_f) > 1e-3:
reps = int(np.ceil(ctx_frames / len(tokens)))
return np.tile(tokens, (reps, 1))[-ctx_frames:]
reps = int(np.ceil(ctx_frames / len(tokens)))
tiled = np.tile(tokens, (reps, 1))
end = (len(tiled) // frames_per_bar) * frames_per_bar
if end < ctx_frames:
return tiled[-ctx_frames:]
start = end - ctx_frames
return tiled[start:end]
def take_bar_aligned_tail(wav: au.Waveform, bpm: float, beats_per_bar: int, ctx_seconds: float, max_bars=None) -> au.Waveform:
spb = (60.0 / bpm) * beats_per_bar
bars_needed = max(1, int(round(ctx_seconds / spb)))
if max_bars is not None:
bars_needed = min(bars_needed, max_bars)
tail_seconds = bars_needed * spb
n = int(round(tail_seconds * wav.sample_rate))
if n >= wav.samples.shape[0]:
return wav
return au.Waveform(wav.samples[-n:], wav.sample_rate)
# ---------- SR normalize + snap ----------
def resample_and_snap(x: np.ndarray, cur_sr: int, target_sr: int, seconds: float) -> np.ndarray:
"""
x: np.ndarray shape (S, C), float32
Returns: exact-length array (round(seconds*target_sr), C)
"""
if x.ndim == 1:
x = x[:, None]
if cur_sr != target_sr:
g = gcd(cur_sr, target_sr)
up, down = target_sr // g, cur_sr // g
x = resample_poly(x, up, down, axis=0)
expected_len = int(round(seconds * target_sr))
if x.shape[0] < expected_len:
pad = np.zeros((expected_len - x.shape[0], x.shape[1]), dtype=x.dtype)
x = np.vstack([x, pad])
elif x.shape[0] > expected_len:
x = x[:expected_len, :]
return x.astype(np.float32, copy=False)
# ---------- WAV encode ----------
def wav_bytes_base64(x: np.ndarray, sr: int) -> tuple[str, int, int]:
"""
x: np.ndarray shape (S, C)
returns: (base64_wav, total_samples, channels)
"""
buf = io.BytesIO()
sf.write(buf, x, sr, subtype="FLOAT", format="WAV")
buf.seek(0)
b64 = base64.b64encode(buf.read()).decode("utf-8")
return b64, int(x.shape[0]), int(x.shape[1])
|