File size: 5,344 Bytes
2ee63b1
 
 
 
 
 
 
 
 
0af1c3f
fbb28c2
2ee63b1
56349a1
2ee63b1
 
 
 
fbb28c2
2ee63b1
fbb28c2
2ee63b1
 
 
fbb28c2
 
2ee63b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbb28c2
 
 
 
2ee63b1
 
 
 
fbb28c2
2ee63b1
 
 
 
 
 
 
 
 
 
 
 
 
fbb28c2
2ee63b1
 
 
fbb28c2
2ee63b1
 
 
 
 
 
 
 
 
 
 
fbb28c2
2ee63b1
fbb28c2
 
2ee63b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbb28c2
2ee63b1
fbb28c2
2ee63b1
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# app.py
# Zack's Audio Outpost — AI Noise Reducer (SpeechBrain MetricGAN)
# Works on CPU in a Hugging Face Space. No GPU required.

import os
from typing import Tuple

import gradio as gr
import numpy as np
import torch
import torchaudio
from speechbrain.pretrained import SpectralMaskEnhancement

# -----------------------------
# Config
# -----------------------------
TARGET_SR = 16_000  # The SpeechBrain mtl-mimic-voicebank model expects 16 kHz mono

# Wet/dry mix by "strength"
MIX_BY_STRENGTH = {
    "Light": 0.50,   # 50% wet
    "Medium": 0.75,  # 75% wet
    "Strong": 1.00,  # 100% wet
}

MODEL_SOURCE = "speechbrain/mtl-mimic-voicebank"
MODEL_DIR = "pretrained_models/mtl-mimic-voicebank"

# Global enhancer (loaded once)
_enhancer: SpectralMaskEnhancement | None = None


def get_enhancer() -> SpectralMaskEnhancement:
    """Lazy-load the SpeechBrain enhancer once."""
    global _enhancer
    if _enhancer is None:
        # Downloads the small MetricGAN+ checkpoint on first run
        _enhancer = SpectralMaskEnhancement.from_hparams(
            source=MODEL_SOURCE, savedir=MODEL_DIR
        )
        _enhancer.mods.eval()
        torch.set_grad_enabled(False)
    return _enhancer


# -----------------------------
# Audio helpers
# -----------------------------
def to_mono(x: np.ndarray) -> np.ndarray:
    """
    Ensure mono. Accepts shapes:
      - (time,) already mono
      - (time, channels) -> average channels
      - (channels, time) (rare) -> average channels, return (time,)
    Returns float32 -1..1
    """
    if x.ndim == 1:
        y = x
    elif x.ndim == 2:
        # pick which axis is channels
        if x.shape[0] < x.shape[1]:
            # (channels, time)
            y = x.mean(axis=0)
        else:
            # (time, channels)
            y = x.mean(axis=1)
    else:
        raise ValueError("Unsupported audio shape; expected 1D or 2D ndarray")
    return y.astype(np.float32, copy=False)


def resample_to_16k_mono(x: np.ndarray, sr: int) -> torch.Tensor:
    """
    Numpy -> torch (1, time) @ 16 kHz mono, float32 in [-1, 1]
    """
    mono = to_mono(x)
    wav = torch.from_numpy(mono)  # (time,)
    if sr != TARGET_SR:
        wav = torchaudio.functional.resample(wav, sr, TARGET_SR)
    return wav.unsqueeze(0)  # (1, time)


# -----------------------------
# Core processing
# -----------------------------
@torch.no_grad()
def denoise_numpy(audio: Tuple[int, np.ndarray], strength: str) -> Tuple[Tuple[int, np.ndarray], Tuple[int, np.ndarray]]:
    """
    Gradio callback.
    Input:
      audio: (sr, numpy waveform)
      strength: "Light" | "Medium" | "Strong"
    Output:
      (original_sr, original_wav), (TARGET_SR, processed_wav)
      Both as float32 in [-1, 1]
    """
    if audio is None:
        # Nothing uploaded
        return None, None

    in_sr, in_wav = audio
    if in_wav is None or in_wav.size == 0:
        return None, None

    # Normalize types just in case
    in_wav = in_wav.astype(np.float32, copy=False)

    # Prepare input for model (mono, 16k)
    wav16 = resample_to_16k_mono(in_wav, in_sr)  # torch (1, time)

    # SpeechBrain expects relative lengths tensor (batch-size == 1)
    lengths = torch.tensor([1.0])

    # Enhance
    enhancer = get_enhancer()
    enhanced = enhancer.enhance_batch(wav16, lengths=lengths)  # (1, time)
    enhanced = enhanced.squeeze(0)  # (time,)
    dry = wav16.squeeze(0)

    # Wet/dry mix
    mix = MIX_BY_STRENGTH.get(strength, MIX_BY_STRENGTH["Medium"])
    out = dry * (1.0 - mix) + enhanced * mix  # (time,)

    # Clamp just in case, then back to numpy
    y = torch.clamp(out, -1.0, 1.0).cpu().numpy().astype(np.float32)

    # For "Original", we return the user’s uploaded audio unmodified
    # (Gradio prefers (sr, waveform) for type="numpy")
    original = (in_sr, to_mono(in_wav))  # make sure it plays as mono
    processed = (TARGET_SR, y)

    return original, processed


# -----------------------------
# UI
# -----------------------------
CSS = """
/* simple brand-ish tweaks */
.gradio-container { max-width: 1100px !important; }
#title { font-weight: 700; font-size: 1.4rem; margin-bottom: .25rem; }
#subtitle { opacity: .8; margin-bottom: .75rem; }
"""

with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
    gr.HTML('<div id="title">Zack’s Audio Outpost — AI Noise Reducer</div>')
    gr.HTML('<div id="subtitle">Upload a file and compare <b>Original</b> vs <b>Processed</b>.</div>')

    with gr.Row():
        audio_in = gr.Audio(
            sources=["upload"],
            type="numpy",          # returns (sr, np.ndarray)
            label="Upload Audio",
            waveform_options=gr.WaveformOptions(show_controls=True),
        )
        strength = gr.Radio(
            choices=["Light", "Medium", "Strong"],
            value="Medium",
            label="Noise Reduction Strength",
        )

    btn = gr.Button("Run Noise Reduction", variant="primary")

    with gr.Row():
        out_orig = gr.Audio(type="numpy", label="Original")
        out_proc = gr.Audio(type="numpy", label="Processed")

    btn.click(denoise_numpy, inputs=[audio_in, strength], outputs=[out_orig, out_proc])

# Recommended: SSR is fine on Spaces; leave default
if __name__ == "__main__":
    # In Spaces this is ignored; locally it runs on http://0.0.0.0:7860
    demo.launch()