ZDingman's picture
Update app.py
276b212 verified
import os
from typing import Tuple
import numpy as np
import torch
import torchaudio
import gradio as gr
from speechbrain.pretrained import SpectralMaskEnhancement
# ----------------------------
# Constants / Globals
# ----------------------------
TITLE = "Zack's Audio Outpost — Voice Denoiser"
DESCRIPTION = """
Upload a short audio clip with speech (mono or stereo).
Choose **Light**, **Medium**, or **Strong** reduction and compare **Original vs Processed**.
"""
TARGET_SR = 16000 # MetricGAN+ VoiceBank expects 16 kHz
MIX_BY_STRENGTH = {"Light": 0.35, "Medium": 0.65, "Strong": 0.9}
_enhancer = None # lazy-loaded model
# ----------------------------
# Model Loader (cached)
# ----------------------------
def get_enhancer() -> SpectralMaskEnhancement:
"""
Loads SpeechBrain MetricGAN+ VoiceBank denoiser (once) and caches it.
Runs on CPU inside Spaces by default.
"""
global _enhancer
if _enhancer is None:
_enhancer = SpectralMaskEnhancement.from_hparams(
source="speechbrain/metricgan-plus-voicebank",
savedir="pretrained_models/metricgan-plus-voicebank",
run_opts={"device": "cpu"},
)
# Put underlying nn.Module in eval mode
_enhancer.mods.eval()
return _enhancer
# ----------------------------
# Audio utilities
# ----------------------------
def to_mono(x: np.ndarray) -> np.ndarray:
"""
Normalize shape consistently to mono float32 in [-1, 1].
Accepts:
(time,) -> mono
(time, channels) -> last dim is channels
(channels, time) -> first dim is channels
"""
if x.ndim == 1:
y = x
elif x.ndim == 2:
t, c = x.shape
# If last dim is 1 or 2, treat as (time, ch)
if c in (1, 2) and t >= c:
y = x if c == 1 else x.mean(axis=1)
# If first dim is 1 or 2, treat as (ch, time)
elif t in (1, 2) and x.shape[1] > t:
y = x[0] if t == 1 else x.mean(axis=0)
else:
# Fallback: assume (time, ch)
y = x.mean(axis=1)
else:
raise ValueError(f"Unsupported audio shape {x.shape} (need 1D or 2D).")
# Ensure float32 in [-1, 1], handle int16/24/32 just in case
if np.issubdtype(y.dtype, np.integer):
# assume int16-like full scale
y = y.astype(np.float32) / 32768.0
else:
y = y.astype(np.float32, copy=False)
# Remove NaNs/Infs and hard-clip
y = np.nan_to_num(y, nan=0.0, posinf=0.0, neginf=0.0)
return np.clip(y, -1.0, 1.0)
def resample_to_16k_mono(x: np.ndarray, sr: int) -> torch.Tensor:
"""
Returns torch float32 (1, time) @ 16 kHz mono on CPU.
"""
mono = to_mono(x)
wav = torch.from_numpy(mono).to(torch.float32) # (time,)
if sr != TARGET_SR:
wav = torchaudio.functional.resample(
wav, orig_freq=sr, new_freq=TARGET_SR
)
return wav.unsqueeze(0) # (1, time)
# ----------------------------
# Denoise core
# ----------------------------
@torch.no_grad()
def denoise_numpy(
audio: Tuple[int, np.ndarray], strength: str
) -> Tuple[Tuple[int, np.ndarray], Tuple[int, np.ndarray]]:
"""
Gradio callback:
input: (sr, np.ndarray) where array is mono or stereo
output: ((sr_in, mono_orig), (16k, mono_processed))
"""
if audio is None:
return None, None
in_sr, in_wav = audio
if in_wav is None or in_wav.size == 0:
return None, None
# Load model and prepare input
enhancer = get_enhancer()
device = next(enhancer.mods.parameters()).device # CPU
wav16 = resample_to_16k_mono(in_wav, in_sr).to(device) # (1, time)
lengths = torch.tensor([1.0], dtype=torch.float32, device=device)
# If effectively silent, skip processing
if wav16.abs().mean().item() < 1e-6:
original = (in_sr, to_mono(in_wav))
processed = (TARGET_SR, wav16.squeeze(0).cpu().numpy())
return original, processed
# Enhance (MetricGAN+ outputs enhanced waveform directly via mix-mask pipeline)
enhanced = enhancer.enhance_batch(wav16, lengths=lengths).squeeze(0) # (time,)
dry = wav16.squeeze(0)
# Wet/Dry mix by strength
mix = MIX_BY_STRENGTH.get(strength, MIX_BY_STRENGTH["Medium"])
out = dry * (1.0 - mix) + enhanced * mix
# Clean up output
out = torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)
y = torch.clamp(out, -1.0, 1.0).cpu().numpy().astype(np.float32)
original = (in_sr, to_mono(in_wav))
processed = (TARGET_SR, y)
return original, processed
# ----------------------------
# Gradio UI
# ----------------------------
CSS = """
:root { --brand: #4b6bfb; } /* tweak to your brand color */
.gradio-container { font-family: Inter, system-ui, -apple-system, Segoe UI, Roboto, Arial, sans-serif; }
#title { font-weight: 800; font-size: 1.4rem; }
#footer { opacity: 0.75; font-size: 0.85rem; }
button.primary { background: var(--brand) !important; }
"""
with gr.Blocks(css=CSS, title=TITLE, fill_height=True) as demo:
gr.Markdown(f"<div id='title'>{TITLE}</div>")
gr.Markdown(DESCRIPTION)
with gr.Row():
audio_in = gr.Audio(
label="Upload or record (mono or stereo)",
sources=["upload", "microphone"],
type="numpy", # (sr, np.ndarray)
waveform_options={"show_controls": True}, # keep simple transport
)
strength = gr.Radio(
choices=["Light", "Medium", "Strong"],
value="Medium",
label="Reduction Strength",
)
run_btn = gr.Button("Process", variant="primary")
with gr.Row():
orig_out = gr.Audio(label="Original (mono)", interactive=False)
proc_out = gr.Audio(label="Processed (16 kHz mono)", interactive=False)
gr.Markdown(
"<div id='footer'>Tip: Try Medium first. Strong may sound more 'processed' but removes more traffic/hiss.</div>"
)
run_btn.click(
fn=denoise_numpy,
inputs=[audio_in, strength],
outputs=[orig_out, proc_out],
scroll_to_output=True,
show_progress=True,
)
if __name__ == "__main__":
# On Hugging Face Spaces the host/port are set for you.
demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=True)