File size: 5,795 Bytes
889a5cc
 
 
 
 
 
 
 
 
 
 
47af204
889a5cc
 
 
56bdf87
 
 
 
 
889a5cc
 
 
47af204
 
 
 
 
 
 
 
 
 
 
 
889a5cc
15ad32c
47af204
 
 
889a5cc
 
 
47af204
15ad32c
 
889a5cc
e8db4c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
889a5cc
 
e8db4c4
47af204
 
 
 
15ad32c
47af204
48b5bc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f7df9f
56bdf87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess
from modelscope import snapshot_download

import io
import os
import tempfile
import json
from typing import Optional

import torch
import gradio as gr  # 添加Gradio库

from config import model_config

from fastapi import FastAPI, File, Form, UploadFile, HTTPException
from fastapi.responses import StreamingResponse, Response

import uvicorn

device = "cuda:0" if torch.cuda.is_available() else "cpu"
model_dir = snapshot_download(model_config['model_dir'])

# 初始化模型
model = AutoModel(
    model=model_dir,
    trust_remote_code=False,
    remote_code="./model.py",
    vad_model="fsmn-vad",
    vad_kwargs={"max_single_segment_time": 30000},
    ncpu=4,
    batch_size=1,
    hub="ms",
    device=device,
)

def transcribe_audio(file_path, vad_model="fsmn-vad", vad_kwargs='{"max_single_segment_time": 30000}', 
                     ncpu=4, batch_size=1, language="auto", use_itn=True, batch_size_s=60, 
                     merge_vad=True, merge_length_s=15, batch_size_threshold_s=50, 
                     hotword=" ", spk_model="cam++", ban_emo_unk=False):
    try:
        # 将字符串转换为字典
        vad_kwargs = json.loads(vad_kwargs)
        
        # 使用文件路径作为输入
        temp_file_path = file_path

        # 生成结果
        res = model.generate(
            input=temp_file_path,  # 使用文件路径作为输入
            cache={},
            language=language,
            use_itn=use_itn,
            batch_size_s=batch_size_s,
            merge_vad=merge_vad,
            merge_length_s=merge_length_s,
            batch_size_threshold_s=batch_size_threshold_s,
            hotword=hotword,
            spk_model=spk_model,
            ban_emo_unk=ban_emo_unk
        )

        # 处理结果
        text = rich_transcription_postprocess(res[0]["text"])
        
        return text

    except Exception as e:
        # 捕获异常并返回错误信息
        return str(e)

# 创建Gradio界面
inputs = [
    gr.Audio(type="filepath"),  # 设置为'filepath'来支持文件路径
    gr.Textbox(value="fsmn-vad", label="VAD Model"),
    gr.Textbox(value='{"max_single_segment_time": 30000}', label="VAD Kwargs"),
    gr.Slider(1, 8, value=4, step=1, label="NCPU"),
    gr.Slider(1, 10, value=1, step=1, label="Batch Size"),
    gr.Textbox(value="auto", label="Language"),
    gr.Checkbox(value=True, label="Use ITN"),
    gr.Slider(30, 120, value=60, step=1, label="Batch Size (seconds)"),
    gr.Checkbox(value=True, label="Merge VAD"),
    gr.Slider(5, 60, value=15, step=1, label="Merge Length (seconds)"),
    gr.Slider(10, 100, value=50, step=1, label="Batch Size Threshold (seconds)"),
    gr.Textbox(value=" ", label="Hotword"),
    gr.Textbox(value="cam++", label="Speaker Model"),
    gr.Checkbox(value=False, label="Ban Emotional Unknown"),
]

outputs = gr.Textbox(label="Transcription")

gr.Interface(
    fn=transcribe_audio, 
    inputs=inputs, 
    outputs=outputs, 
    title="ASR Transcription with FunASR"
).launch()


class SynthesizeResponse(Response):
    media_type = 'text/plain'

app = FastAPI()

@app.post('/asr', response_class=SynthesizeResponse)
async def generate(
    file: UploadFile = File(...),
    vad_model: str = Form("fsmn-vad"),
    vad_kwargs: str = Form('{"max_single_segment_time": 30000}'),
    ncpu: int = Form(4),
    batch_size: int = Form(1),
    language: str = Form("auto"),
    use_itn: bool = Form(True),
    batch_size_s: int = Form(60),
    merge_vad: bool = Form(True),
    merge_length_s: int = Form(15),
    batch_size_threshold_s: int = Form(50),
    hotword: Optional[str] = Form(" "),
    spk_model: str = Form("cam++"),
    ban_emo_unk: bool = Form(False),
) -> StreamingResponse:
    try:
        # 将字符串转换为字典
        vad_kwargs = json.loads(vad_kwargs)

        # 创建临时文件并保存上传的音频文件
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
            temp_file_path = temp_file.name
            input_wav_bytes = await file.read()
            temp_file.write(input_wav_bytes)

        try:
            # 初始化模型
            model = AutoModel(
                model=model_dir,
                trust_remote_code=False,
                remote_code="./model.py",
                vad_model=vad_model,
                vad_kwargs=vad_kwargs,
                ncpu=ncpu,
                batch_size=batch_size,
                hub="ms",
                device=device,
            )

            # 生成结果
            res = model.generate(
                input=temp_file_path,  # 使用临时文件路径作为输入
                cache={},
                language=language,
                use_itn=use_itn,
                batch_size_s=batch_size_s,
                merge_vad=merge_vad,
                merge_length_s=merge_length_s,
                batch_size_threshold_s=batch_size_threshold_s,
                hotword=hotword,
                spk_model=spk_model,
                ban_emo_unk=ban_emo_unk
            )

            # 处理结果
            text = rich_transcription_postprocess(res[0]["text"])
            
            # 返回结果
            return StreamingResponse(io.BytesIO(text.encode('utf-8')), media_type="text/plain")
        
        finally:
            # 确保在处理完毕后删除临时文件
            if os.path.exists(temp_file_path):
                os.remove(temp_file_path)
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/root")
async def read_root():
    return {"message": "Hello World"}
    
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)