Spaces:
Running
Running
File size: 4,865 Bytes
8289369 48811fe 8289369 5d74f79 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
"""
语音识别模块基类
"""
import os
import numpy as np
from pydub import AudioSegment
from typing import Dict, List, Union, Optional, Tuple
# from dataclasses import dataclass # dataclass is now imported from schemas if needed or already there
import logging
from ..schemas import TranscriptionResult # Added import
# 配置日志
logger = logging.getLogger("asr")
class BaseTranscriber:
"""统一的语音识别基类,支持MLX和Transformers等多种框架"""
def __init__(
self,
model_name: str,
device: str = None,
):
"""
初始化转录器
参数:
model_name: 模型名称
device: 推理设备,'cpu'或'cuda',对于MLX框架此参数可忽略
"""
self.model_name = model_name
self.device = device
self.pipeline = None # 用于Transformers
self.model = None # 用于MLX等其他框架
logger.info(f"初始化转录器,模型: {model_name}" + (f",设备: {device}" if device else ""))
# 子类需要实现_load_model方法
self._load_model()
def _load_model(self):
"""
加载模型(需要在子类中实现)
"""
raise NotImplementedError("子类必须实现_load_model方法")
def transcribe(self, audio: AudioSegment) -> TranscriptionResult:
"""
转录音频,针对distil-whisper模型取消分块处理,直接处理整个音频。
参数:
audio: 要转录的AudioSegment对象
返回:
TranscriptionResult对象,包含转录结果
"""
logger.info(f"开始转录 {len(audio)/1000:.2f} 秒的音频") # 移除了模型名称,因为基类不知道具体模型
# 直接处理整个音频,不进行分块
processed_audio = self._prepare_audio(audio)
samples = np.array(processed_audio.get_array_of_samples(), dtype=np.float32) / 32768.0
try:
model_result = self._perform_transcription(samples)
text = self._get_text_from_result(model_result)
segments = self._convert_segments(model_result)
language = self._detect_language(text)
logger.info(f"转录完成,语言: {language},文本长度: {len(text)},分段数: {len(segments)}")
return TranscriptionResult(text=text, segments=segments, language=language)
except Exception as e:
logger.error(f"转录失败: {str(e)}", exc_info=True)
raise RuntimeError(f"转录失败: {str(e)}")
def _get_text_from_result(self, result):
"""
从结果中获取文本
参数:
result: 模型的转录结果
返回:
转录的文本
"""
return result.get("text", "")
def _perform_transcription(self, audio_data):
"""执行转录的抽象方法,由子类实现"""
raise NotImplementedError("子类必须实现_perform_transcription方法")
def _convert_segments(self, result) -> List[Dict[str, Union[float, str]]]:
"""将模型结果转换为分段的抽象方法,由子类实现"""
raise NotImplementedError("子类必须实现_convert_segments方法")
def _prepare_audio(self, audio: AudioSegment) -> AudioSegment:
"""
准备音频数据
参数:
audio: 输入的AudioSegment对象
返回:
处理后的AudioSegment对象
"""
logger.debug(f"准备音频数据: 时长={len(audio)/1000:.2f}秒, 采样率={audio.frame_rate}Hz, 声道数={audio.channels}")
# 确保采样率为16kHz
if audio.frame_rate != 16000:
logger.debug(f"重采样音频从 {audio.frame_rate}Hz 到 16000Hz")
audio = audio.set_frame_rate(16000)
# 确保是单声道
if audio.channels > 1:
logger.debug(f"将{audio.channels}声道音频转换为单声道")
audio = audio.set_channels(1)
logger.debug(f"音频处理完成")
return audio
def _detect_language(self, text: str) -> str:
"""
简单的语言检测(基于经验规则)
参数:
text: 识别出的文本
返回:
检测到的语言代码
"""
# 简单的规则检测,实际应用中应使用更准确的语言检测
chinese_chars = len([c for c in text if '\u4e00' <= c <= '\u9fff'])
chinese_ratio = chinese_chars / len(text) if text else 0
logger.debug(f"语言检测: 中文字符比例 = {chinese_ratio:.2f}")
if chinese_chars > len(text) * 0.3:
return "zh"
return "en" |