|
import os |
|
from typing import Tuple |
|
|
|
import numpy as np |
|
import torch |
|
import torchaudio |
|
import gradio as gr |
|
from speechbrain.pretrained import SpectralMaskEnhancement |
|
|
|
|
|
|
|
|
|
TITLE = "Zack's Audio Outpost — Voice Denoiser" |
|
DESCRIPTION = """ |
|
Upload a short audio clip with speech (mono or stereo). |
|
Choose **Light**, **Medium**, or **Strong** reduction and compare **Original vs Processed**. |
|
""" |
|
|
|
TARGET_SR = 16000 |
|
MIX_BY_STRENGTH = {"Light": 0.35, "Medium": 0.65, "Strong": 0.9} |
|
|
|
_enhancer = None |
|
|
|
|
|
|
|
|
|
|
|
def get_enhancer() -> SpectralMaskEnhancement: |
|
""" |
|
Loads SpeechBrain MetricGAN+ VoiceBank denoiser (once) and caches it. |
|
Runs on CPU inside Spaces by default. |
|
""" |
|
global _enhancer |
|
if _enhancer is None: |
|
_enhancer = SpectralMaskEnhancement.from_hparams( |
|
source="speechbrain/metricgan-plus-voicebank", |
|
savedir="pretrained_models/metricgan-plus-voicebank", |
|
run_opts={"device": "cpu"}, |
|
) |
|
|
|
_enhancer.mods.eval() |
|
return _enhancer |
|
|
|
|
|
|
|
|
|
|
|
def to_mono(x: np.ndarray) -> np.ndarray: |
|
""" |
|
Normalize shape consistently to mono float32 in [-1, 1]. |
|
|
|
Accepts: |
|
(time,) -> mono |
|
(time, channels) -> last dim is channels |
|
(channels, time) -> first dim is channels |
|
""" |
|
if x.ndim == 1: |
|
y = x |
|
elif x.ndim == 2: |
|
t, c = x.shape |
|
|
|
if c in (1, 2) and t >= c: |
|
y = x if c == 1 else x.mean(axis=1) |
|
|
|
elif t in (1, 2) and x.shape[1] > t: |
|
y = x[0] if t == 1 else x.mean(axis=0) |
|
else: |
|
|
|
y = x.mean(axis=1) |
|
else: |
|
raise ValueError(f"Unsupported audio shape {x.shape} (need 1D or 2D).") |
|
|
|
|
|
if np.issubdtype(y.dtype, np.integer): |
|
|
|
y = y.astype(np.float32) / 32768.0 |
|
else: |
|
y = y.astype(np.float32, copy=False) |
|
|
|
|
|
y = np.nan_to_num(y, nan=0.0, posinf=0.0, neginf=0.0) |
|
return np.clip(y, -1.0, 1.0) |
|
|
|
|
|
def resample_to_16k_mono(x: np.ndarray, sr: int) -> torch.Tensor: |
|
""" |
|
Returns torch float32 (1, time) @ 16 kHz mono on CPU. |
|
""" |
|
mono = to_mono(x) |
|
wav = torch.from_numpy(mono).to(torch.float32) |
|
if sr != TARGET_SR: |
|
wav = torchaudio.functional.resample( |
|
wav, orig_freq=sr, new_freq=TARGET_SR |
|
) |
|
return wav.unsqueeze(0) |
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
def denoise_numpy( |
|
audio: Tuple[int, np.ndarray], strength: str |
|
) -> Tuple[Tuple[int, np.ndarray], Tuple[int, np.ndarray]]: |
|
""" |
|
Gradio callback: |
|
input: (sr, np.ndarray) where array is mono or stereo |
|
output: ((sr_in, mono_orig), (16k, mono_processed)) |
|
""" |
|
if audio is None: |
|
return None, None |
|
|
|
in_sr, in_wav = audio |
|
if in_wav is None or in_wav.size == 0: |
|
return None, None |
|
|
|
|
|
enhancer = get_enhancer() |
|
device = next(enhancer.mods.parameters()).device |
|
wav16 = resample_to_16k_mono(in_wav, in_sr).to(device) |
|
lengths = torch.tensor([1.0], dtype=torch.float32, device=device) |
|
|
|
|
|
if wav16.abs().mean().item() < 1e-6: |
|
original = (in_sr, to_mono(in_wav)) |
|
processed = (TARGET_SR, wav16.squeeze(0).cpu().numpy()) |
|
return original, processed |
|
|
|
|
|
enhanced = enhancer.enhance_batch(wav16, lengths=lengths).squeeze(0) |
|
dry = wav16.squeeze(0) |
|
|
|
|
|
mix = MIX_BY_STRENGTH.get(strength, MIX_BY_STRENGTH["Medium"]) |
|
out = dry * (1.0 - mix) + enhanced * mix |
|
|
|
|
|
out = torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0) |
|
y = torch.clamp(out, -1.0, 1.0).cpu().numpy().astype(np.float32) |
|
|
|
original = (in_sr, to_mono(in_wav)) |
|
processed = (TARGET_SR, y) |
|
return original, processed |
|
|
|
|
|
|
|
|
|
|
|
CSS = """ |
|
:root { --brand: #4b6bfb; } /* tweak to your brand color */ |
|
.gradio-container { font-family: Inter, system-ui, -apple-system, Segoe UI, Roboto, Arial, sans-serif; } |
|
#title { font-weight: 800; font-size: 1.4rem; } |
|
#footer { opacity: 0.75; font-size: 0.85rem; } |
|
button.primary { background: var(--brand) !important; } |
|
""" |
|
|
|
with gr.Blocks(css=CSS, title=TITLE, fill_height=True) as demo: |
|
gr.Markdown(f"<div id='title'>{TITLE}</div>") |
|
gr.Markdown(DESCRIPTION) |
|
|
|
with gr.Row(): |
|
audio_in = gr.Audio( |
|
label="Upload or record (mono or stereo)", |
|
sources=["upload", "microphone"], |
|
type="numpy", |
|
waveform_options={"show_controls": True}, |
|
) |
|
strength = gr.Radio( |
|
choices=["Light", "Medium", "Strong"], |
|
value="Medium", |
|
label="Reduction Strength", |
|
) |
|
|
|
run_btn = gr.Button("Process", variant="primary") |
|
|
|
with gr.Row(): |
|
orig_out = gr.Audio(label="Original (mono)", interactive=False) |
|
proc_out = gr.Audio(label="Processed (16 kHz mono)", interactive=False) |
|
|
|
gr.Markdown( |
|
"<div id='footer'>Tip: Try Medium first. Strong may sound more 'processed' but removes more traffic/hiss.</div>" |
|
) |
|
|
|
run_btn.click( |
|
fn=denoise_numpy, |
|
inputs=[audio_in, strength], |
|
outputs=[orig_out, proc_out], |
|
scroll_to_output=True, |
|
show_progress=True, |
|
) |
|
|
|
if __name__ == "__main__": |
|
|
|
demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=True) |
|
|