TaiYouWeb commited on
Commit
47af204
·
verified ·
1 Parent(s): 889a5cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -43
app.py CHANGED
@@ -9,61 +9,40 @@ import json
9
  from typing import Optional
10
 
11
  import torch
12
-
13
- from fastapi import FastAPI, File, Form, UploadFile, HTTPException
14
- from fastapi.responses import StreamingResponse, Response
15
 
16
  from config import model_config
17
 
18
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
19
  model_dir = snapshot_download(model_config['model_dir'])
20
 
21
- class SynthesizeResponse(Response):
22
- media_type = 'text/plain'
23
-
24
- app = FastAPI()
 
 
 
 
 
 
 
 
25
 
26
- @app.post('/asr', response_class=SynthesizeResponse)
27
- async def generate(
28
- file: UploadFile = File(...),
29
- vad_model: str = Form("fsmn-vad"),
30
- vad_kwargs: str = Form('{"max_single_segment_time": 30000}'),
31
- ncpu: int = Form(4),
32
- batch_size: int = Form(1),
33
- language: str = Form("auto"),
34
- use_itn: bool = Form(True),
35
- batch_size_s: int = Form(60),
36
- merge_vad: bool = Form(True),
37
- merge_length_s: int = Form(15),
38
- batch_size_threshold_s: int = Form(50),
39
- hotword: Optional[str] = Form(" "),
40
- spk_model: str = Form("cam++"),
41
- ban_emo_unk: bool = Form(False),
42
- ) -> StreamingResponse:
43
  try:
44
  # 将字符串转换为字典
45
  vad_kwargs = json.loads(vad_kwargs)
46
-
47
  # 创建临时文件并保存上传的音频文件
48
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
49
  temp_file_path = temp_file.name
50
- input_wav_bytes = await file.read()
51
- temp_file.write(input_wav_bytes)
52
 
53
  try:
54
- # 初始化模型
55
- model = AutoModel(
56
- model=model_dir,
57
- trust_remote_code=False,
58
- remote_code="./model.py",
59
- vad_model=vad_model,
60
- vad_kwargs=vad_kwargs,
61
- ncpu=ncpu,
62
- batch_size=batch_size,
63
- hub="ms",
64
- device=device,
65
- )
66
-
67
  # 生成结果
68
  res = model.generate(
69
  input=temp_file_path, # 使用临时文件路径作为输入
@@ -82,8 +61,7 @@ async def generate(
82
  # 处理结果
83
  text = rich_transcription_postprocess(res[0]["text"])
84
 
85
- # 返回结果
86
- return StreamingResponse(io.BytesIO(text.encode('utf-8')), media_type="text/plain")
87
 
88
  finally:
89
  # 确保在处理完毕后删除临时文件
@@ -91,4 +69,31 @@ async def generate(
91
  os.remove(temp_file_path)
92
 
93
  except Exception as e:
94
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from typing import Optional
10
 
11
  import torch
12
+ import gradio as gr # 添加Gradio库
 
 
13
 
14
  from config import model_config
15
 
16
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
17
  model_dir = snapshot_download(model_config['model_dir'])
18
 
19
+ # 初始化模型
20
+ model = AutoModel(
21
+ model=model_dir,
22
+ trust_remote_code=False,
23
+ remote_code="./model.py",
24
+ vad_model="fsmn-vad",
25
+ vad_kwargs={"max_single_segment_time": 30000},
26
+ ncpu=4,
27
+ batch_size=1,
28
+ hub="ms",
29
+ device=device,
30
+ )
31
 
32
+ def transcribe_audio(file, vad_model="fsmn-vad", vad_kwargs='{"max_single_segment_time": 30000}',
33
+ ncpu=4, batch_size=1, language="auto", use_itn=True, batch_size_s=60,
34
+ merge_vad=True, merge_length_s=15, batch_size_threshold_s=50,
35
+ hotword=" ", spk_model="cam++", ban_emo_unk=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  try:
37
  # 将字符串转换为字典
38
  vad_kwargs = json.loads(vad_kwargs)
39
+
40
  # 创建临时文件并保存上传的音频文件
41
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
42
  temp_file_path = temp_file.name
43
+ temp_file.write(file.read())
 
44
 
45
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  # 生成结果
47
  res = model.generate(
48
  input=temp_file_path, # 使用临时文件路径作为输入
 
61
  # 处理结果
62
  text = rich_transcription_postprocess(res[0]["text"])
63
 
64
+ return text
 
65
 
66
  finally:
67
  # 确保在处理完毕后删除临时文件
 
69
  os.remove(temp_file_path)
70
 
71
  except Exception as e:
72
+ return str(e)
73
+
74
+ # 创建Gradio界面
75
+ inputs = [
76
+ gr.Audio(source="upload", type="file"), # 上传音频
77
+ gr.Textbox(value="fsmn-vad", label="VAD Model"),
78
+ gr.Textbox(value='{"max_single_segment_time": 30000}', label="VAD Kwargs"),
79
+ gr.Slider(1, 8, value=4, step=1, label="NCPU"),
80
+ gr.Slider(1, 10, value=1, step=1, label="Batch Size"),
81
+ gr.Textbox(value="auto", label="Language"),
82
+ gr.Checkbox(value=True, label="Use ITN"),
83
+ gr.Slider(30, 120, value=60, step=1, label="Batch Size (seconds)"),
84
+ gr.Checkbox(value=True, label="Merge VAD"),
85
+ gr.Slider(5, 60, value=15, step=1, label="Merge Length (seconds)"),
86
+ gr.Slider(10, 100, value=50, step=1, label="Batch Size Threshold (seconds)"),
87
+ gr.Textbox(value=" ", label="Hotword"),
88
+ gr.Textbox(value="cam++", label="Speaker Model"),
89
+ gr.Checkbox(value=False, label="Ban Emotional Unknown"),
90
+ ]
91
+
92
+ outputs = gr.Textbox(label="Transcription")
93
+
94
+ gr.Interface(
95
+ fn=transcribe_audio,
96
+ inputs=inputs,
97
+ outputs=outputs,
98
+ title="ASR Transcription with FunASR"
99
+ ).launch()