ZDingman's picture
Update app.py
1430d22 verified
raw
history blame
4.94 kB
# app.py
# Zack's Audio Outpost — AI Noise Reducer (SpeechBrain MetricGAN+)
# CPU-friendly; provides Light/Medium/Strong wet mix and Original vs Processed.
import os
from typing import Tuple
import gradio as gr
import numpy as np
import torch
import torchaudio
from speechbrain.pretrained import SpectralMaskEnhancement
# -----------------------------
# Config
# -----------------------------
TARGET_SR = 16_000 # MetricGAN+ expects 16 kHz mono
# Wet/dry mix by "strength"
MIX_BY_STRENGTH = {
"Light": 0.50, # 50% wet
"Medium": 0.75, # 75% wet
"Strong": 1.00, # 100% wet
}
# ✅ Correct SpeechBrain model for SpectralMaskEnhancement
MODEL_SOURCE = "speechbrain/metricgan-plus-voicebank"
MODEL_DIR = "pretrained_models/metricgan-plus-voicebank"
# Global enhancer (loaded once)
_enhancer: SpectralMaskEnhancement | None = None
def get_enhancer() -> SpectralMaskEnhancement:
"""Lazy-load the SpeechBrain enhancer once."""
global _enhancer
if _enhancer is None:
_enhancer = SpectralMaskEnhancement.from_hparams(
source=MODEL_SOURCE,
savedir=MODEL_DIR,
)
_enhancer.mods.eval()
torch.set_grad_enabled(False)
return _enhancer
# -----------------------------
# Audio helpers
# -----------------------------
def to_mono(x: np.ndarray) -> np.ndarray:
"""
Ensure mono. Accepts shapes:
- (time,) already mono
- (time, channels) -> average channels
- (channels, time) -> average channels, return (time,)
Returns float32 -1..1
"""
if x.ndim == 1:
y = x
elif x.ndim == 2:
if x.shape[0] < x.shape[1]:
# (channels, time)
y = x.mean(axis=0)
else:
# (time, channels)
y = x.mean(axis=1)
else:
raise ValueError("Unsupported audio shape; expected 1D or 2D ndarray")
return y.astype(np.float32, copy=False)
def resample_to_16k_mono(x: np.ndarray, sr: int) -> torch.Tensor:
"""
Numpy -> torch (1, time) @ 16 kHz mono, float32 in [-1, 1]
"""
mono = to_mono(x)
wav = torch.from_numpy(mono) # (time,)
if sr != TARGET_SR:
wav = torchaudio.functional.resample(wav, sr, TARGET_SR)
return wav.unsqueeze(0) # (1, time)
# -----------------------------
# Core processing
# -----------------------------
@torch.no_grad()
def denoise_numpy(audio: Tuple[int, np.ndarray], strength: str) -> Tuple[Tuple[int, np.ndarray], Tuple[int, np.ndarray]]:
"""
Gradio callback.
Input:
audio: (sr, numpy waveform)
strength: "Light" | "Medium" | "Strong"
Output:
(original_sr, original_wav), (TARGET_SR, processed_wav)
Both as float32 in [-1, 1]
"""
if audio is None:
return None, None
in_sr, in_wav = audio
if in_wav is None or in_wav.size == 0:
return None, None
in_wav = in_wav.astype(np.float32, copy=False)
# Prepare input for model (mono, 16k)
wav16 = resample_to_16k_mono(in_wav, in_sr) # torch (1, time)
# SpeechBrain expects relative lengths tensor (batch-size == 1)
lengths = torch.tensor([1.0])
# Enhance
enhancer = get_enhancer()
enhanced = enhancer.enhance_batch(wav16, lengths=lengths).squeeze(0) # (time,)
dry = wav16.squeeze(0)
# Wet/dry mix
mix = MIX_BY_STRENGTH.get(strength, MIX_BY_STRENGTH["Medium"])
out = dry * (1.0 - mix) + enhanced * mix
# Clamp & back to numpy
y = torch.clamp(out, -1.0, 1.0).cpu().numpy().astype(np.float32)
# Return original (mono copy for consistent playback) + processed @16k
original = (in_sr, to_mono(in_wav))
processed = (TARGET_SR, y)
return original, processed
# -----------------------------
# UI
# -----------------------------
CSS = """
.gradio-container { max-width: 1100px !important; }
#title { font-weight: 700; font-size: 1.4rem; margin-bottom: .25rem; }
#subtitle { opacity: .8; margin-bottom: .75rem; }
"""
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
gr.HTML('<div id="title">Zack’s Audio Outpost — AI Noise Reducer</div>')
gr.HTML('<div id="subtitle">Upload a file and compare <b>Original</b> vs <b>Processed</b>.</div>')
with gr.Row():
audio_in = gr.Audio(
sources=["upload"],
type="numpy", # returns (sr, np.ndarray)
label="Upload Audio",
# show_controls is deprecated; we leave default controls on
)
strength = gr.Radio(
choices=["Light", "Medium", "Strong"],
value="Medium",
label="Noise Reduction Strength",
)
btn = gr.Button("Run Noise Reduction", variant="primary")
with gr.Row():
out_orig = gr.Audio(type="numpy", label="Original")
out_proc = gr.Audio(type="numpy", label="Processed")
btn.click(denoise_numpy, inputs=[audio_in, strength], outputs=[out_orig, out_proc])
if __name__ == "__main__":
demo.launch()