File size: 2,955 Bytes
889a5cc
 
 
 
 
 
14b80ce
889a5cc
 
 
 
47af204
81f707b
47af204
 
323958e
cfa3d27
50b0eff
3830518
4872dd1
47af204
4872dd1
47af204
 
889a5cc
cfa3d27
 
 
3830518
889a5cc
 
15ad32c
889a5cc
e8db4c4
14b80ce
e8db4c4
 
 
 
 
 
 
 
 
 
 
cfa3d27
 
 
 
 
 
 
 
 
 
 
 
889a5cc
 
47af204
 
 
14b80ce
47af204
a82fccb
50b0eff
59fa2b8
48b5bc5
 
 
 
 
 
 
 
2cb87d0
48b5bc5
 
 
 
 
cfa3d27
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess
from modelscope import snapshot_download

import json
import torch
import gradio as gr

device = "cuda:0" if torch.cuda.is_available() else "cpu"
model_dir = snapshot_download(model_config['model_dir'])

model = AutoModel(
    model="FunAudioLLM/SenseVoiceSmall",
    trust_remote_code=False,
    remote_code="./model.py",
    vad_model="fsmn-vad",
    punc_model="ct-punc",
    spk_model="cam++",
    vad_kwargs={"max_single_segment_time": 15000},
    ncpu=torch.get_num_threads(),
    batch_size=1,
    hub="hf",
    device=device,
)

def transcribe_audio(file_path, vad_model="fsmn-vad", punc_model="ct-punc", spk_model="cam++", vad_kwargs='{"max_single_segment_time": 15000}',
                     batch_size=1, language="auto", use_itn=True, batch_size_s=60,
                     merge_vad=True, merge_length_s=15, batch_size_threshold_s=50,
                     hotword=" ", ban_emo_unk=True):
    try:
        vad_kwargs = json.loads(vad_kwargs)
        temp_file_path = file_path

        res = model.generate(
            input=temp_file_path,
            cache={},
            language=language,
            use_itn=use_itn,
            batch_size_s=batch_size_s,
            merge_vad=merge_vad,
            merge_length_s=merge_length_s,
            batch_size_threshold_s=batch_size_threshold_s,
            hotword=hotword,
            ban_emo_unk=ban_emo_unk
        )

        segments = res[0]["segments"]
        transcription = ""

        for segment in segments:
            start_time = segment["start"]
            end_time = segment["end"]
            speaker = segment.get("speaker", "unknown")
            text = segment["text"]

            transcription += f"[{start_time:.2f}s - {end_time:.2f}s] Speaker {speaker}: {text}\n"

        return transcription

    except Exception as e:
        return str(e)

inputs = [
    gr.Audio(type="filepath"),
    gr.Textbox(value="fsmn-vad", label="VAD Model"),
    gr.Textbox(value="ct-punc", label="PUNC Model"),
    gr.Textbox(value="cam++", label="SPK Model"),
    gr.Textbox(value='{"max_single_segment_time": 15000}', label="VAD Kwargs"),
    gr.Slider(1, 10, value=1, step=1, label="Batch Size"),
    gr.Textbox(value="auto", label="Language"),
    gr.Checkbox(value=True, label="Use ITN"),
    gr.Slider(30, 120, value=60, step=1, label="Batch Size (seconds)"),
    gr.Checkbox(value=True, label="Merge VAD"),
    gr.Slider(5, 60, value=15, step=1, label="Merge Length (seconds)"),
    gr.Slider(10, 100, value=50, step=1, label="Batch Size Threshold (seconds)"),
    gr.Textbox(value=" ", label="Hotword"),
    gr.Checkbox(value=True, label="Ban Emotional Unknown"),
]

outputs = gr.Textbox(label="Transcription")

gr.Interface(
    fn=transcribe_audio,
    inputs=inputs,
    outputs=outputs,
    title="ASR Transcription with Speaker Diarization and Timestamps"
).launch()