Dominik Macháček commited on
Commit
6387098
·
1 Parent(s): 7eeb73f

FixedSileroVADIterator to support other than 512-sized chunks with v5

Browse files
Files changed (2) hide show
  1. silero_vad.py +37 -0
  2. whisper_online.py +1 -1
silero_vad.py CHANGED
@@ -94,4 +94,41 @@ class VADIterator:
94
 
95
  return None
96
 
 
 
 
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  return None
96
 
97
+ #######################
98
+ # this is our workaround for Silero v5 requiring at least 512-sized audio chunks
99
+ # (see https://github.com/ufal/whisper_streaming/issues/116 )
100
 
101
+ import numpy as np
102
+ class FixedVADIterator(VADIterator):
103
+
104
+ def reset_states(self):
105
+ super().reset_states()
106
+ self.buffer = np.array([],dtype=np.float32)
107
+
108
+ def __call__(self, x, return_seconds=False):
109
+ self.buffer = np.append(self.buffer, x)
110
+ if len(self.buffer) >= 512:
111
+ ret = super().__call__(self.buffer, return_seconds=return_seconds)
112
+ self.buffer = np.array([],dtype=np.float32)
113
+ return ret
114
+ return None
115
+
116
+ if __name__ == "__main__":
117
+ # test/demonstrate the need for FixedVADIterator:
118
+
119
+ import torch
120
+ model, _ = torch.hub.load(
121
+ repo_or_dir='snakers4/silero-vad',
122
+ model='silero_vad'
123
+ )
124
+ vac = FixedVADIterator(model)
125
+ # vac = VADIterator(model) # the second case crashes with this
126
+
127
+ # this works: for both
128
+ audio_buffer = np.array([0]*(512),dtype=np.float32)
129
+ vac(audio_buffer)
130
+
131
+ # this crashes on the non FixedVADIterator with
132
+ # ops.prim.RaiseException("Input audio chunk is too short", "builtins.ValueError")
133
+ audio_buffer = np.array([0]*(512-1),dtype=np.float32)
134
+ vac(audio_buffer)
whisper_online.py CHANGED
@@ -531,7 +531,7 @@ class VACOnlineASRProcessor(OnlineASRProcessor):
531
  # VAC:
532
  import torch
533
  model, _ = torch.hub.load(
534
- repo_or_dir='snakers4/silero-vad:v4.0',
535
  model='silero_vad'
536
  )
537
  from silero_vad import VADIterator
 
531
  # VAC:
532
  import torch
533
  model, _ = torch.hub.load(
534
+ repo_or_dir='snakers4/silero-vad',
535
  model='silero_vad'
536
  )
537
  from silero_vad import VADIterator