Spaces:
Running
Running
File size: 7,261 Bytes
8289369 48811fe 8289369 7803eb5 8289369 48811fe 8289369 48811fe 8289369 48811fe 8289369 7803eb5 8289369 48811fe 8289369 48811fe 8289369 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
"""
ASR模型调用路由器
根据传递的provider参数调用不同的ASR实现,支持延迟加载
"""
import logging
from typing import Dict, Any, Optional, Callable
from pydub import AudioSegment
import spaces
from .asr_base import TranscriptionResult
# 配置日志
logger = logging.getLogger("asr")
class ASRRouter:
"""ASR模型调用路由器,支持多种ASR实现的统一调用"""
def __init__(self):
"""初始化路由器"""
self._loaded_modules = {} # 用于缓存已加载的模块
self._transcribers = {} # 用于缓存已实例化的转录器
# 定义支持的provider配置
self._provider_configs = {
"distil_whisper_transformers": {
"module_path": ".asr_distil_whisper_transformers",
"function_name": "transcribe_audio",
"default_model": "distil-whisper/distil-large-v3.5",
"supported_params": ["model_name", "device"],
"description": "基于Transformers的Distil Whisper模型"
}
}
def _lazy_load_module(self, provider: str):
"""
获取指定provider的模块
参数:
provider: provider名称
返回:
对应的模块
"""
if provider not in self._provider_configs:
raise ValueError(f"不支持的provider: {provider}")
if provider not in self._loaded_modules:
module_path = self._provider_configs[provider]["module_path"]
logger.info(f"获取模块: {module_path}")
# 使用 importlib 动态导入模块
import importlib
module = importlib.import_module(module_path, package=__package__)
self._loaded_modules[provider] = module
logger.info(f"模块 {module_path} 获取成功")
return self._loaded_modules[provider]
def _get_transcribe_function(self, provider: str) -> Callable:
"""
获取指定provider的转录函数
参数:
provider: provider名称
返回:
转录函数
"""
module = self._lazy_load_module(provider)
function_name = self._provider_configs[provider]["function_name"]
if not hasattr(module, function_name):
raise AttributeError(f"模块中未找到函数: {function_name}")
return getattr(module, function_name)
def _filter_params(self, provider: str, params: Dict[str, Any]) -> Dict[str, Any]:
"""
过滤参数,只保留指定provider支持的参数
参数:
provider: provider名称
params: 原始参数字典
返回:
过滤后的参数字典
"""
supported_params = self._provider_configs[provider]["supported_params"]
filtered_params = {}
for param in supported_params:
if param in params:
filtered_params[param] = params[param]
# 如果没有指定model_name,使用默认模型
if "model_name" not in filtered_params and "model_name" in supported_params:
filtered_params["model_name"] = self._provider_configs[provider]["default_model"]
# 对于 Transformers backend,如果 device 未指定,则默认为 cpu
if provider == "distil_whisper_transformers" and "device" in supported_params and "device" not in filtered_params:
filtered_params["device"] = "cpu"
return filtered_params
def transcribe(
self,
audio_segment: AudioSegment,
provider: str,
**kwargs
) -> TranscriptionResult:
"""
统一的音频转录接口
参数:
audio_segment: 输入的AudioSegment对象
provider: ASR提供者名称
**kwargs: 其他参数,如model_name, device等
返回:
TranscriptionResult对象
"""
logger.info(f"使用provider '{provider}' 进行音频转录,音频长度: {len(audio_segment)/1000:.2f}秒")
if provider not in self._provider_configs:
available_providers = list(self._provider_configs.keys())
raise ValueError(f"不支持的provider: {provider}。支持的provider: {available_providers}")
try:
# 获取转录函数
transcribe_func = self._get_transcribe_function(provider)
# 过滤并准备参数
filtered_kwargs = self._filter_params(provider, kwargs)
logger.debug(f"调用 {provider} 转录函数,参数: {filtered_kwargs}")
# 执行转录
result = transcribe_func(audio_segment, **filtered_kwargs)
logger.info(f"转录完成,文本长度: {len(result.text)}字符")
return result
except Exception as e:
logger.error(f"使用provider '{provider}' 转录音频失败: {str(e)}", exc_info=True)
raise RuntimeError(f"转录失败: {str(e)}")
def get_available_providers(self) -> Dict[str, str]:
"""
获取所有可用的provider及其描述
返回:
provider名称到描述的映射
"""
return {
provider: config["description"]
for provider, config in self._provider_configs.items()
}
def get_provider_info(self, provider: str) -> Dict[str, Any]:
"""
获取指定provider的详细信息
参数:
provider: provider名称
返回:
provider的配置信息
"""
if provider not in self._provider_configs:
raise ValueError(f"不支持的provider: {provider}")
return self._provider_configs[provider].copy()
# 创建全局路由器实例
_router = ASRRouter()
@spaces.GPU(duration=180)
def transcribe_audio(
audio_segment: AudioSegment,
provider: str = "distil_whisper_transformers",
model_name: Optional[str] = None,
device: str = "cpu",
**kwargs
) -> TranscriptionResult:
"""
统一的音频转录接口,通过路由器选择后端
"""
# 准备参数
params = kwargs.copy()
if model_name is not None:
params["model_name"] = model_name
if device != "cpu": # 只有当 device 不是默认值才传递,或者根据需要传递所有支持的参数
params["device"] = device
return _router.transcribe(audio_segment, provider, **params)
def get_available_providers() -> Dict[str, str]:
"""
获取所有可用的ASR提供者
返回:
provider名称到描述的映射
"""
return _router.get_available_providers()
def get_provider_info(provider: str) -> Dict[str, Any]:
"""
获取指定provider的详细信息
参数:
provider: provider名称
返回:
provider的配置信息
"""
return _router.get_provider_info(provider)
|