File size: 7,261 Bytes
8289369
 
 
 
 
 
48811fe
8289369
7803eb5
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48811fe
 
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48811fe
 
 
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48811fe
 
 
 
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7803eb5
8289369
 
 
 
 
 
 
48811fe
 
 
8289369
 
 
 
48811fe
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
"""
ASR模型调用路由器
根据传递的provider参数调用不同的ASR实现,支持延迟加载
"""

import logging
from typing import Dict, Any, Optional, Callable
from pydub import AudioSegment
import spaces
from .asr_base import TranscriptionResult

# 配置日志
logger = logging.getLogger("asr")


class ASRRouter:
    """ASR模型调用路由器,支持多种ASR实现的统一调用"""
    
    def __init__(self):
        """初始化路由器"""
        self._loaded_modules = {}  # 用于缓存已加载的模块
        self._transcribers = {}    # 用于缓存已实例化的转录器
        
        # 定义支持的provider配置
        self._provider_configs = {
            "distil_whisper_transformers": {
                "module_path": ".asr_distil_whisper_transformers",
                "function_name": "transcribe_audio",
                "default_model": "distil-whisper/distil-large-v3.5",
                "supported_params": ["model_name", "device"],
                "description": "基于Transformers的Distil Whisper模型"
            }
        }
    
    def _lazy_load_module(self, provider: str):
        """
        获取指定provider的模块
        
        参数:
            provider: provider名称
            
        返回:
            对应的模块
        """
        if provider not in self._provider_configs:
            raise ValueError(f"不支持的provider: {provider}")
            
        if provider not in self._loaded_modules:
            module_path = self._provider_configs[provider]["module_path"]
            logger.info(f"获取模块: {module_path}")
            
            # 使用 importlib 动态导入模块
            import importlib
            module = importlib.import_module(module_path, package=__package__)
            
            self._loaded_modules[provider] = module
            logger.info(f"模块 {module_path} 获取成功")
        
        return self._loaded_modules[provider]
    
    def _get_transcribe_function(self, provider: str) -> Callable:
        """
        获取指定provider的转录函数
        
        参数:
            provider: provider名称
            
        返回:
            转录函数
        """
        module = self._lazy_load_module(provider)
        function_name = self._provider_configs[provider]["function_name"]
        
        if not hasattr(module, function_name):
            raise AttributeError(f"模块中未找到函数: {function_name}")
            
        return getattr(module, function_name)
    
    def _filter_params(self, provider: str, params: Dict[str, Any]) -> Dict[str, Any]:
        """
        过滤参数,只保留指定provider支持的参数
        
        参数:
            provider: provider名称
            params: 原始参数字典
            
        返回:
            过滤后的参数字典
        """
        supported_params = self._provider_configs[provider]["supported_params"]
        filtered_params = {}
        
        for param in supported_params:
            if param in params:
                filtered_params[param] = params[param]
        
        # 如果没有指定model_name,使用默认模型
        if "model_name" not in filtered_params and "model_name" in supported_params:
            filtered_params["model_name"] = self._provider_configs[provider]["default_model"]
        
        # 对于 Transformers backend,如果 device 未指定,则默认为 cpu
        if provider == "distil_whisper_transformers" and "device" in supported_params and "device" not in filtered_params:
            filtered_params["device"] = "cpu"
        
        return filtered_params
    
    def transcribe(
        self,
        audio_segment: AudioSegment,
        provider: str,
        **kwargs
    ) -> TranscriptionResult:
        """
        统一的音频转录接口
        
        参数:
            audio_segment: 输入的AudioSegment对象
            provider: ASR提供者名称
            **kwargs: 其他参数,如model_name, device等
            
        返回:
            TranscriptionResult对象
        """
        logger.info(f"使用provider '{provider}' 进行音频转录,音频长度: {len(audio_segment)/1000:.2f}秒")
        
        if provider not in self._provider_configs:
            available_providers = list(self._provider_configs.keys())
            raise ValueError(f"不支持的provider: {provider}。支持的provider: {available_providers}")
        
        try:
            # 获取转录函数
            transcribe_func = self._get_transcribe_function(provider)
            
            # 过滤并准备参数
            filtered_kwargs = self._filter_params(provider, kwargs)
            
            logger.debug(f"调用 {provider} 转录函数,参数: {filtered_kwargs}")
            
            # 执行转录
            result = transcribe_func(audio_segment, **filtered_kwargs)
            
            logger.info(f"转录完成,文本长度: {len(result.text)}字符")
            return result
            
        except Exception as e:
            logger.error(f"使用provider '{provider}' 转录音频失败: {str(e)}", exc_info=True)
            raise RuntimeError(f"转录失败: {str(e)}")
    
    def get_available_providers(self) -> Dict[str, str]:
        """
        获取所有可用的provider及其描述
        
        返回:
            provider名称到描述的映射
        """
        return {
            provider: config["description"] 
            for provider, config in self._provider_configs.items()
        }
    
    def get_provider_info(self, provider: str) -> Dict[str, Any]:
        """
        获取指定provider的详细信息
        
        参数:
            provider: provider名称
            
        返回:
            provider的配置信息
        """
        if provider not in self._provider_configs:
            raise ValueError(f"不支持的provider: {provider}")
            
        return self._provider_configs[provider].copy()


# 创建全局路由器实例
_router = ASRRouter()

@spaces.GPU(duration=180)
def transcribe_audio(
    audio_segment: AudioSegment,
    provider: str = "distil_whisper_transformers",
    model_name: Optional[str] = None,
    device: str = "cpu",
    **kwargs
) -> TranscriptionResult:
    """
    统一的音频转录接口,通过路由器选择后端
    """
    # 准备参数
    params = kwargs.copy()
    if model_name is not None:
        params["model_name"] = model_name
    if device != "cpu": # 只有当 device 不是默认值才传递,或者根据需要传递所有支持的参数
        params["device"] = device
    
    return _router.transcribe(audio_segment, provider, **params)


def get_available_providers() -> Dict[str, str]:
    """
    获取所有可用的ASR提供者
    
    返回:
        provider名称到描述的映射
    """
    return _router.get_available_providers()


def get_provider_info(provider: str) -> Dict[str, Any]:
    """
    获取指定provider的详细信息
    
    参数:
        provider: provider名称
        
    返回:
        provider的配置信息
    """
    return _router.get_provider_info(provider)