Dominik Macháček commited on
Commit
7edc534
·
1 Parent(s): 14c2bbe

clean code with VAC

Browse files
Files changed (3) hide show
  1. silero_vad.py +95 -0
  2. voice_activity_controller.py +28 -104
  3. whisper_online.py +60 -18
silero_vad.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # this is copypasted from silero-vad's vad_utils.py:
4
+ # https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/utils_vad.py#L340
5
+
6
+ class VADIterator:
7
+ def __init__(self,
8
+ model,
9
+ threshold: float = 0.5,
10
+ sampling_rate: int = 16000,
11
+ min_silence_duration_ms: int = 100,
12
+ speech_pad_ms: int = 30
13
+ ):
14
+
15
+ """
16
+ Class for stream imitation
17
+
18
+ Parameters
19
+ ----------
20
+ model: preloaded .jit silero VAD model
21
+
22
+ threshold: float (default - 0.5)
23
+ Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
24
+ It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
25
+
26
+ sampling_rate: int (default - 16000)
27
+ Currently silero VAD models support 8000 and 16000 sample rates
28
+
29
+ min_silence_duration_ms: int (default - 100 milliseconds)
30
+ In the end of each speech chunk wait for min_silence_duration_ms before separating it
31
+
32
+ speech_pad_ms: int (default - 30 milliseconds)
33
+ Final speech chunks are padded by speech_pad_ms each side
34
+ """
35
+
36
+ self.model = model
37
+ self.threshold = threshold
38
+ self.sampling_rate = sampling_rate
39
+
40
+ if sampling_rate not in [8000, 16000]:
41
+ raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
42
+
43
+ self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
44
+ self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
45
+ self.reset_states()
46
+
47
+ def reset_states(self):
48
+
49
+ self.model.reset_states()
50
+ self.triggered = False
51
+ self.temp_end = 0
52
+ self.current_sample = 0
53
+
54
+ def __call__(self, x, return_seconds=False):
55
+ """
56
+ x: torch.Tensor
57
+ audio chunk (see examples in repo)
58
+
59
+ return_seconds: bool (default - False)
60
+ whether return timestamps in seconds (default - samples)
61
+ """
62
+
63
+ if not torch.is_tensor(x):
64
+ try:
65
+ x = torch.Tensor(x)
66
+ except:
67
+ raise TypeError("Audio cannot be casted to tensor. Cast it manually")
68
+
69
+ window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
70
+ self.current_sample += window_size_samples
71
+
72
+ speech_prob = self.model(x, self.sampling_rate).item()
73
+
74
+ if (speech_prob >= self.threshold) and self.temp_end:
75
+ self.temp_end = 0
76
+
77
+ if (speech_prob >= self.threshold) and not self.triggered:
78
+ self.triggered = True
79
+ speech_start = self.current_sample - self.speech_pad_samples
80
+ return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
81
+
82
+ if (speech_prob < self.threshold - 0.15) and self.triggered:
83
+ if not self.temp_end:
84
+ self.temp_end = self.current_sample
85
+ if self.current_sample - self.temp_end < self.min_silence_samples:
86
+ return None
87
+ else:
88
+ speech_end = self.temp_end + self.speech_pad_samples
89
+ self.temp_end = 0
90
+ self.triggered = False
91
+ return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)}
92
+
93
+ return None
94
+
95
+
voice_activity_controller.py CHANGED
@@ -1,111 +1,35 @@
1
  import torch
2
- import numpy as np
 
3
 
4
  class VoiceActivityController:
5
- def __init__(
6
- self,
7
- sampling_rate = 16000,
8
- min_silence_to_final_ms = 500,
9
- min_speech_to_final_ms = 100,
10
- min_silence_duration_ms = 100,
11
- use_vad_result = True,
12
- # activity_detected_callback=None,
13
- threshold =0.3
14
- ):
15
- # self.activity_detected_callback=activity_detected_callback
16
- self.model, self.utils = torch.hub.load(
17
  repo_or_dir='snakers4/silero-vad',
18
  model='silero_vad'
19
  )
20
- # (self.get_speech_timestamps,
21
- # save_audio,
22
- # read_audio,
23
- # VADIterator,
24
- # collect_chunks) = self.utils
25
-
26
- self.sampling_rate = sampling_rate
27
- self.final_silence_limit = min_silence_to_final_ms * self.sampling_rate / 1000
28
- self.final_speech_limit = min_speech_to_final_ms *self.sampling_rate / 1000
29
- self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
30
-
31
- self.use_vad_result = use_vad_result
32
- self.threshold = threshold
33
- self.reset_states()
34
-
35
- def reset_states(self):
36
- self.model.reset_states()
37
- self.temp_end = 0
38
- self.current_sample = 0
39
-
40
- self.last_silence_len= 0
41
- self.speech_len = 0
42
-
43
- def apply_vad(self, audio):
44
- """
45
- returns: triple
46
- (voice_audio,
47
- speech_in_wav,
48
- silence_in_wav)
49
-
50
- """
51
- print("applying vad here")
52
  x = audio
53
- if not torch.is_tensor(x):
54
- try:
55
- x = torch.Tensor(x)
56
- except:
57
- raise TypeError("Audio cannot be casted to tensor. Cast it manually")
58
-
59
- speech_prob = self.model(x, self.sampling_rate).item()
60
- print("speech_prob",speech_prob)
61
-
62
- window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
63
- self.current_sample += window_size_samples
64
-
65
- if speech_prob >= self.threshold: # speech is detected
66
- self.temp_end = 0
67
- return audio, window_size_samples, 0
68
-
69
- else: # silence detected, counting w
70
- if not self.temp_end:
71
- self.temp_end = self.current_sample
72
-
73
- if self.current_sample - self.temp_end < self.min_silence_samples:
74
- return audio, 0, window_size_samples
75
- else:
76
- return np.array([], dtype=np.float16) if self.use_vad_result else audio, 0, window_size_samples
77
-
78
-
79
- def detect_speech_iter(self, data, audio_in_int16 = False):
80
- audio_block = data
81
- wav = audio_block
82
-
83
- is_final = False
84
- voice_audio, speech_in_wav, last_silent_in_wav = self.apply_vad(wav)
85
- print("speech, last silence",speech_in_wav, last_silent_in_wav)
86
-
87
-
88
- if speech_in_wav > 0 :
89
- self.last_silence_len= 0
90
- self.speech_len += speech_in_wav
91
- # if self.activity_detected_callback is not None:
92
- # self.activity_detected_callback()
93
-
94
- self.last_silence_len += last_silent_in_wav
95
- print("self.last_silence_len",self.last_silence_len, self.final_silence_limit,self.last_silence_len>= self.final_silence_limit)
96
- print("self.speech_len, final_speech_limit",self.speech_len , self.final_speech_limit,self.speech_len >= self.final_speech_limit)
97
- if self.last_silence_len>= self.final_silence_limit and self.speech_len >= self.final_speech_limit:
98
- for i in range(10): print("TADY!!!")
99
-
100
- is_final = True
101
- self.last_silence_len= 0
102
- self.speech_len = 0
103
-
104
- return voice_audio, is_final
105
-
106
- def detect_user_speech(self, audio_stream, audio_in_int16 = False):
107
- self.last_silence_len= 0
108
- self.speech_len = 0
109
-
110
- for data in audio_stream: # replace with your condition of choice
111
- yield self.detect_speech_iter(data, audio_in_int16)
 
1
  import torch
2
+ from silero_vad import VADIterator
3
+ import time
4
 
5
  class VoiceActivityController:
6
+ SAMPLING_RATE = 16000
7
+ def __init__(self):
8
+ self.model, _ = torch.hub.load(
 
 
 
 
 
 
 
 
 
9
  repo_or_dir='snakers4/silero-vad',
10
  model='silero_vad'
11
  )
12
+ # we use the default options: 500ms silence, etc.
13
+ self.iterator = VADIterator(self.model)
14
+
15
+ def reset(self):
16
+ self.iterator.reset_states()
17
+
18
+ def __call__(self, audio):
19
+ '''
20
+ audio: audio chunk in the current np.array format
21
+ returns:
22
+ - { 'start': time_frame } ... when voice start was detected. time_frame is number of frame (can be converted to seconds)
23
+ - { 'end': time_frame } ... when voice end is detected
24
+ - None ... when no change detected by current chunk
25
+ '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  x = audio
27
+ # if not torch.is_tensor(x):
28
+ # try:
29
+ # x = torch.Tensor(x)
30
+ # except:
31
+ # raise TypeError("Audio cannot be casted to tensor. Cast it manually")
32
+ t = time.time()
33
+ a = self.iterator(x)
34
+ print("VAD took ",time.time()-t,"seconds")
35
+ return a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
whisper_online.py CHANGED
@@ -331,16 +331,14 @@ class OnlineASRProcessor:
331
 
332
  self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
333
 
334
- def init(self, keep_offset=False):
335
  """run this when starting or restarting processing"""
336
  self.audio_buffer = np.array([],dtype=np.float32)
337
  self.transcript_buffer = HypothesisBuffer(logfile=self.logfile)
338
- if not keep_offset:
339
- self.buffer_time_offset = 0
340
- self.transcript_buffer.last_commited_time = 0
341
- else:
342
- self.transcript_buffer.last_commited_time = self.buffer_time_offset
343
-
344
  self.commited = []
345
 
346
  def insert_audio_chunk(self, audio):
@@ -529,27 +527,71 @@ class VACOnlineASRProcessor(OnlineASRProcessor):
529
  self.online_chunk_size = online_chunk_size
530
 
531
  self.online = OnlineASRProcessor(*a, **kw)
532
- from voice_activity_controller import VoiceActivityController
533
- self.vac = VoiceActivityController(use_vad_result = False)
534
 
535
- self.logfile = self.online.logfile
 
 
 
 
 
 
 
536
 
 
537
  self.init()
538
 
539
  def init(self):
540
  self.online.init()
541
  self.vac.reset_states()
542
  self.current_online_chunk_buffer_size = 0
 
543
  self.is_currently_final = False
544
 
 
 
 
 
 
 
 
 
545
 
546
  def insert_audio_chunk(self, audio):
547
- r = self.vac.detect_speech_iter(audio,audio_in_int16=False)
548
- audio, is_final = r
549
- print(is_final)
550
- self.is_currently_final = is_final
551
- self.online.insert_audio_chunk(audio)
552
- self.current_online_chunk_buffer_size += len(audio)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
553
 
554
  def process_iter(self):
555
  if self.is_currently_final:
@@ -559,13 +601,13 @@ class VACOnlineASRProcessor(OnlineASRProcessor):
559
  ret = self.online.process_iter()
560
  return ret
561
  else:
562
- print("no online update, only VAD", file=self.logfile)
563
  return (None, None, "")
564
 
565
  def finish(self):
566
  ret = self.online.finish()
567
- self.online.init(keep_offset=True)
568
  self.current_online_chunk_buffer_size = 0
 
569
  return ret
570
 
571
 
 
331
 
332
  self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
333
 
334
+ def init(self, offset=None):
335
  """run this when starting or restarting processing"""
336
  self.audio_buffer = np.array([],dtype=np.float32)
337
  self.transcript_buffer = HypothesisBuffer(logfile=self.logfile)
338
+ self.buffer_time_offset = 0
339
+ if offset is not None:
340
+ self.buffer_time_offset = offset
341
+ self.transcript_buffer.last_commited_time = self.buffer_time_offset
 
 
342
  self.commited = []
343
 
344
  def insert_audio_chunk(self, audio):
 
527
  self.online_chunk_size = online_chunk_size
528
 
529
  self.online = OnlineASRProcessor(*a, **kw)
 
 
530
 
531
+ # VAC:
532
+ import torch
533
+ model, _ = torch.hub.load(
534
+ repo_or_dir='snakers4/silero-vad',
535
+ model='silero_vad'
536
+ )
537
+ from silero_vad import VADIterator
538
+ self.vac = VADIterator(model) # we use all the default options: 500ms silence, etc.
539
 
540
+ self.logfile = self.online.logfile
541
  self.init()
542
 
543
  def init(self):
544
  self.online.init()
545
  self.vac.reset_states()
546
  self.current_online_chunk_buffer_size = 0
547
+
548
  self.is_currently_final = False
549
 
550
+ self.status = None # or "voice" or "nonvoice"
551
+ self.audio_buffer = np.array([],dtype=np.float32)
552
+ self.buffer_offset = 0 # in frames
553
+
554
+ def clear_buffer(self):
555
+ self.buffer_offset += len(self.audio_buffer)
556
+ self.audio_buffer = np.array([],dtype=np.float32)
557
+
558
 
559
  def insert_audio_chunk(self, audio):
560
+ res = self.vac(audio)
561
+ print(res)
562
+ self.audio_buffer = np.append(self.audio_buffer, audio)
563
+
564
+ if res is not None:
565
+ frame = list(res.values())[0]
566
+ if 'start' in res and 'end' not in res:
567
+ self.status = 'voice'
568
+ send_audio = self.audio_buffer[frame-self.buffer_offset:]
569
+ self.online.init(offset=frame/self.SAMPLING_RATE)
570
+ self.online.insert_audio_chunk(send_audio)
571
+ self.current_online_chunk_buffer_size += len(send_audio)
572
+ self.clear_buffer()
573
+ elif 'end' in res and 'start' not in res:
574
+ self.status = 'nonvoice'
575
+ send_audio = self.audio_buffer[:frame-self.buffer_offset]
576
+ self.online.insert_audio_chunk(send_audio)
577
+ self.current_online_chunk_buffer_size += len(send_audio)
578
+ self.is_currently_final = True
579
+ self.clear_buffer()
580
+ else:
581
+ # It doesn't happen in the current code.
582
+ raise NotImplemented("both start and end of voice in one chunk!!!")
583
+ else:
584
+ if self.status == 'voice':
585
+ self.online.insert_audio_chunk(self.audio_buffer)
586
+ self.current_online_chunk_buffer_size += len(self.audio_buffer)
587
+ if self.status is not None:
588
+ self.clear_buffer()
589
+ else: # we are at the beginning of process, no voice has ever been detected
590
+ # We keep the 1s because VAD may later find start of voice in it.
591
+ # But trimming it to prevent OOM.
592
+ self.buffer_offset += max(0,len(self.audio_buffer)-self.SAMPLING_RATE)
593
+ self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE:]
594
+
595
 
596
  def process_iter(self):
597
  if self.is_currently_final:
 
601
  ret = self.online.process_iter()
602
  return ret
603
  else:
604
+ print("no online update, only VAD", self.status, file=self.logfile)
605
  return (None, None, "")
606
 
607
  def finish(self):
608
  ret = self.online.finish()
 
609
  self.current_online_chunk_buffer_size = 0
610
+ self.is_currently_final = False
611
  return ret
612
 
613