File size: 4,865 Bytes
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48811fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d74f79
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
语音识别模块基类
"""

import os
import numpy as np
from pydub import AudioSegment
from typing import Dict, List, Union, Optional, Tuple
# from dataclasses import dataclass # dataclass is now imported from schemas if needed or already there
import logging

from ..schemas import TranscriptionResult # Added import

# 配置日志
logger = logging.getLogger("asr")


class BaseTranscriber:
    """统一的语音识别基类,支持MLX和Transformers等多种框架"""
    
    def __init__(
        self, 
        model_name: str,
        device: str = None,
    ):
        """
        初始化转录器
        
        参数:
            model_name: 模型名称
            device: 推理设备,'cpu'或'cuda',对于MLX框架此参数可忽略
        """
        self.model_name = model_name
        self.device = device
        self.pipeline = None  # 用于Transformers
        self.model = None     # 用于MLX等其他框架
                
        logger.info(f"初始化转录器,模型: {model_name}" + (f",设备: {device}" if device else ""))
        
        # 子类需要实现_load_model方法
        self._load_model()
        
    def _load_model(self):
        """
        加载模型(需要在子类中实现)
        """
        raise NotImplementedError("子类必须实现_load_model方法")
    
    def transcribe(self, audio: AudioSegment) -> TranscriptionResult:
        """
        转录音频,针对distil-whisper模型取消分块处理,直接处理整个音频。

        参数:
            audio: 要转录的AudioSegment对象

        返回:
            TranscriptionResult对象,包含转录结果
        """
        logger.info(f"开始转录 {len(audio)/1000:.2f} 秒的音频") # 移除了模型名称,因为基类不知道具体模型

        # 直接处理整个音频,不进行分块
        processed_audio = self._prepare_audio(audio)
        samples = np.array(processed_audio.get_array_of_samples(), dtype=np.float32) / 32768.0

        try:
            model_result = self._perform_transcription(samples)
            text = self._get_text_from_result(model_result)
            segments = self._convert_segments(model_result)
            language = self._detect_language(text)

            logger.info(f"转录完成,语言: {language},文本长度: {len(text)},分段数: {len(segments)}")
            return TranscriptionResult(text=text, segments=segments, language=language)
        except Exception as e:
            logger.error(f"转录失败: {str(e)}", exc_info=True)
            raise RuntimeError(f"转录失败: {str(e)}")

    def _get_text_from_result(self, result):
        """
        从结果中获取文本

        参数:
            result: 模型的转录结果

        返回:
            转录的文本
        """
        return result.get("text", "")

    def _perform_transcription(self, audio_data):
        """执行转录的抽象方法,由子类实现"""
        raise NotImplementedError("子类必须实现_perform_transcription方法")

    def _convert_segments(self, result) -> List[Dict[str, Union[float, str]]]:
        """将模型结果转换为分段的抽象方法,由子类实现"""
        raise NotImplementedError("子类必须实现_convert_segments方法")
    
    def _prepare_audio(self, audio: AudioSegment) -> AudioSegment:
        """
        准备音频数据
        
        参数:
            audio: 输入的AudioSegment对象
            
        返回:
            处理后的AudioSegment对象
        """
        logger.debug(f"准备音频数据: 时长={len(audio)/1000:.2f}秒, 采样率={audio.frame_rate}Hz, 声道数={audio.channels}")
        
        # 确保采样率为16kHz
        if audio.frame_rate != 16000:
            logger.debug(f"重采样音频从 {audio.frame_rate}Hz 到 16000Hz")
            audio = audio.set_frame_rate(16000)
            
        # 确保是单声道
        if audio.channels > 1:
            logger.debug(f"将{audio.channels}声道音频转换为单声道")
            audio = audio.set_channels(1)
            
        logger.debug(f"音频处理完成")
        
        return audio
    
    def _detect_language(self, text: str) -> str:
        """
        简单的语言检测(基于经验规则)
        
        参数:
            text: 识别出的文本
            
        返回:
            检测到的语言代码
        """
        # 简单的规则检测,实际应用中应使用更准确的语言检测
        chinese_chars = len([c for c in text if '\u4e00' <= c <= '\u9fff'])
        chinese_ratio = chinese_chars / len(text) if text else 0
        logger.debug(f"语言检测: 中文字符比例 = {chinese_ratio:.2f}")
        
        if chinese_chars > len(text) * 0.3:
            return "zh"
        return "en"