ZDingman commited on
Commit
fbb28c2
·
verified ·
1 Parent(s): 7dc97a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -90
app.py CHANGED
@@ -1,92 +1,46 @@
1
- import numpy as np
2
- import gradio as gr
3
- import soundfile as sf
4
- from scipy.signal import resample_poly
5
  import torch
 
 
6
 
7
- # Lazy import to avoid failing at build time
8
- ENHANCER = None
9
- TARGET_SR = 16000 # MetricGAN+ expects 16 kHz
10
- DEVICE = "cpu"
11
- torch.set_num_threads(1)
12
-
13
- def get_enhancer():
14
- global ENHANCER
15
- if ENHANCER is None:
16
- from speechbrain.pretrained import SpectralMaskEnhancement
17
- ENHANCER = SpectralMaskEnhancement.from_hparams(
18
- source="speechbrain/metricgan-plus-voicebank",
19
- savedir="pretrained_metricganp",
20
- run_opts={"device": DEVICE}
21
- )
22
- return ENHANCER
23
-
24
- def _to_mono(x: np.ndarray) -> np.ndarray:
25
- # x shape: (n,) or (n, ch); keep as float32 in [-1,1]
26
- if x.ndim == 2 and x.shape[1] > 1:
27
- x = np.mean(x, axis=1)
28
- x = np.asarray(x, dtype=np.float32)
29
- return np.clip(x, -1.0, 1.0)
30
-
31
- def _resample(x: np.ndarray, sr_in: int, sr_out: int) -> np.ndarray:
32
- if sr_in == sr_out:
33
- return x.astype(np.float32, copy=False)
34
- g = np.gcd(sr_in, sr_out)
35
- up, down = sr_out // g, sr_in // g
36
- y = resample_poly(x, up, down).astype(np.float32)
37
- return y
38
-
39
- def _mix(dry: np.ndarray, wet: np.ndarray, strength: str) -> np.ndarray:
40
- mix = {"Light": 0.4, "Medium": 0.7, "Strong": 1.0}.get(strength, 0.7)
41
- n = min(len(dry), len(wet))
42
- return dry[:n] * (1.0 - mix) + wet[:n] * mix
43
-
44
- def denoise(audio: tuple, strength: str):
45
- if audio is None:
46
- raise gr.Error("Please upload an audio file.")
47
-
48
- sr, data = audio
49
- data = np.asarray(data) # gradio sometimes gives list
50
-
51
- # to mono + float32
52
- dry_mono = _to_mono(data)
53
-
54
- # resample to 16k
55
- x16 = _resample(dry_mono, sr_in=sr, sr_out=TARGET_SR)
56
-
57
- # run enhancer (lazy load)
58
- enhancer = get_enhancer()
59
- with torch.no_grad():
60
- inp = torch.from_numpy(x16).unsqueeze(0) # (1, time)
61
- enhanced = enhancer.enhance_batch(inp, TARGET_SR)
62
- if isinstance(enhanced, torch.Tensor):
63
- enhanced = enhanced.squeeze(0).cpu().numpy().astype(np.float32)
64
-
65
- # back to original SR
66
- enh_sr = _resample(enhanced, sr_in=TARGET_SR, sr_out=sr)
67
-
68
- # wet/dry
69
- out = _mix(dry_mono, enh_sr, strength)
70
- return (sr, out.astype(np.float32))
71
-
72
- # -------- UI --------
73
- with gr.Blocks(theme=gr.themes.Soft(), css="footer{visibility:hidden}") as demo:
74
- gr.Markdown("### Zack’s Audio Outpost — AI Noise Reducer\nUpload a file and compare **Original vs Processed**.")
75
- with gr.Row():
76
- audio_in = gr.Audio(type="numpy", label="Upload Audio")
77
- strength = gr.Radio(["Light","Medium","Strong"], value="Medium", label="Noise Reduction Strength")
78
- run_btn = gr.Button("Run Noise Reduction", variant="primary")
79
- with gr.Row():
80
- orig = gr.Audio(label="Original")
81
- proc = gr.Audio(label="Processed")
82
-
83
- def run(audio, s):
84
- if audio is None:
85
- raise gr.Error("Please upload an audio file.")
86
- sr, x = audio
87
- y = denoise(audio, s)
88
- return (sr, x), y
89
-
90
- run_btn.click(run, [audio_in, strength], [orig, proc])
91
-
92
- demo.launch()
 
 
 
 
 
1
  import torch
2
+ import torchaudio
3
+ import numpy as np
4
 
5
+ TARGET_SR = 16000 # model expects 16 kHz
6
+
7
+ # strength -> wet mix
8
+ MIX_BY_STRENGTH = {
9
+ "Light": 0.5,
10
+ "Medium": 0.75,
11
+ "Strong": 1.0,
12
+ }
13
+
14
+ def _to_16k_mono(x: np.ndarray, sr: int) -> torch.Tensor:
15
+ """x: (time,) or (time, channels) float32 -1..1 -> torch (1, time) @16k"""
16
+ if x.ndim == 2: # stereo -> mono average
17
+ x = x.mean(axis=1)
18
+ wav = torch.from_numpy(x.astype(np.float32)) # (time,)
19
+ if sr != TARGET_SR:
20
+ wav = torchaudio.functional.resample(wav, sr, TARGET_SR)
21
+ return wav.unsqueeze(0) # (1, time)
22
+
23
+ @torch.no_grad()
24
+ def denoise(audio, strength):
25
+ # audio comes from gradio as (sr, np.ndarray) or filepath depending on your IO
26
+ # If you already have (sr, np.ndarray) upstream, keep that. Example below assumes tuple:
27
+ sr, x = audio # x shape (time, [channels]) float32 -1..1
28
+
29
+ # to 16k mono
30
+ wav16 = _to_16k_mono(x, sr) # (1, time) torch.float32
31
+ lengths = torch.tensor([1.0]) # full-length (relative) as required
32
+
33
+ # Run SpeechBrain enhancer (already created as `enhancer`)
34
+ enhanced = enhancer.enhance_batch(wav16, lengths=lengths) # (1, time)
35
+ enhanced = enhanced.squeeze(0) # (time,)
36
+ dry = wav16.squeeze(0)
37
+
38
+ # Wet/dry mix per UI strength
39
+ mix = MIX_BY_STRENGTH.get(strength, 0.75)
40
+ out = dry * (1.0 - mix) + enhanced * mix
41
+
42
+ # back to numpy @16k
43
+ y = out.cpu().numpy().astype(np.float32)
44
+
45
+ # Return (sr, waveform) to Gradio (or whatever your interface expects)
46
+ return (TARGET_SR, y)