magenta-retry / utils.py
thecollabagepatch's picture
oh boy bar boundaries is hard
d41a575
raw
history blame
7.94 kB
# utils.py
from __future__ import annotations
import io, base64, math
from math import gcd
import numpy as np
import soundfile as sf
from scipy.signal import resample_poly
# Magenta RT audio types
from magenta_rt import audio as au
# Optional loudness
try:
import pyloudnorm as pyln
_HAS_LOUDNORM = True
except Exception:
_HAS_LOUDNORM = False
# ---------- Loudness ----------
def _measure_lufs(wav: au.Waveform) -> float:
meter = pyln.Meter(wav.sample_rate) # BS.1770-4
return float(meter.integrated_loudness(wav.samples))
def _rms(x: np.ndarray) -> float:
if x.size == 0: return 0.0
return float(np.sqrt(np.mean(x**2)))
def match_loudness_to_reference(
ref: au.Waveform,
target: au.Waveform,
method: str = "auto", # "auto"|"lufs"|"rms"|"none"
headroom_db: float = 1.0
) -> tuple[au.Waveform, dict]:
stats = {"method": method, "applied_gain_db": 0.0}
if method == "none":
return target, stats
if method == "auto":
method = "lufs" if _HAS_LOUDNORM else "rms"
if method == "lufs" and _HAS_LOUDNORM:
L_ref = _measure_lufs(ref)
L_tgt = _measure_lufs(target)
delta_db = L_ref - L_tgt
gain = 10.0 ** (delta_db / 20.0)
y = target.samples.astype(np.float32) * gain
stats.update({"ref_lufs": L_ref, "tgt_lufs_before": L_tgt, "applied_gain_db": delta_db})
else:
ra = _rms(ref.samples)
rb = _rms(target.samples)
if rb <= 1e-12:
return target, stats
gain = ra / rb
y = target.samples.astype(np.float32) * gain
stats.update({"ref_rms": ra, "tgt_rms_before": rb, "applied_gain_db": 20*np.log10(max(gain,1e-12))})
# simple peak “limiter” to keep headroom
limit = 10 ** (-headroom_db / 20.0) # e.g., -1 dBFS
peak = float(np.max(np.abs(y))) if y.size else 0.0
if peak > limit:
y *= (limit / peak)
stats["post_peak_limited"] = True
else:
stats["post_peak_limited"] = False
target.samples = y.astype(np.float32)
return target, stats
# ---------- Stitch / fades / trims ----------
def stitch_generated(chunks, sr: int, xfade_s: float, drop_first_pre_roll: bool = True):
if not chunks:
raise ValueError("no chunks")
xfade_n = int(round(xfade_s * sr))
if xfade_n <= 0:
return au.Waveform(np.concatenate([c.samples for c in chunks], axis=0), sr)
t = np.linspace(0, np.pi/2, xfade_n, endpoint=False, dtype=np.float32)
eq_in, eq_out = np.sin(t)[:, None], np.cos(t)[:, None]
first = chunks[0].samples
if first.shape[0] < xfade_n:
raise ValueError("chunk shorter than crossfade prefix")
# 🔧 key change:
out = first[xfade_n:].copy() if drop_first_pre_roll else first.copy()
for i in range(1, len(chunks)):
cur = chunks[i].samples
if cur.shape[0] < xfade_n:
continue
head, tail = cur[:xfade_n], cur[xfade_n:]
mixed = out[-xfade_n:] * eq_out + head * eq_in
out = np.concatenate([out[:-xfade_n], mixed, tail], axis=0)
return au.Waveform(out, sr)
def hard_trim_seconds(wav: au.Waveform, seconds: float) -> au.Waveform:
n = int(round(seconds * wav.sample_rate))
return au.Waveform(wav.samples[:n], wav.sample_rate)
def apply_micro_fades(wav: au.Waveform, ms: int = 5) -> None:
n = int(wav.sample_rate * ms / 1000.0)
if n > 0 and wav.samples.shape[0] > 2*n:
env = np.linspace(0.0, 1.0, n, dtype=np.float32)[:, None]
wav.samples[:n] *= env
wav.samples[-n:] *= env[::-1]
# ---------- Token context helpers ----------
def make_bar_aligned_context(tokens, bpm, fps=25.0, ctx_frames=250, beats_per_bar=4):
"""
Return a ctx_frames-long slice of `tokens` whose **end** lands on the nearest
whole-bar boundary in codec-frame space, even when frames_per_bar is fractional.
tokens: np.ndarray of shape (T, D) or (T,) where T = codec frames
bpm: float
fps: float (codec frames per second; keep this as float)
ctx_frames: int (length of context window in codec frames)
beats_per_bar: int
"""
if tokens is None:
raise ValueError("tokens is None")
tokens = np.asarray(tokens)
if tokens.ndim == 1:
tokens = tokens[:, None] # promote to (T, 1) for uniform tiling
T = tokens.shape[0]
if T == 0:
return tokens
fps = float(fps)
frames_per_bar_f = (beats_per_bar * 60.0 / float(bpm)) * fps # float frames per bar
# Tile a little more than we need so we can always snap the END to a bar boundary
reps = int(np.ceil((ctx_frames + T) / float(T))) + 1
tiled = np.tile(tokens, (reps, 1))
total = tiled.shape[0]
# How many whole bars fit?
k_bars = int(np.floor(total / frames_per_bar_f))
if k_bars <= 0:
# Fallback: just take the last ctx_frames
window = tiled[-ctx_frames:]
return window
# Snap END index to the nearest integer frame at a whole-bar boundary
end_idx = int(round(k_bars * frames_per_bar_f))
end_idx = min(max(end_idx, ctx_frames), total)
start_idx = end_idx - ctx_frames
if start_idx < 0:
start_idx = 0
end_idx = ctx_frames
window = tiled[start_idx:end_idx]
# Guard against rare off-by-one due to rounding
if window.shape[0] < ctx_frames:
pad = np.tile(tokens, (int(np.ceil((ctx_frames - window.shape[0]) / T)), 1))
window = np.vstack([window, pad])[:ctx_frames]
elif window.shape[0] > ctx_frames:
window = window[-ctx_frames:]
return window
def take_bar_aligned_tail(
wav: au.Waveform,
bpm: float,
beats_per_bar: int,
ctx_seconds: float,
max_bars=None
) -> au.Waveform:
"""
Take a tail whose length is an integer number of bars, with the END aligned
to a bar boundary. Uses ceil for bars_needed so we never under-fill the context.
"""
import math
# seconds per bar
spb = (60.0 / float(bpm)) * float(beats_per_bar)
# Pick enough whole bars to cover ctx_seconds (avoid underfilling on round-down).
# The small epsilon avoids an extra bar due to FP jitter when ctx_seconds ~= k * spb.
eps = 1e-9
bars_needed = max(1, int(math.ceil((float(ctx_seconds) - eps) / spb)))
if max_bars is not None:
bars_needed = min(bars_needed, int(max_bars))
# Convert bars -> samples (do rounding once at the end for stability)
samples_per_bar_f = spb * float(wav.sample_rate)
n = int(round(bars_needed * samples_per_bar_f))
total = int(wav.samples.shape[0])
if n >= total:
# Not enough audio to take that many bars—return as-is (current behavior).
return wav
start = total - n
return au.Waveform(wav.samples[start:], wav.sample_rate)
# ---------- SR normalize + snap ----------
def resample_and_snap(x: np.ndarray, cur_sr: int, target_sr: int, seconds: float) -> np.ndarray:
"""
x: np.ndarray shape (S, C), float32
Returns: exact-length array (round(seconds*target_sr), C)
"""
if x.ndim == 1:
x = x[:, None]
if cur_sr != target_sr:
g = gcd(cur_sr, target_sr)
up, down = target_sr // g, cur_sr // g
x = resample_poly(x, up, down, axis=0)
expected_len = int(round(seconds * target_sr))
if x.shape[0] < expected_len:
pad = np.zeros((expected_len - x.shape[0], x.shape[1]), dtype=x.dtype)
x = np.vstack([x, pad])
elif x.shape[0] > expected_len:
x = x[:expected_len, :]
return x.astype(np.float32, copy=False)
# ---------- WAV encode ----------
def wav_bytes_base64(x: np.ndarray, sr: int) -> tuple[str, int, int]:
"""
x: np.ndarray shape (S, C)
returns: (base64_wav, total_samples, channels)
"""
buf = io.BytesIO()
sf.write(buf, x, sr, subtype="FLOAT", format="WAV")
buf.seek(0)
b64 = base64.b64encode(buf.read()).decode("utf-8")
return b64, int(x.shape[0]), int(x.shape[1])