qfuxa commited on
Commit
1ac9078
·
1 Parent(s): bd15235

silero vad

Browse files
Files changed (1) hide show
  1. silero_vad_iterator.py +163 -0
silero_vad_iterator.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # This is copied from silero-vad's vad_utils.py:
4
+ # https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/utils_vad.py#L340
5
+ # (except changed defaults)
6
+
7
+ # Their licence is MIT, same as ours: https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/LICENSE
8
+
9
+
10
+ class VADIterator:
11
+ def __init__(
12
+ self,
13
+ model,
14
+ threshold: float = 0.5,
15
+ sampling_rate: int = 16000,
16
+ min_silence_duration_ms: int = 500, # makes sense on one recording that I checked
17
+ speech_pad_ms: int = 100, # same
18
+ ):
19
+ """
20
+ Class for stream imitation
21
+
22
+ Parameters
23
+ ----------
24
+ model: preloaded .jit silero VAD model
25
+
26
+ threshold: float (default - 0.5)
27
+ Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
28
+ It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
29
+
30
+ sampling_rate: int (default - 16000)
31
+ Currently silero VAD models support 8000 and 16000 sample rates
32
+
33
+ min_silence_duration_ms: int (default - 100 milliseconds)
34
+ In the end of each speech chunk wait for min_silence_duration_ms before separating it
35
+
36
+ speech_pad_ms: int (default - 30 milliseconds)
37
+ Final speech chunks are padded by speech_pad_ms each side
38
+ """
39
+
40
+ self.model = model
41
+ self.threshold = threshold
42
+ self.sampling_rate = sampling_rate
43
+
44
+ if sampling_rate not in [8000, 16000]:
45
+ raise ValueError(
46
+ "VADIterator does not support sampling rates other than [8000, 16000]"
47
+ )
48
+
49
+ self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
50
+ self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
51
+ self.reset_states()
52
+
53
+ def reset_states(self):
54
+
55
+ self.model.reset_states()
56
+ self.triggered = False
57
+ self.temp_end = 0
58
+ self.current_sample = 0
59
+
60
+ def __call__(self, x, return_seconds=False):
61
+ """
62
+ x: torch.Tensor
63
+ audio chunk (see examples in repo)
64
+
65
+ return_seconds: bool (default - False)
66
+ whether return timestamps in seconds (default - samples)
67
+ """
68
+
69
+ if not torch.is_tensor(x):
70
+ try:
71
+ x = torch.Tensor(x)
72
+ except:
73
+ raise TypeError("Audio cannot be casted to tensor. Cast it manually")
74
+
75
+ window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
76
+ self.current_sample += window_size_samples
77
+
78
+ speech_prob = self.model(x, self.sampling_rate).item()
79
+
80
+ if (speech_prob >= self.threshold) and self.temp_end:
81
+ self.temp_end = 0
82
+
83
+ if (speech_prob >= self.threshold) and not self.triggered:
84
+ self.triggered = True
85
+ speech_start = self.current_sample - self.speech_pad_samples
86
+ return {
87
+ "start": (
88
+ int(speech_start)
89
+ if not return_seconds
90
+ else round(speech_start / self.sampling_rate, 1)
91
+ )
92
+ }
93
+
94
+ if (speech_prob < self.threshold - 0.15) and self.triggered:
95
+ if not self.temp_end:
96
+ self.temp_end = self.current_sample
97
+ if self.current_sample - self.temp_end < self.min_silence_samples:
98
+ return None
99
+ else:
100
+ speech_end = self.temp_end + self.speech_pad_samples
101
+ self.temp_end = 0
102
+ self.triggered = False
103
+ return {
104
+ "end": (
105
+ int(speech_end)
106
+ if not return_seconds
107
+ else round(speech_end / self.sampling_rate, 1)
108
+ )
109
+ }
110
+
111
+ return None
112
+
113
+
114
+ #######################
115
+ # because Silero now requires exactly 512-sized audio chunks
116
+
117
+ import numpy as np
118
+
119
+
120
+ class FixedVADIterator(VADIterator):
121
+ """It fixes VADIterator by allowing to process any audio length, not only exactly 512 frames at once.
122
+ If audio to be processed at once is long and multiple voiced segments detected,
123
+ then __call__ returns the start of the first segment, and end (or middle, which means no end) of the last segment.
124
+ """
125
+
126
+ def reset_states(self):
127
+ super().reset_states()
128
+ self.buffer = np.array([], dtype=np.float32)
129
+
130
+ def __call__(self, x, return_seconds=False):
131
+ self.buffer = np.append(self.buffer, x)
132
+ ret = None
133
+ while len(self.buffer) >= 512:
134
+ r = super().__call__(self.buffer[:512], return_seconds=return_seconds)
135
+ self.buffer = self.buffer[512:]
136
+ if ret is None:
137
+ ret = r
138
+ elif r is not None:
139
+ if "end" in r:
140
+ ret["end"] = r["end"] # the latter end
141
+ if "start" in r and "end" in ret: # there is an earlier start.
142
+ # Remove end, merging this segment with the previous one.
143
+ del ret["end"]
144
+ return ret if ret != {} else None
145
+
146
+
147
+ if __name__ == "__main__":
148
+ # test/demonstrate the need for FixedVADIterator:
149
+
150
+ import torch
151
+
152
+ model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
153
+ vac = FixedVADIterator(model)
154
+ # vac = VADIterator(model) # the second case crashes with this
155
+
156
+ # this works: for both
157
+ audio_buffer = np.array([0] * (512), dtype=np.float32)
158
+ vac(audio_buffer)
159
+
160
+ # this crashes on the non FixedVADIterator with
161
+ # ops.prim.RaiseException("Input audio chunk is too short", "builtins.ValueError")
162
+ audio_buffer = np.array([0] * (512 - 1), dtype=np.float32)
163
+ vac(audio_buffer)