|
|
|
|
|
|
|
|
|
import os |
|
from typing import Tuple |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
import torchaudio |
|
from speechbrain.pretrained import SpectralMaskEnhancement |
|
|
|
|
|
|
|
|
|
TARGET_SR = 16_000 |
|
|
|
|
|
MIX_BY_STRENGTH = { |
|
"Light": 0.50, |
|
"Medium": 0.75, |
|
"Strong": 1.00, |
|
} |
|
|
|
|
|
MODEL_SOURCE = "speechbrain/metricgan-plus-voicebank" |
|
MODEL_DIR = "pretrained_models/metricgan-plus-voicebank" |
|
|
|
|
|
_enhancer: SpectralMaskEnhancement | None = None |
|
|
|
|
|
def get_enhancer() -> SpectralMaskEnhancement: |
|
"""Lazy-load the SpeechBrain enhancer once.""" |
|
global _enhancer |
|
if _enhancer is None: |
|
_enhancer = SpectralMaskEnhancement.from_hparams( |
|
source=MODEL_SOURCE, |
|
savedir=MODEL_DIR, |
|
) |
|
_enhancer.mods.eval() |
|
torch.set_grad_enabled(False) |
|
return _enhancer |
|
|
|
|
|
|
|
|
|
|
|
def to_mono(x: np.ndarray) -> np.ndarray: |
|
""" |
|
Ensure mono. Accepts shapes: |
|
- (time,) already mono |
|
- (time, channels) -> average channels |
|
- (channels, time) -> average channels, return (time,) |
|
Returns float32 -1..1 |
|
""" |
|
if x.ndim == 1: |
|
y = x |
|
elif x.ndim == 2: |
|
if x.shape[0] < x.shape[1]: |
|
|
|
y = x.mean(axis=0) |
|
else: |
|
|
|
y = x.mean(axis=1) |
|
else: |
|
raise ValueError("Unsupported audio shape; expected 1D or 2D ndarray") |
|
return y.astype(np.float32, copy=False) |
|
|
|
|
|
def resample_to_16k_mono(x: np.ndarray, sr: int) -> torch.Tensor: |
|
""" |
|
Numpy -> torch (1, time) @ 16 kHz mono, float32 in [-1, 1] |
|
""" |
|
mono = to_mono(x) |
|
wav = torch.from_numpy(mono) |
|
if sr != TARGET_SR: |
|
wav = torchaudio.functional.resample(wav, sr, TARGET_SR) |
|
return wav.unsqueeze(0) |
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
def denoise_numpy(audio: Tuple[int, np.ndarray], strength: str) -> Tuple[Tuple[int, np.ndarray], Tuple[int, np.ndarray]]: |
|
""" |
|
Gradio callback. |
|
Input: |
|
audio: (sr, numpy waveform) |
|
strength: "Light" | "Medium" | "Strong" |
|
Output: |
|
(original_sr, original_wav), (TARGET_SR, processed_wav) |
|
Both as float32 in [-1, 1] |
|
""" |
|
if audio is None: |
|
return None, None |
|
|
|
in_sr, in_wav = audio |
|
if in_wav is None or in_wav.size == 0: |
|
return None, None |
|
|
|
in_wav = in_wav.astype(np.float32, copy=False) |
|
|
|
|
|
wav16 = resample_to_16k_mono(in_wav, in_sr) |
|
|
|
|
|
lengths = torch.tensor([1.0]) |
|
|
|
|
|
enhancer = get_enhancer() |
|
enhanced = enhancer.enhance_batch(wav16, lengths=lengths).squeeze(0) |
|
dry = wav16.squeeze(0) |
|
|
|
|
|
mix = MIX_BY_STRENGTH.get(strength, MIX_BY_STRENGTH["Medium"]) |
|
out = dry * (1.0 - mix) + enhanced * mix |
|
|
|
|
|
y = torch.clamp(out, -1.0, 1.0).cpu().numpy().astype(np.float32) |
|
|
|
|
|
original = (in_sr, to_mono(in_wav)) |
|
processed = (TARGET_SR, y) |
|
return original, processed |
|
|
|
|
|
|
|
|
|
|
|
CSS = """ |
|
.gradio-container { max-width: 1100px !important; } |
|
#title { font-weight: 700; font-size: 1.4rem; margin-bottom: .25rem; } |
|
#subtitle { opacity: .8; margin-bottom: .75rem; } |
|
""" |
|
|
|
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo: |
|
gr.HTML('<div id="title">Zack’s Audio Outpost — AI Noise Reducer</div>') |
|
gr.HTML('<div id="subtitle">Upload a file and compare <b>Original</b> vs <b>Processed</b>.</div>') |
|
|
|
with gr.Row(): |
|
audio_in = gr.Audio( |
|
sources=["upload"], |
|
type="numpy", |
|
label="Upload Audio", |
|
|
|
) |
|
strength = gr.Radio( |
|
choices=["Light", "Medium", "Strong"], |
|
value="Medium", |
|
label="Noise Reduction Strength", |
|
) |
|
|
|
btn = gr.Button("Run Noise Reduction", variant="primary") |
|
|
|
with gr.Row(): |
|
out_orig = gr.Audio(type="numpy", label="Original") |
|
out_proc = gr.Audio(type="numpy", label="Processed") |
|
|
|
btn.click(denoise_numpy, inputs=[audio_in, strength], outputs=[out_orig, out_proc]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|