File size: 6,380 Bytes
5d74f79
48811fe
5d74f79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48811fe
5d74f79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406e6ac
5d74f79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48811fe
5d74f79
 
 
 
48811fe
5d74f79
 
 
 
48811fe
5d74f79
48811fe
 
 
5d74f79
48811fe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
"""
基于Transformers实现的语音识别模块,使用distil-whisper模型
"""

import os
from pydub import AudioSegment
from typing import Dict, List, Union, Literal
import logging
import numpy as np

# 导入基类
from .asr_base import BaseTranscriber, TranscriptionResult

# 配置日志
logger = logging.getLogger("asr")


class TransformersDistilWhisperTranscriber(BaseTranscriber):
    """使用Transformers加载和运行distil-whisper模型的转录器"""
    
    def __init__(
        self, 
        model_name: str = "distil-whisper/distil-large-v3.5",
        device: str = "cpu",
    ):
        """
        初始化转录器
        
        参数:
            model_name: 模型名称
            device: 推理设备,'cpu'或'cuda'
        """
        super().__init__(model_name=model_name, device=device)
        
    def _load_model(self):
        """加载Distil Whisper Transformers模型"""
        try:
            # 懒加载transformers
            try:
                from transformers import pipeline
            except ImportError:
                raise ImportError("请先安装transformers库: pip install transformers")
                
            logger.info(f"开始加载模型 {self.model_name} 设备: {self.device}")

            pipeline_device_arg = None
            if self.device == "cuda":
                pipeline_device_arg = 0  # 使用第一个 CUDA 设备
            elif self.device == "mps":
                pipeline_device_arg = "mps"  # 使用 MPS 设备
            elif self.device == "cpu":
                pipeline_device_arg = -1 # 使用 CPU
            else:
                # 对于其他未明确支持的 device 字符串,记录警告并默认使用 CPU
                logger.warning(f"不支持的设备字符串 '{self.device}',将默认使用 CPU。")
                pipeline_device_arg = -1
            
            # 导入必要的模块来配置模型
            import warnings
            from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
            
            # 抑制特定的警告
            warnings.filterwarnings("ignore", message="The input name `inputs` is deprecated")
            warnings.filterwarnings("ignore", message="You have passed task=transcribe")
            warnings.filterwarnings("ignore", message="The attention mask is not set")
            
            self.pipeline = pipeline(
                "automatic-speech-recognition",
                model=self.model_name,
                device=pipeline_device_arg,
                return_timestamps=True,
                chunk_length_s=30,      # 使用30秒的块长度
                stride_length_s=5,      # 块之间5秒的重叠
                batch_size=32,           # 顺序处理
                # 添加以下参数来减少警告
                generate_kwargs={
                    "task": "transcribe",
                    "language": None,  # 自动检测语言
                    "forced_decoder_ids": None,  # 避免冲突
                }
            )
            logger.info(f"模型加载成功")
        except Exception as e:
            logger.error(f"加载模型失败: {str(e)}", exc_info=True)
            raise RuntimeError(f"加载模型失败: {str(e)}")
    
    def _convert_segments(self, result) -> List[Dict[str, Union[float, str]]]:
        """
        将模型的分段结果转换为所需格式
        
        参数:
            result: 模型返回的结果
            
        返回:
            转换后的分段列表
        """
        segments = []
        
        # transformers pipeline 的结果格式
        if "chunks" in result:
            for chunk in result["chunks"]:
                segments.append({
                    "start": chunk["timestamp"][0] if chunk["timestamp"][0] is not None else 0.0,
                    "end": chunk["timestamp"][1] if chunk["timestamp"][1] is not None else 0.0,
                    "text": chunk["text"].strip()
                })
        else:
            # 如果没有分段信息,创建一个单一分段
            segments.append({
                "start": 0.0,
                "end": 0.0,  # 无法确定结束时间
                "text": result.get("text", "").strip()
            })
        
        return segments
    
    def _perform_transcription(self, audio_data):
        """
        执行转录
        
        参数:
            audio_data: 音频数据(numpy数组)
            
        返回:
            模型的转录结果
        """
        # transformers pipeline 接受numpy数组作为输入
        # 音频数据已经在_prepare_audio中确保是16kHz采样率
        
        # 确保音频数据格式正确
        if audio_data.dtype != np.float32:
            audio_data = audio_data.astype(np.float32)
        
        # 使用正确的参数名称调用pipeline
        try:
            result = self.pipeline(
                audio_data,
                generate_kwargs={
                    "task": "transcribe",
                    "language": None,  # 自动检测语言
                    "forced_decoder_ids": None,  # 避免冲突
                }
            )
            return result
        except Exception as e:
            logger.warning(f"使用新参数格式失败,尝试使用默认参数: {str(e)}")
            # 如果新格式失败,回退到简单调用
            return self.pipeline(audio_data)

# 统一的接口函数
def transcribe_audio(
    audio_segment: AudioSegment,
    model_name: str = None,
    device: str = "cpu",
) -> TranscriptionResult:
    """
    使用Distil Whisper模型转录音频 (Transformers后端)
    
    参数:
        audio_segment: 输入的AudioSegment对象
        model_name: 使用的模型名称,如果不指定则使用默认模型
        device: 推理设备,'cpu'或'cuda'
        
    返回:
        TranscriptionResult对象,包含转录的文本、分段和语言
    """
    logger.info(f"调用 transcribe_audio 函数 (Transformers后端),音频长度: {len(audio_segment)/1000:.2f}秒,设备: {device}")
    
    default_model = "distil-whisper/distil-large-v3.5"
    model = model_name or default_model
    transcriber = TransformersDistilWhisperTranscriber(model_name=model, device=device)
    
    return transcriber.transcribe(audio_segment)