Spaces:
cuio
/
No application file

cuio commited on
Commit
67ffb29
·
verified ·
1 Parent(s): ea385ed

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +194 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from fastapi import FastAPI, HTTPException, Request, WebSocket, WebSocketDisconnect, Query
3
+ from fastapi.responses import HTMLResponse, StreamingResponse
4
+ from fastapi.staticfiles import StaticFiles
5
+ import asyncio
6
+ import logging
7
+ from pydantic import BaseModel, Field
8
+ import uvicorn
9
+ from voiceapi.tts import TTSResult, start_tts_stream, TTSStream
10
+ from voiceapi.asr import start_asr_stream, ASRStream, ASRResult
11
+ import logging
12
+ import argparse
13
+ import os
14
+
15
+ app = FastAPI()
16
+ logger = logging.getLogger(__file__)
17
+
18
+
19
+ @app.websocket("/asr")
20
+ async def websocket_asr(websocket: WebSocket,
21
+ samplerate: int = Query(16000, title="Sample Rate",
22
+ description="The sample rate of the audio."),):
23
+ await websocket.accept()
24
+
25
+ asr_stream: ASRStream = await start_asr_stream(samplerate, args)
26
+ if not asr_stream:
27
+ logger.error("failed to start ASR stream")
28
+ await websocket.close()
29
+ return
30
+
31
+ async def task_recv_pcm():
32
+ while True:
33
+ pcm_bytes = await websocket.receive_bytes()
34
+ if not pcm_bytes:
35
+ return
36
+ await asr_stream.write(pcm_bytes)
37
+
38
+ async def task_send_result():
39
+ while True:
40
+ result: ASRResult = await asr_stream.read()
41
+ if not result:
42
+ return
43
+ await websocket.send_json(result.to_dict())
44
+ try:
45
+ await asyncio.gather(task_recv_pcm(), task_send_result())
46
+ except WebSocketDisconnect:
47
+ logger.info("asr: disconnected")
48
+ finally:
49
+ await asr_stream.close()
50
+
51
+
52
+ @app.websocket("/tts")
53
+ async def websocket_tts(websocket: WebSocket,
54
+ samplerate: int = Query(16000,
55
+ title="Sample Rate",
56
+ description="The sample rate of the generated audio."),
57
+ interrupt: bool = Query(True,
58
+ title="Interrupt",
59
+ description="Interrupt the current TTS stream when a new text is received."),
60
+ sid: int = Query(0,
61
+ title="Speaker ID",
62
+ description="The ID of the speaker to use for TTS."),
63
+ chunk_size: int = Query(1024,
64
+ title="Chunk Size",
65
+ description="The size of the chunk to send to the client."),
66
+ speed: float = Query(1.0,
67
+ title="Speed",
68
+ description="The speed of the generated audio."),
69
+ split: bool = Query(True,
70
+ title="Split",
71
+ description="Split the text into sentences.")):
72
+
73
+ await websocket.accept()
74
+ tts_stream: TTSStream = None
75
+
76
+ async def task_recv_text():
77
+ nonlocal tts_stream
78
+ while True:
79
+ text = await websocket.receive_text()
80
+ if not text:
81
+ return
82
+
83
+ if interrupt or not tts_stream:
84
+ if tts_stream:
85
+ await tts_stream.close()
86
+ logger.info("tts: stream interrupt")
87
+
88
+ tts_stream = await start_tts_stream(sid, samplerate, speed, args)
89
+ if not tts_stream:
90
+ logger.error("tts: failed to allocate tts stream")
91
+ await websocket.close()
92
+ return
93
+ logger.info(f"tts: received: {text} (split={split})")
94
+ await tts_stream.write(text, split)
95
+
96
+ async def task_send_pcm():
97
+ nonlocal tts_stream
98
+ while not tts_stream:
99
+ # wait for tts stream to be created
100
+ await asyncio.sleep(0.1)
101
+
102
+ while True:
103
+ result: TTSResult = await tts_stream.read()
104
+ if not result:
105
+ return
106
+
107
+ if result.finished:
108
+ await websocket.send_json(result.to_dict())
109
+ else:
110
+ for i in range(0, len(result.pcm_bytes), chunk_size):
111
+ await websocket.send_bytes(result.pcm_bytes[i:i+chunk_size])
112
+
113
+ try:
114
+ await asyncio.gather(task_recv_text(), task_send_pcm())
115
+ except WebSocketDisconnect:
116
+ logger.info("tts: disconnected")
117
+ finally:
118
+ if tts_stream:
119
+ await tts_stream.close()
120
+
121
+
122
+ class TTSRequest(BaseModel):
123
+ text: str = Field(..., title="Text",
124
+ description="The text to be converted to speech.",
125
+ examples=["Hello, world!"])
126
+ sid: int = Field(0, title="Speaker ID",
127
+ description="The ID of the speaker to use for TTS.")
128
+ samplerate: int = Field(16000, title="Sample Rate",
129
+ description="The sample rate of the generated audio.")
130
+ speed: float = Field(1.0, title="Speed",
131
+ description="The speed of the generated audio.")
132
+
133
+
134
+ @ app.post("/tts",
135
+ description="Generate speech audio from text.",
136
+ response_class=StreamingResponse, responses={200: {"content": {"audio/wav": {}}}})
137
+ async def tts_generate(req: TTSRequest):
138
+ if not req.text:
139
+ raise HTTPException(status_code=400, detail="text is required")
140
+
141
+ tts_stream = await start_tts_stream(req.sid, req.samplerate, req.speed, args)
142
+ if not tts_stream:
143
+ raise HTTPException(
144
+ status_code=500, detail="failed to start TTS stream")
145
+
146
+ r = await tts_stream.generate(req.text)
147
+ return StreamingResponse(r, media_type="audio/wav")
148
+
149
+
150
+ if __name__ == "__main__":
151
+ models_root = './models'
152
+
153
+ for d in ['.', '..', '../..']:
154
+ if os.path.isdir(f'{d}/models'):
155
+ models_root = f'{d}/models'
156
+ break
157
+
158
+ parser = argparse.ArgumentParser()
159
+ parser.add_argument("--port", type=int, default=8000, help="port number")
160
+ parser.add_argument("--addr", type=str,
161
+ default="0.0.0.0", help="serve address")
162
+
163
+ parser.add_argument("--asr-provider", type=str,
164
+ default="cpu", help="asr provider, cpu or cuda")
165
+ parser.add_argument("--tts-provider", type=str,
166
+ default="cpu", help="tts provider, cpu or cuda")
167
+
168
+ parser.add_argument("--threads", type=int, default=2,
169
+ help="number of threads")
170
+
171
+ parser.add_argument("--models-root", type=str, default=models_root,
172
+ help="model root directory")
173
+
174
+ parser.add_argument("--asr-model", type=str, default='sensevoice',
175
+ help="ASR model name: zipformer-bilingual, sensevoice, paraformer-trilingual, paraformer-en")
176
+
177
+ parser.add_argument("--asr-lang", type=str, default='zh',
178
+ help="ASR language, zh, en, ja, ko, yue")
179
+
180
+ parser.add_argument("--tts-model", type=str, default='vits-zh-hf-theresa',
181
+ help="TTS model name: vits-zh-hf-theresa, vits-melo-tts-zh_en")
182
+
183
+ args = parser.parse_args()
184
+
185
+ if args.tts_model == 'vits-melo-tts-zh_en' and args.tts_provider == 'cuda':
186
+ logger.warning(
187
+ "vits-melo-tts-zh_en does not support CUDA fallback to CPU")
188
+ args.tts_provider = 'cpu'
189
+
190
+ app.mount("/", app=StaticFiles(directory="./assets", html=True), name="assets")
191
+
192
+ logging.basicConfig(format='%(levelname)s: %(asctime)s %(name)s:%(lineno)s %(message)s',
193
+ level=logging.INFO)
194
+ uvicorn.run(app, host=args.addr, port=args.port)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ sherpa-onnx == 1.10.24
2
+ soundfile == 0.12.1
3
+ fastapi == 0.114.1
4
+ uvicorn == 0.30.6
5
+ scipy == 1.13.1
6
+ numpy == 1.26.4
7
+ websockets == 13.0.1