import os from typing import Tuple import numpy as np import torch import torchaudio import gradio as gr from speechbrain.pretrained import SpectralMaskEnhancement # ---------------------------- # Constants / Globals # ---------------------------- TITLE = "Zack's Audio Outpost — Voice Denoiser" DESCRIPTION = """ Upload a short audio clip with speech (mono or stereo). Choose **Light**, **Medium**, or **Strong** reduction and compare **Original vs Processed**. """ TARGET_SR = 16000 # MetricGAN+ VoiceBank expects 16 kHz MIX_BY_STRENGTH = {"Light": 0.35, "Medium": 0.65, "Strong": 0.9} _enhancer = None # lazy-loaded model # ---------------------------- # Model Loader (cached) # ---------------------------- def get_enhancer() -> SpectralMaskEnhancement: """ Loads SpeechBrain MetricGAN+ VoiceBank denoiser (once) and caches it. Runs on CPU inside Spaces by default. """ global _enhancer if _enhancer is None: _enhancer = SpectralMaskEnhancement.from_hparams( source="speechbrain/metricgan-plus-voicebank", savedir="pretrained_models/metricgan-plus-voicebank", run_opts={"device": "cpu"}, ) # Put underlying nn.Module in eval mode _enhancer.mods.eval() return _enhancer # ---------------------------- # Audio utilities # ---------------------------- def to_mono(x: np.ndarray) -> np.ndarray: """ Normalize shape consistently to mono float32 in [-1, 1]. Accepts: (time,) -> mono (time, channels) -> last dim is channels (channels, time) -> first dim is channels """ if x.ndim == 1: y = x elif x.ndim == 2: t, c = x.shape # If last dim is 1 or 2, treat as (time, ch) if c in (1, 2) and t >= c: y = x if c == 1 else x.mean(axis=1) # If first dim is 1 or 2, treat as (ch, time) elif t in (1, 2) and x.shape[1] > t: y = x[0] if t == 1 else x.mean(axis=0) else: # Fallback: assume (time, ch) y = x.mean(axis=1) else: raise ValueError(f"Unsupported audio shape {x.shape} (need 1D or 2D).") # Ensure float32 in [-1, 1], handle int16/24/32 just in case if np.issubdtype(y.dtype, np.integer): # assume int16-like full scale y = y.astype(np.float32) / 32768.0 else: y = y.astype(np.float32, copy=False) # Remove NaNs/Infs and hard-clip y = np.nan_to_num(y, nan=0.0, posinf=0.0, neginf=0.0) return np.clip(y, -1.0, 1.0) def resample_to_16k_mono(x: np.ndarray, sr: int) -> torch.Tensor: """ Returns torch float32 (1, time) @ 16 kHz mono on CPU. """ mono = to_mono(x) wav = torch.from_numpy(mono).to(torch.float32) # (time,) if sr != TARGET_SR: wav = torchaudio.functional.resample( wav, orig_freq=sr, new_freq=TARGET_SR ) return wav.unsqueeze(0) # (1, time) # ---------------------------- # Denoise core # ---------------------------- @torch.no_grad() def denoise_numpy( audio: Tuple[int, np.ndarray], strength: str ) -> Tuple[Tuple[int, np.ndarray], Tuple[int, np.ndarray]]: """ Gradio callback: input: (sr, np.ndarray) where array is mono or stereo output: ((sr_in, mono_orig), (16k, mono_processed)) """ if audio is None: return None, None in_sr, in_wav = audio if in_wav is None or in_wav.size == 0: return None, None # Load model and prepare input enhancer = get_enhancer() device = next(enhancer.mods.parameters()).device # CPU wav16 = resample_to_16k_mono(in_wav, in_sr).to(device) # (1, time) lengths = torch.tensor([1.0], dtype=torch.float32, device=device) # If effectively silent, skip processing if wav16.abs().mean().item() < 1e-6: original = (in_sr, to_mono(in_wav)) processed = (TARGET_SR, wav16.squeeze(0).cpu().numpy()) return original, processed # Enhance (MetricGAN+ outputs enhanced waveform directly via mix-mask pipeline) enhanced = enhancer.enhance_batch(wav16, lengths=lengths).squeeze(0) # (time,) dry = wav16.squeeze(0) # Wet/Dry mix by strength mix = MIX_BY_STRENGTH.get(strength, MIX_BY_STRENGTH["Medium"]) out = dry * (1.0 - mix) + enhanced * mix # Clean up output out = torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0) y = torch.clamp(out, -1.0, 1.0).cpu().numpy().astype(np.float32) original = (in_sr, to_mono(in_wav)) processed = (TARGET_SR, y) return original, processed # ---------------------------- # Gradio UI # ---------------------------- CSS = """ :root { --brand: #4b6bfb; } /* tweak to your brand color */ .gradio-container { font-family: Inter, system-ui, -apple-system, Segoe UI, Roboto, Arial, sans-serif; } #title { font-weight: 800; font-size: 1.4rem; } #footer { opacity: 0.75; font-size: 0.85rem; } button.primary { background: var(--brand) !important; } """ with gr.Blocks(css=CSS, title=TITLE, fill_height=True) as demo: gr.Markdown(f"