ZDingman commited on
Commit
fc270e2
·
verified ·
1 Parent(s): 56349a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -75
app.py CHANGED
@@ -1,106 +1,116 @@
1
- import gradio as gr
2
  import numpy as np
 
 
3
  import torch
4
 
5
- # Try torchaudio for high-quality resampling; fall back to scipy if unavailable
 
6
  try:
7
  import torchaudio
8
- HAS_TORCHAUDIO = True
9
  except Exception:
10
- HAS_TORCHAUDIO = False
11
  from scipy.signal import resample_poly
12
 
 
13
  from speechbrain.pretrained import SpectralMaskEnhancement
14
 
15
- # Download once and cache in the Space
 
 
 
 
16
  ENHANCER = SpectralMaskEnhancement.from_hparams(
17
- source="speechbrain/metricgan-plus-voicebank",
18
- savedir="pretrained/metricgan-plus-voicebank",
 
19
  )
20
 
21
- TARGET_SR = 16000 # model expects 16 kHz
22
 
23
- def _resample(x: torch.Tensor, in_sr: int, out_sr: int) -> torch.Tensor:
24
- if in_sr == out_sr:
25
- return x
26
- if HAS_TORCHAUDIO:
27
- return torchaudio.functional.resample(x, in_sr, out_sr)
28
- # fallback (scipy) — expects numpy
29
- xn = x.cpu().numpy()
30
- g = np.gcd(in_sr, out_sr)
31
- up, down = out_sr // g, in_sr // g
32
- y = resample_poly(xn, up, down).astype(np.float32)
33
- return torch.from_numpy(y)
34
-
35
- def _to_tensor(mono_np: np.ndarray) -> torch.Tensor:
36
- t = torch.from_numpy(mono_np.astype(np.float32))
37
- peak = t.abs().max().clamp(min=1e-8)
38
- return (t / peak)
39
-
40
- def _enhance_channel(wav_np: np.ndarray, in_sr: int, mix: float) -> np.ndarray:
41
- """Enhance one channel and wet/dry mix."""
42
- x = _to_tensor(wav_np) # shape [T]
43
- x16 = _resample(x, in_sr, TARGET_SR) # -> 16 kHz
44
 
45
- with torch.no_grad():
46
- # Correct call: enhance_batch(wavs [, lengths]), NOT sample_rate
47
- # expects [B, T]; returns [B, T]
48
- est16 = ENHANCER.enhance_batch(x16.unsqueeze(0))[0]
49
-
50
- # back to original sr and length
51
- est = _resample(est16, TARGET_SR, in_sr)
52
- if est.shape[0] >= x.shape[0]:
53
- est = est[: x.shape[0]]
54
- else:
55
- est = torch.nn.functional.pad(est, (0, x.shape[0] - est.shape[0]))
56
-
57
- y = (1.0 - mix) * x + mix * est
58
- return y.cpu().numpy()
59
-
60
- def denoise(audio, strength):
 
 
 
 
 
 
 
61
  """
62
- Gradio (type='numpy') -> (sample_rate:int, data: np.ndarray)
63
- data is [T] or [T,2].
64
  """
65
- try:
66
- if audio is None:
67
- return None, None
 
 
68
 
69
- sr, data = audio
70
- chs = [data] if data.ndim == 1 else [data[:, 0], data[:, 1]]
 
71
 
72
- mix_map = {"Light": 0.5, "Medium": 0.75, "Strong": 1.0}
73
- mix = mix_map.get(strength, 0.75)
74
 
75
- out_chs = [_enhance_channel(c, sr, mix) for c in chs]
 
 
 
 
 
 
76
 
77
- processed = (np.stack(out_chs, axis=1) if len(out_chs) == 2
78
- else out_chs[0])
79
 
80
- # Return (sr, audio) tuples for A/B players
81
- return (sr, data), (sr, processed)
82
 
83
- except Exception as e:
84
- # Show a readable error in the UI
85
- msg = f"Processing error: {type(e).__name__}: {e}"
86
- print(msg)
87
- raise RuntimeError(msg)
88
 
89
- # -------- UI --------
90
- with gr.Blocks(css="footer {visibility: hidden}") as demo:
91
- gr.Markdown("## Zack’s Audio Outpost — AI Noise Reducer\nUpload a file and compare **Original** vs **Processed**.")
92
 
 
 
 
93
  with gr.Row():
94
  audio_in = gr.Audio(type="numpy", label="Upload Audio")
95
- strength = gr.Radio(["Light", "Medium", "Strong"], value="Medium",
96
- label="Noise Reduction Strength")
97
-
98
- run = gr.Button("Run Noise Reduction", variant="primary")
99
-
100
  with gr.Row():
101
- out_orig = gr.Audio(label="Original Audio")
102
- out_proc = gr.Audio(label="Processed Audio")
 
 
 
 
 
 
 
103
 
104
- run.click(denoise, inputs=[audio_in, strength], outputs=[out_orig, out_proc])
105
 
106
  demo.launch()
 
1
+ import os
2
  import numpy as np
3
+ import gradio as gr
4
+ import soundfile as sf
5
  import torch
6
 
7
+ # Try torchaudio for resampling. If it's not usable, fall back to SciPy.
8
+ USE_TORCHAUDIO = True
9
  try:
10
  import torchaudio
11
+ import torchaudio.functional as AF
12
  except Exception:
13
+ USE_TORCHAUDIO = False
14
  from scipy.signal import resample_poly
15
 
16
+ # SpeechBrain MetricGAN+ enhancement (CPU)
17
  from speechbrain.pretrained import SpectralMaskEnhancement
18
 
19
+ torch.set_num_threads(1)
20
+ DEVICE = "cpu"
21
+ MODEL_ID = "speechbrain/metricgan-plus-voicebank"
22
+
23
+ # Load the enhancer once
24
  ENHANCER = SpectralMaskEnhancement.from_hparams(
25
+ source=MODEL_ID,
26
+ savedir="pretrained_metricganp",
27
+ run_opts={"device": DEVICE}
28
  )
29
 
30
+ TARGET_SR = 16000 # MetricGAN+ expects 16 kHz
31
 
32
+ def _to_mono(x: np.ndarray) -> np.ndarray:
33
+ # x shape: (samples,) or (samples, channels)
34
+ if x.ndim == 2 and x.shape[1] > 1:
35
+ return np.mean(x, axis=1, dtype=np.float32)
36
+ return x.astype(np.float32, copy=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ def _resample(x: np.ndarray, sr_in: int, sr_out: int) -> np.ndarray:
39
+ if sr_in == sr_out:
40
+ return x
41
+ if USE_TORCHAUDIO:
42
+ with torch.no_grad():
43
+ t = torch.from_numpy(x).unsqueeze(0) # (1, time)
44
+ y = AF.resample(t, orig_freq=sr_in, new_freq=sr_out)
45
+ return y.squeeze(0).cpu().numpy().astype(np.float32)
46
+ # SciPy fall-back
47
+ g = np.gcd(sr_in, sr_out)
48
+ up, down = sr_out // g, sr_in // g
49
+ y = resample_poly(x, up, down).astype(np.float32)
50
+ return y
51
+
52
+ def _mix(dry: np.ndarray, wet: np.ndarray, strength: str) -> np.ndarray:
53
+ # Light / Medium / Strong → wet mix amounts
54
+ mix = {"Light": 0.4, "Medium": 0.7, "Strong": 1.0}.get(strength, 0.7)
55
+ # pad/truncate to the same length
56
+ n = min(len(dry), len(wet))
57
+ out = dry[:n] * (1.0 - mix) + wet[:n] * mix
58
+ return out
59
+
60
+ def denoise(audio: tuple, strength: str):
61
  """
62
+ Gradio passes (sr, np.ndarray[int16/float32, shape=(n,) or (n, ch)]) when type='numpy'
63
+ Return the processed audio as (sr, np.ndarray[float32]).
64
  """
65
+ if audio is None:
66
+ raise gr.Error("Please upload an audio file.")
67
+ sr, data = audio
68
+ if isinstance(data, list):
69
+ data = np.array(data, dtype=np.float32)
70
 
71
+ # To mono, float32 in [-1, 1]
72
+ x_mono = _to_mono(data)
73
+ x_mono = np.clip(x_mono, -1.0, 1.0).astype(np.float32)
74
 
75
+ # Resample to 16 kHz for the model
76
+ x_16k = _resample(x_mono, sr_in=sr, sr_out=TARGET_SR)
77
 
78
+ # Enhance with MetricGAN+
79
+ with torch.no_grad():
80
+ # Enhance expects torch.Tensor: shape (batch, time)
81
+ inp = torch.from_numpy(x_16k).unsqueeze(0)
82
+ enhanced = ENHANCER.enhance_batch(inp, TARGET_SR)
83
+ if isinstance(enhanced, torch.Tensor):
84
+ enhanced = enhanced.squeeze(0).cpu().numpy().astype(np.float32)
85
 
86
+ # Back to original sample rate
87
+ enhanced_sr = _resample(enhanced, sr_in=TARGET_SR, sr_out=sr)
88
 
89
+ # Mix according to strength (preserve dry transients)
90
+ y = _mix(dry=x_mono, wet=enhanced_sr, strength=strength)
91
 
92
+ # Return as mono track at original sr
93
+ return (sr, y.astype(np.float32))
 
 
 
94
 
 
 
 
95
 
96
+ # ---------- UI ----------
97
+ with gr.Blocks(theme=gr.themes.Soft(), css="footer {visibility:hidden}") as demo:
98
+ gr.Markdown("### Zack’s Audio Outpost — AI Noise Reducer\nUpload a file and compare **Original vs Processed**.")
99
  with gr.Row():
100
  audio_in = gr.Audio(type="numpy", label="Upload Audio")
101
+ strength = gr.Radio(["Light", "Medium", "Strong"], value="Medium", label="Noise Reduction Strength")
102
+ run_btn = gr.Button("Run Noise Reduction", variant="primary")
 
 
 
103
  with gr.Row():
104
+ orig = gr.Audio(label="Original")
105
+ clean = gr.Audio(label="Processed")
106
+
107
+ def run(audio, strength):
108
+ if audio is None:
109
+ raise gr.Error("Please upload an audio file.")
110
+ sr, data = audio
111
+ processed = denoise((sr, data), strength)
112
+ return (sr, data), processed
113
 
114
+ run_btn.click(fn=run, inputs=[audio_in, strength], outputs=[orig, clean])
115
 
116
  demo.launch()