podcast-transcriber / src /podcast_transcribe /asr /asr_distil_whisper_transformers.py
konieshadow's picture
更新ASR模块,重构转录逻辑,新增基于Transformers的distil-whisper模型支持,并优化音频处理流程。
48811fe
"""
基于Transformers实现的语音识别模块,使用distil-whisper模型
"""
import os
from pydub import AudioSegment
from typing import Dict, List, Union, Literal
import logging
import numpy as np
# 导入基类
from .asr_base import BaseTranscriber, TranscriptionResult
# 配置日志
logger = logging.getLogger("asr")
class TransformersDistilWhisperTranscriber(BaseTranscriber):
"""使用Transformers加载和运行distil-whisper模型的转录器"""
def __init__(
self,
model_name: str = "distil-whisper/distil-large-v3.5",
device: str = "cpu",
):
"""
初始化转录器
参数:
model_name: 模型名称
device: 推理设备,'cpu'或'cuda'
"""
super().__init__(model_name=model_name, device=device)
def _load_model(self):
"""加载Distil Whisper Transformers模型"""
try:
# 懒加载transformers
try:
from transformers import pipeline
except ImportError:
raise ImportError("请先安装transformers库: pip install transformers")
logger.info(f"开始加载模型 {self.model_name} 设备: {self.device}")
pipeline_device_arg = None
if self.device == "cuda":
pipeline_device_arg = 0 # 使用第一个 CUDA 设备
elif self.device == "mps":
pipeline_device_arg = "mps" # 使用 MPS 设备
elif self.device == "cpu":
pipeline_device_arg = -1 # 使用 CPU
else:
# 对于其他未明确支持的 device 字符串,记录警告并默认使用 CPU
logger.warning(f"不支持的设备字符串 '{self.device}',将默认使用 CPU。")
pipeline_device_arg = -1
# 导入必要的模块来配置模型
import warnings
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
# 抑制特定的警告
warnings.filterwarnings("ignore", message="The input name `inputs` is deprecated")
warnings.filterwarnings("ignore", message="You have passed task=transcribe")
warnings.filterwarnings("ignore", message="The attention mask is not set")
self.pipeline = pipeline(
"automatic-speech-recognition",
model=self.model_name,
device=pipeline_device_arg,
return_timestamps=True,
chunk_length_s=30, # 使用30秒的块长度
stride_length_s=5, # 块之间5秒的重叠
batch_size=32, # 顺序处理
# 添加以下参数来减少警告
generate_kwargs={
"task": "transcribe",
"language": None, # 自动检测语言
"forced_decoder_ids": None, # 避免冲突
}
)
logger.info(f"模型加载成功")
except Exception as e:
logger.error(f"加载模型失败: {str(e)}", exc_info=True)
raise RuntimeError(f"加载模型失败: {str(e)}")
def _convert_segments(self, result) -> List[Dict[str, Union[float, str]]]:
"""
将模型的分段结果转换为所需格式
参数:
result: 模型返回的结果
返回:
转换后的分段列表
"""
segments = []
# transformers pipeline 的结果格式
if "chunks" in result:
for chunk in result["chunks"]:
segments.append({
"start": chunk["timestamp"][0] if chunk["timestamp"][0] is not None else 0.0,
"end": chunk["timestamp"][1] if chunk["timestamp"][1] is not None else 0.0,
"text": chunk["text"].strip()
})
else:
# 如果没有分段信息,创建一个单一分段
segments.append({
"start": 0.0,
"end": 0.0, # 无法确定结束时间
"text": result.get("text", "").strip()
})
return segments
def _perform_transcription(self, audio_data):
"""
执行转录
参数:
audio_data: 音频数据(numpy数组)
返回:
模型的转录结果
"""
# transformers pipeline 接受numpy数组作为输入
# 音频数据已经在_prepare_audio中确保是16kHz采样率
# 确保音频数据格式正确
if audio_data.dtype != np.float32:
audio_data = audio_data.astype(np.float32)
# 使用正确的参数名称调用pipeline
try:
result = self.pipeline(
audio_data,
generate_kwargs={
"task": "transcribe",
"language": None, # 自动检测语言
"forced_decoder_ids": None, # 避免冲突
}
)
return result
except Exception as e:
logger.warning(f"使用新参数格式失败,尝试使用默认参数: {str(e)}")
# 如果新格式失败,回退到简单调用
return self.pipeline(audio_data)
# 统一的接口函数
def transcribe_audio(
audio_segment: AudioSegment,
model_name: str = None,
device: str = "cpu",
) -> TranscriptionResult:
"""
使用Distil Whisper模型转录音频 (Transformers后端)
参数:
audio_segment: 输入的AudioSegment对象
model_name: 使用的模型名称,如果不指定则使用默认模型
device: 推理设备,'cpu'或'cuda'
返回:
TranscriptionResult对象,包含转录的文本、分段和语言
"""
logger.info(f"调用 transcribe_audio 函数 (Transformers后端),音频长度: {len(audio_segment)/1000:.2f}秒,设备: {device}")
default_model = "distil-whisper/distil-large-v3.5"
model = model_name or default_model
transcriber = TransformersDistilWhisperTranscriber(model_name=model, device=device)
return transcriber.transcribe(audio_segment)