Spaces:
cuio
/
No application file

cuio commited on
Commit
fb50115
·
verified ·
1 Parent(s): 67ffb29

Upload 2 files

Browse files
Files changed (2) hide show
  1. asr.py +233 -0
  2. tts.py +216 -0
asr.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import logging
3
+ import time
4
+ import logging
5
+ import sherpa_onnx
6
+ import os
7
+ import asyncio
8
+ import numpy as np
9
+
10
+ logger = logging.getLogger(__file__)
11
+ _asr_engines = {}
12
+
13
+
14
+ class ASRResult:
15
+ def __init__(self, text: str, finished: bool, idx: int):
16
+ self.text = text
17
+ self.finished = finished
18
+ self.idx = idx
19
+
20
+ def to_dict(self):
21
+ return {"text": self.text, "finished": self.finished, "idx": self.idx}
22
+
23
+
24
+ class ASRStream:
25
+ def __init__(self, recognizer: Union[sherpa_onnx.OnlineRecognizer | sherpa_onnx.OfflineRecognizer], sample_rate: int) -> None:
26
+ self.recognizer = recognizer
27
+ self.inbuf = asyncio.Queue()
28
+ self.outbuf = asyncio.Queue()
29
+ self.sample_rate = sample_rate
30
+ self.is_closed = False
31
+ self.online = isinstance(recognizer, sherpa_onnx.OnlineRecognizer)
32
+
33
+ async def start(self):
34
+ if self.online:
35
+ asyncio.create_task(self.run_online())
36
+ else:
37
+ asyncio.create_task(self.run_offline())
38
+
39
+ async def run_online(self):
40
+ stream = self.recognizer.create_stream()
41
+ last_result = ""
42
+ segment_id = 0
43
+ logger.info('asr: start real-time recognizer')
44
+ while not self.is_closed:
45
+ samples = await self.inbuf.get()
46
+ stream.accept_waveform(self.sample_rate, samples)
47
+ while self.recognizer.is_ready(stream):
48
+ self.recognizer.decode_stream(stream)
49
+
50
+ is_endpoint = self.recognizer.is_endpoint(stream)
51
+ result = self.recognizer.get_result(stream)
52
+
53
+ if result and (last_result != result):
54
+ last_result = result
55
+ logger.info(f' > {segment_id}:{result}')
56
+ self.outbuf.put_nowait(
57
+ ASRResult(result, False, segment_id))
58
+
59
+ if is_endpoint:
60
+ if result:
61
+ logger.info(f'{segment_id}: {result}')
62
+ self.outbuf.put_nowait(
63
+ ASRResult(result, True, segment_id))
64
+ segment_id += 1
65
+ self.recognizer.reset(stream)
66
+
67
+ async def run_offline(self):
68
+ vad = _asr_engines['vad']
69
+ segment_id = 0
70
+ st = None
71
+ while not self.is_closed:
72
+ samples = await self.inbuf.get()
73
+ vad.accept_waveform(samples)
74
+ while not vad.empty():
75
+ if not st:
76
+ st = time.time()
77
+ stream = self.recognizer.create_stream()
78
+ stream.accept_waveform(self.sample_rate, vad.front.samples)
79
+
80
+ vad.pop()
81
+ self.recognizer.decode_stream(stream)
82
+
83
+ result = stream.result.text.strip()
84
+ if result:
85
+ duration = time.time() - st
86
+ logger.info(f'{segment_id}:{result} ({duration:.2f}s)')
87
+ self.outbuf.put_nowait(ASRResult(result, True, segment_id))
88
+ segment_id += 1
89
+ st = None
90
+
91
+ async def close(self):
92
+ self.is_closed = True
93
+ self.outbuf.put_nowait(None)
94
+
95
+ async def write(self, pcm_bytes: bytes):
96
+ pcm_data = np.frombuffer(pcm_bytes, dtype=np.int16)
97
+ samples = pcm_data.astype(np.float32) / 32768.0
98
+ self.inbuf.put_nowait(samples)
99
+
100
+ async def read(self) -> ASRResult:
101
+ return await self.outbuf.get()
102
+
103
+
104
+ def create_zipformer(samplerate: int, args) -> sherpa_onnx.OnlineRecognizer:
105
+ d = os.path.join(
106
+ args.models_root, 'sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20')
107
+ if not os.path.exists(d):
108
+ raise ValueError(f"asr: model not found {d}")
109
+
110
+ encoder = os.path.join(d, "encoder-epoch-99-avg-1.onnx")
111
+ decoder = os.path.join(d, "decoder-epoch-99-avg-1.onnx")
112
+ joiner = os.path.join(d, "joiner-epoch-99-avg-1.onnx")
113
+ tokens = os.path.join(d, "tokens.txt")
114
+
115
+ recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
116
+ tokens=tokens,
117
+ encoder=encoder,
118
+ decoder=decoder,
119
+ joiner=joiner,
120
+ provider=args.asr_provider,
121
+ num_threads=args.threads,
122
+ sample_rate=samplerate,
123
+ feature_dim=80,
124
+ enable_endpoint_detection=True,
125
+ rule1_min_trailing_silence=2.4,
126
+ rule2_min_trailing_silence=1.2,
127
+ rule3_min_utterance_length=20, # it essentially disables this rule
128
+ )
129
+ return recognizer
130
+
131
+
132
+ def create_sensevoice(samplerate: int, args) -> sherpa_onnx.OfflineRecognizer:
133
+ d = os.path.join(args.models_root,
134
+ 'sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17')
135
+
136
+ if not os.path.exists(d):
137
+ raise ValueError(f"asr: model not found {d}")
138
+
139
+ recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
140
+ model=os.path.join(d, 'model.onnx'),
141
+ tokens=os.path.join(d, 'tokens.txt'),
142
+ num_threads=args.threads,
143
+ sample_rate=samplerate,
144
+ use_itn=True,
145
+ debug=0,
146
+ language=args.asr_lang,
147
+ )
148
+ return recognizer
149
+
150
+
151
+ def create_paraformer_trilingual(samplerate: int, args) -> sherpa_onnx.OnlineRecognizer:
152
+ d = os.path.join(
153
+ args.models_root, 'sherpa-onnx-paraformer-trilingual-zh-cantonese-en')
154
+ if not os.path.exists(d):
155
+ raise ValueError(f"asr: model not found {d}")
156
+
157
+ recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
158
+ paraformer=os.path.join(d, 'model.onnx'),
159
+ tokens=os.path.join(d, 'tokens.txt'),
160
+ num_threads=args.threads,
161
+ sample_rate=samplerate,
162
+ debug=0,
163
+ provider=args.asr_provider,
164
+ )
165
+ return recognizer
166
+
167
+
168
+ def create_paraformer_en(samplerate: int, args) -> sherpa_onnx.OnlineRecognizer:
169
+ d = os.path.join(
170
+ args.models_root, 'sherpa-onnx-paraformer-en')
171
+ if not os.path.exists(d):
172
+ raise ValueError(f"asr: model not found {d}")
173
+
174
+ recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
175
+ paraformer=os.path.join(d, 'model.onnx'),
176
+ tokens=os.path.join(d, 'tokens.txt'),
177
+ num_threads=args.threads,
178
+ sample_rate=samplerate,
179
+ use_itn=True,
180
+ debug=0,
181
+ provider=args.asr_provider,
182
+ )
183
+ return recognizer
184
+
185
+
186
+ def load_asr_engine(samplerate: int, args) -> sherpa_onnx.OnlineRecognizer:
187
+ cache_engine = _asr_engines.get(args.asr_model)
188
+ if cache_engine:
189
+ return cache_engine
190
+ st = time.time()
191
+ if args.asr_model == 'zipformer-bilingual':
192
+ cache_engine = create_zipformer(samplerate, args)
193
+ elif args.asr_model == 'sensevoice':
194
+ cache_engine = create_sensevoice(samplerate, args)
195
+ _asr_engines['vad'] = load_vad_engine(samplerate, args)
196
+ elif args.asr_model == 'paraformer-trilingual':
197
+ cache_engine = create_paraformer_trilingual(samplerate, args)
198
+ _asr_engines['vad'] = load_vad_engine(samplerate, args)
199
+ elif args.asr_model == 'paraformer-en':
200
+ cache_engine = create_paraformer_en(samplerate, args)
201
+ _asr_engines['vad'] = load_vad_engine(samplerate, args)
202
+ else:
203
+ raise ValueError(f"asr: unknown model {args.asr_model}")
204
+ _asr_engines[args.asr_model] = cache_engine
205
+ logger.info(f"asr: engine loaded in {time.time() - st:.2f}s")
206
+ return cache_engine
207
+
208
+
209
+ def load_vad_engine(samplerate: int, args, min_silence_duration: float = 0.25, buffer_size_in_seconds: int = 100) -> sherpa_onnx.VoiceActivityDetector:
210
+ config = sherpa_onnx.VadModelConfig()
211
+ d = os.path.join(args.models_root, 'silero_vad')
212
+ if not os.path.exists(d):
213
+ raise ValueError(f"vad: model not found {d}")
214
+
215
+ config.silero_vad.model = os.path.join(d, 'silero_vad.onnx')
216
+ config.silero_vad.min_silence_duration = min_silence_duration
217
+ config.sample_rate = samplerate
218
+ config.provider = args.asr_provider
219
+ config.num_threads = args.threads
220
+
221
+ vad = sherpa_onnx.VoiceActivityDetector(
222
+ config,
223
+ buffer_size_in_seconds=buffer_size_in_seconds)
224
+ return vad
225
+
226
+
227
+ async def start_asr_stream(samplerate: int, args) -> ASRStream:
228
+ """
229
+ Start a ASR stream
230
+ """
231
+ stream = ASRStream(load_asr_engine(samplerate, args), samplerate)
232
+ await stream.start()
233
+ return stream
tts.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import os
3
+ import time
4
+ import sherpa_onnx
5
+ import logging
6
+ import numpy as np
7
+ import asyncio
8
+ import time
9
+ import soundfile
10
+ from scipy.signal import resample
11
+ import io
12
+ import re
13
+
14
+ logger = logging.getLogger(__file__)
15
+
16
+ splitter = re.compile(r'[,,。.!?!?;;、\n]')
17
+ _tts_engines = {}
18
+
19
+ tts_configs = {
20
+ 'vits-zh-hf-theresa': {
21
+ 'model': 'theresa.onnx',
22
+ 'lexicon': 'lexicon.txt',
23
+ 'dict_dir': 'dict',
24
+ 'tokens': 'tokens.txt',
25
+ 'sample_rate': 22050,
26
+ # 'rule_fsts': ['phone.fst', 'date.fst', 'number.fst'],
27
+ },
28
+ 'vits-melo-tts-zh_en': {
29
+ 'model': 'model.onnx',
30
+ 'lexicon': 'lexicon.txt',
31
+ 'dict_dir': 'dict',
32
+ 'tokens': 'tokens.txt',
33
+ 'sample_rate': 44100,
34
+ 'rule_fsts': ['phone.fst', 'date.fst', 'number.fst'],
35
+ },
36
+ }
37
+
38
+
39
+ def load_tts_model(name: str, model_root: str, provider: str, num_threads: int = 1, max_num_sentences: int = 20) -> sherpa_onnx.OfflineTtsConfig:
40
+ cfg = tts_configs[name]
41
+ fsts = []
42
+ model_dir = os.path.join(model_root, name)
43
+ for f in cfg.get('rule_fsts', ''):
44
+ fsts.append(os.path.join(model_dir, f))
45
+ tts_rule_fsts = ','.join(fsts) if fsts else ''
46
+
47
+ model_config = sherpa_onnx.OfflineTtsModelConfig(
48
+ vits=sherpa_onnx.OfflineTtsVitsModelConfig(
49
+ model=os.path.join(model_dir, cfg['model']),
50
+ lexicon=os.path.join(model_dir, cfg['lexicon']),
51
+ dict_dir=os.path.join(model_dir, cfg['dict_dir']),
52
+ tokens=os.path.join(model_dir, cfg['tokens']),
53
+ ),
54
+ provider=provider,
55
+ debug=0,
56
+ num_threads=num_threads,
57
+ )
58
+ tts_config = sherpa_onnx.OfflineTtsConfig(
59
+ model=model_config,
60
+ rule_fsts=tts_rule_fsts,
61
+ max_num_sentences=max_num_sentences)
62
+
63
+ if not tts_config.validate():
64
+ raise ValueError("tts: invalid config")
65
+
66
+ return tts_config
67
+
68
+
69
+ def get_tts_engine(args) -> Tuple[sherpa_onnx.OfflineTts, int]:
70
+ sample_rate = tts_configs[args.tts_model]['sample_rate']
71
+ cache_engine = _tts_engines.get(args.tts_model)
72
+ if cache_engine:
73
+ return cache_engine, sample_rate
74
+ st = time.time()
75
+ tts_config = load_tts_model(
76
+ args.tts_model, args.models_root, args.tts_provider)
77
+
78
+ cache_engine = sherpa_onnx.OfflineTts(tts_config)
79
+ elapsed = time.time() - st
80
+ logger.info(f"tts: loaded {args.tts_model} in {elapsed:.2f}s")
81
+ _tts_engines[args.tts_model] = cache_engine
82
+
83
+ return cache_engine, sample_rate
84
+
85
+
86
+ class TTSResult:
87
+ def __init__(self, pcm_bytes: bytes, finished: bool):
88
+ self.pcm_bytes = pcm_bytes
89
+ self.finished = finished
90
+ self.progress: float = 0.0
91
+ self.elapsed: float = 0.0
92
+ self.audio_duration: float = 0.0
93
+ self.audio_size: int = 0
94
+
95
+ def to_dict(self):
96
+ return {
97
+ "progress": self.progress,
98
+ "elapsed": f'{int(self.elapsed * 1000)}ms',
99
+ "duration": f'{self.audio_duration:.2f}s',
100
+ "size": self.audio_size
101
+ }
102
+
103
+
104
+ class TTSStream:
105
+ def __init__(self, engine, sid: int, speed: float = 1.0, sample_rate: int = 16000, original_sample_rate: int = 16000):
106
+ self.engine = engine
107
+ self.sid = sid
108
+ self.speed = speed
109
+ self.outbuf: asyncio.Queue[TTSResult | None] = asyncio.Queue()
110
+ self.is_closed = False
111
+ self.target_sample_rate = sample_rate
112
+ self.original_sample_rate = original_sample_rate
113
+
114
+ def on_process(self, chunk: np.ndarray, progress: float):
115
+ if self.is_closed:
116
+ return 0
117
+
118
+ # resample to target sample rate
119
+ if self.target_sample_rate != self.original_sample_rate:
120
+ num_samples = int(
121
+ len(chunk) * self.target_sample_rate / self.original_sample_rate)
122
+ resampled_chunk = resample(chunk, num_samples)
123
+ chunk = resampled_chunk.astype(np.float32)
124
+
125
+ scaled_chunk = chunk * 32768.0
126
+ clipped_chunk = np.clip(scaled_chunk, -32768, 32767)
127
+ int16_chunk = clipped_chunk.astype(np.int16)
128
+ samples = int16_chunk.tobytes()
129
+ self.outbuf.put_nowait(TTSResult(samples, False))
130
+ return self.is_closed and 0 or 1
131
+
132
+ async def write(self, text: str, split: bool, pause: float = 0.2):
133
+ start = time.time()
134
+ if split:
135
+ texts = re.split(splitter, text)
136
+ else:
137
+ texts = [text]
138
+
139
+ audio_duration = 0.0
140
+ audio_size = 0
141
+
142
+ for idx, text in enumerate(texts):
143
+ text = text.strip()
144
+ if not text:
145
+ continue
146
+ sub_start = time.time()
147
+
148
+ audio = await asyncio.to_thread(self.engine.generate,
149
+ text, self.sid, self.speed,
150
+ self.on_process)
151
+
152
+ if not audio or not audio.sample_rate or not audio.samples:
153
+ logger.error(f"tts: failed to generate audio for "
154
+ f"'{text}' (audio={audio})")
155
+ continue
156
+
157
+ if split and idx < len(texts) - 1: # add a pause between sentences
158
+ noise = np.zeros(int(audio.sample_rate * pause))
159
+ self.on_process(noise, 1.0)
160
+ audio.samples = np.concatenate([audio.samples, noise])
161
+
162
+ audio_duration += len(audio.samples) / audio.sample_rate
163
+ audio_size += len(audio.samples)
164
+ elapsed_seconds = time.time() - sub_start
165
+ logger.info(f"tts: generated audio for '{text}', "
166
+ f"audio duration: {audio_duration:.2f}s, "
167
+ f"elapsed: {elapsed_seconds:.2f}s")
168
+
169
+ elapsed_seconds = time.time() - start
170
+ logger.info(f"tts: generated audio in {elapsed_seconds:.2f}s, "
171
+ f"audio duration: {audio_duration:.2f}s")
172
+
173
+ r = TTSResult(None, True)
174
+ r.elapsed = elapsed_seconds
175
+ r.audio_duration = audio_duration
176
+ r.progress = 1.0
177
+ r.finished = True
178
+ await self.outbuf.put(r)
179
+
180
+ async def close(self):
181
+ self.is_closed = True
182
+ self.outbuf.put_nowait(None)
183
+ logger.info("tts: stream closed")
184
+
185
+ async def read(self) -> TTSResult:
186
+ return await self.outbuf.get()
187
+
188
+ async def generate(self, text: str) -> io.BytesIO:
189
+ start = time.time()
190
+ audio = await asyncio.to_thread(self.engine.generate,
191
+ text, self.sid, self.speed)
192
+ elapsed_seconds = time.time() - start
193
+ audio_duration = len(audio.samples) / audio.sample_rate
194
+
195
+ logger.info(f"tts: generated audio in {elapsed_seconds:.2f}s, "
196
+ f"audio duration: {audio_duration:.2f}s, "
197
+ f"sample rate: {audio.sample_rate}")
198
+
199
+ if self.target_sample_rate != audio.sample_rate:
200
+ audio.samples = resample(audio.samples,
201
+ int(len(audio.samples) * self.target_sample_rate / audio.sample_rate))
202
+ audio.sample_rate = self.target_sample_rate
203
+
204
+ output = io.BytesIO()
205
+ soundfile.write(output,
206
+ audio.samples,
207
+ samplerate=audio.sample_rate,
208
+ subtype="PCM_16",
209
+ format="WAV")
210
+ output.seek(0)
211
+ return output
212
+
213
+
214
+ async def start_tts_stream(sid: int, sample_rate: int, speed: float, args) -> TTSStream:
215
+ engine, original_sample_rate = get_tts_engine(args)
216
+ return TTSStream(engine, sid, speed, sample_rate, original_sample_rate)