Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -9,61 +9,40 @@ import json
|
|
9 |
from typing import Optional
|
10 |
|
11 |
import torch
|
12 |
-
|
13 |
-
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
|
14 |
-
from fastapi.responses import StreamingResponse, Response
|
15 |
|
16 |
from config import model_config
|
17 |
|
18 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
19 |
model_dir = snapshot_download(model_config['model_dir'])
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
vad_kwargs: str = Form('{"max_single_segment_time": 30000}'),
|
31 |
-
ncpu: int = Form(4),
|
32 |
-
batch_size: int = Form(1),
|
33 |
-
language: str = Form("auto"),
|
34 |
-
use_itn: bool = Form(True),
|
35 |
-
batch_size_s: int = Form(60),
|
36 |
-
merge_vad: bool = Form(True),
|
37 |
-
merge_length_s: int = Form(15),
|
38 |
-
batch_size_threshold_s: int = Form(50),
|
39 |
-
hotword: Optional[str] = Form(" "),
|
40 |
-
spk_model: str = Form("cam++"),
|
41 |
-
ban_emo_unk: bool = Form(False),
|
42 |
-
) -> StreamingResponse:
|
43 |
try:
|
44 |
# 将字符串转换为字典
|
45 |
vad_kwargs = json.loads(vad_kwargs)
|
46 |
-
|
47 |
# 创建临时文件并保存上传的音频文件
|
48 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
|
49 |
temp_file_path = temp_file.name
|
50 |
-
|
51 |
-
temp_file.write(input_wav_bytes)
|
52 |
|
53 |
try:
|
54 |
-
# 初始化模型
|
55 |
-
model = AutoModel(
|
56 |
-
model=model_dir,
|
57 |
-
trust_remote_code=False,
|
58 |
-
remote_code="./model.py",
|
59 |
-
vad_model=vad_model,
|
60 |
-
vad_kwargs=vad_kwargs,
|
61 |
-
ncpu=ncpu,
|
62 |
-
batch_size=batch_size,
|
63 |
-
hub="ms",
|
64 |
-
device=device,
|
65 |
-
)
|
66 |
-
|
67 |
# 生成结果
|
68 |
res = model.generate(
|
69 |
input=temp_file_path, # 使用临时文件路径作为输入
|
@@ -82,8 +61,7 @@ async def generate(
|
|
82 |
# 处理结果
|
83 |
text = rich_transcription_postprocess(res[0]["text"])
|
84 |
|
85 |
-
|
86 |
-
return StreamingResponse(io.BytesIO(text.encode('utf-8')), media_type="text/plain")
|
87 |
|
88 |
finally:
|
89 |
# 确保在处理完毕后删除临时文件
|
@@ -91,4 +69,31 @@ async def generate(
|
|
91 |
os.remove(temp_file_path)
|
92 |
|
93 |
except Exception as e:
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
from typing import Optional
|
10 |
|
11 |
import torch
|
12 |
+
import gradio as gr # 添加Gradio库
|
|
|
|
|
13 |
|
14 |
from config import model_config
|
15 |
|
16 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
17 |
model_dir = snapshot_download(model_config['model_dir'])
|
18 |
|
19 |
+
# 初始化模型
|
20 |
+
model = AutoModel(
|
21 |
+
model=model_dir,
|
22 |
+
trust_remote_code=False,
|
23 |
+
remote_code="./model.py",
|
24 |
+
vad_model="fsmn-vad",
|
25 |
+
vad_kwargs={"max_single_segment_time": 30000},
|
26 |
+
ncpu=4,
|
27 |
+
batch_size=1,
|
28 |
+
hub="ms",
|
29 |
+
device=device,
|
30 |
+
)
|
31 |
|
32 |
+
def transcribe_audio(file, vad_model="fsmn-vad", vad_kwargs='{"max_single_segment_time": 30000}',
|
33 |
+
ncpu=4, batch_size=1, language="auto", use_itn=True, batch_size_s=60,
|
34 |
+
merge_vad=True, merge_length_s=15, batch_size_threshold_s=50,
|
35 |
+
hotword=" ", spk_model="cam++", ban_emo_unk=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
try:
|
37 |
# 将字符串转换为字典
|
38 |
vad_kwargs = json.loads(vad_kwargs)
|
39 |
+
|
40 |
# 创建临时文件并保存上传的音频文件
|
41 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
|
42 |
temp_file_path = temp_file.name
|
43 |
+
temp_file.write(file.read())
|
|
|
44 |
|
45 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
# 生成结果
|
47 |
res = model.generate(
|
48 |
input=temp_file_path, # 使用临时文件路径作为输入
|
|
|
61 |
# 处理结果
|
62 |
text = rich_transcription_postprocess(res[0]["text"])
|
63 |
|
64 |
+
return text
|
|
|
65 |
|
66 |
finally:
|
67 |
# 确保在处理完毕后删除临时文件
|
|
|
69 |
os.remove(temp_file_path)
|
70 |
|
71 |
except Exception as e:
|
72 |
+
return str(e)
|
73 |
+
|
74 |
+
# 创建Gradio界面
|
75 |
+
inputs = [
|
76 |
+
gr.Audio(source="upload", type="file"), # 上传音频
|
77 |
+
gr.Textbox(value="fsmn-vad", label="VAD Model"),
|
78 |
+
gr.Textbox(value='{"max_single_segment_time": 30000}', label="VAD Kwargs"),
|
79 |
+
gr.Slider(1, 8, value=4, step=1, label="NCPU"),
|
80 |
+
gr.Slider(1, 10, value=1, step=1, label="Batch Size"),
|
81 |
+
gr.Textbox(value="auto", label="Language"),
|
82 |
+
gr.Checkbox(value=True, label="Use ITN"),
|
83 |
+
gr.Slider(30, 120, value=60, step=1, label="Batch Size (seconds)"),
|
84 |
+
gr.Checkbox(value=True, label="Merge VAD"),
|
85 |
+
gr.Slider(5, 60, value=15, step=1, label="Merge Length (seconds)"),
|
86 |
+
gr.Slider(10, 100, value=50, step=1, label="Batch Size Threshold (seconds)"),
|
87 |
+
gr.Textbox(value=" ", label="Hotword"),
|
88 |
+
gr.Textbox(value="cam++", label="Speaker Model"),
|
89 |
+
gr.Checkbox(value=False, label="Ban Emotional Unknown"),
|
90 |
+
]
|
91 |
+
|
92 |
+
outputs = gr.Textbox(label="Transcription")
|
93 |
+
|
94 |
+
gr.Interface(
|
95 |
+
fn=transcribe_audio,
|
96 |
+
inputs=inputs,
|
97 |
+
outputs=outputs,
|
98 |
+
title="ASR Transcription with FunASR"
|
99 |
+
).launch()
|