Dominik Macháček
commited on
Commit
·
6387098
1
Parent(s):
7eeb73f
FixedSileroVADIterator to support other than 512-sized chunks with v5
Browse files- silero_vad.py +37 -0
- whisper_online.py +1 -1
silero_vad.py
CHANGED
@@ -94,4 +94,41 @@ class VADIterator:
|
|
94 |
|
95 |
return None
|
96 |
|
|
|
|
|
|
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
return None
|
96 |
|
97 |
+
#######################
|
98 |
+
# this is our workaround for Silero v5 requiring at least 512-sized audio chunks
|
99 |
+
# (see https://github.com/ufal/whisper_streaming/issues/116 )
|
100 |
|
101 |
+
import numpy as np
|
102 |
+
class FixedVADIterator(VADIterator):
|
103 |
+
|
104 |
+
def reset_states(self):
|
105 |
+
super().reset_states()
|
106 |
+
self.buffer = np.array([],dtype=np.float32)
|
107 |
+
|
108 |
+
def __call__(self, x, return_seconds=False):
|
109 |
+
self.buffer = np.append(self.buffer, x)
|
110 |
+
if len(self.buffer) >= 512:
|
111 |
+
ret = super().__call__(self.buffer, return_seconds=return_seconds)
|
112 |
+
self.buffer = np.array([],dtype=np.float32)
|
113 |
+
return ret
|
114 |
+
return None
|
115 |
+
|
116 |
+
if __name__ == "__main__":
|
117 |
+
# test/demonstrate the need for FixedVADIterator:
|
118 |
+
|
119 |
+
import torch
|
120 |
+
model, _ = torch.hub.load(
|
121 |
+
repo_or_dir='snakers4/silero-vad',
|
122 |
+
model='silero_vad'
|
123 |
+
)
|
124 |
+
vac = FixedVADIterator(model)
|
125 |
+
# vac = VADIterator(model) # the second case crashes with this
|
126 |
+
|
127 |
+
# this works: for both
|
128 |
+
audio_buffer = np.array([0]*(512),dtype=np.float32)
|
129 |
+
vac(audio_buffer)
|
130 |
+
|
131 |
+
# this crashes on the non FixedVADIterator with
|
132 |
+
# ops.prim.RaiseException("Input audio chunk is too short", "builtins.ValueError")
|
133 |
+
audio_buffer = np.array([0]*(512-1),dtype=np.float32)
|
134 |
+
vac(audio_buffer)
|
whisper_online.py
CHANGED
@@ -531,7 +531,7 @@ class VACOnlineASRProcessor(OnlineASRProcessor):
|
|
531 |
# VAC:
|
532 |
import torch
|
533 |
model, _ = torch.hub.load(
|
534 |
-
repo_or_dir='snakers4/silero-vad
|
535 |
model='silero_vad'
|
536 |
)
|
537 |
from silero_vad import VADIterator
|
|
|
531 |
# VAC:
|
532 |
import torch
|
533 |
model, _ = torch.hub.load(
|
534 |
+
repo_or_dir='snakers4/silero-vad',
|
535 |
model='silero_vad'
|
536 |
)
|
537 |
from silero_vad import VADIterator
|