konieshadow's picture
优化asr
5d74f79
"""
说话人分离模型调用路由器
根据传递的provider参数调用不同的说话人分离实现,支持延迟加载
"""
import logging
from typing import Dict, Any, Optional, Callable
from pydub import AudioSegment
import spaces
from ..schemas import DiarizationResult
from . import diarization_pyannote_mlx
from . import diarization_pyannote_transformers
# 配置日志
logger = logging.getLogger("diarization")
class DiarizerRouter:
"""说话人分离模型调用路由器,支持多种实现的统一调用"""
def __init__(self):
"""初始化路由器"""
self._loaded_modules = {} # 用于缓存已加载的模块
self._diarizers = {} # 用于缓存已实例化的分离器
# 定义支持的provider配置
self._provider_configs = {
"pyannote_mlx": {
"module_path": "diarization_pyannote_mlx",
"function_name": "diarize_audio",
"default_model": "pyannote/speaker-diarization-3.1",
"supported_params": ["model_name", "token", "device", "segmentation_batch_size"],
"description": "基于pyannote.audio的原生MLX实现"
},
"pyannote_transformers": {
"module_path": "diarization_pyannote_transformers",
"function_name": "diarize_audio",
"default_model": "pyannote/speaker-diarization-3.1",
"supported_params": ["model_name", "token", "device", "segmentation_batch_size"],
"description": "基于transformers库调用pyannote模型"
}
}
def _lazy_load_module(self, provider: str):
"""
获取指定provider的模块
参数:
provider: provider名称
返回:
对应的模块
"""
if provider not in self._provider_configs:
raise ValueError(f"不支持的provider: {provider}")
if provider not in self._loaded_modules:
module_path = self._provider_configs[provider]["module_path"]
logger.info(f"获取模块: {module_path}")
# 根据module_path返回对应的模块
if module_path == "diarization_pyannote_mlx":
module = diarization_pyannote_mlx
elif module_path == "diarization_pyannote_transformers":
module = diarization_pyannote_transformers
else:
raise ImportError(f"未找到模块: {module_path}")
self._loaded_modules[provider] = module
logger.info(f"模块 {module_path} 获取成功")
return self._loaded_modules[provider]
def _get_diarize_function(self, provider: str) -> Callable:
"""
获取指定provider的说话人分离函数
参数:
provider: provider名称
返回:
说话人分离函数
"""
module = self._lazy_load_module(provider)
function_name = self._provider_configs[provider]["function_name"]
if not hasattr(module, function_name):
raise AttributeError(f"模块中未找到函数: {function_name}")
return getattr(module, function_name)
def _filter_params(self, provider: str, params: Dict[str, Any]) -> Dict[str, Any]:
"""
过滤参数,只保留指定provider支持的参数
参数:
provider: provider名称
params: 原始参数字典
返回:
过滤后的参数字典
"""
supported_params = self._provider_configs[provider]["supported_params"]
filtered_params = {}
for param in supported_params:
if param in params:
filtered_params[param] = params[param]
# 如果没有指定model_name,使用默认模型
if "model_name" not in filtered_params and "model_name" in supported_params:
filtered_params["model_name"] = self._provider_configs[provider]["default_model"]
return filtered_params
def diarize(
self,
audio_segment: AudioSegment,
provider: str,
**kwargs
) -> DiarizationResult:
"""
统一的说话人分离接口
参数:
audio_segment: 输入的AudioSegment对象
provider: 说话人分离提供者名称
**kwargs: 其他参数,如model_name, token, device, segmentation_batch_size等
返回:
DiarizationResult对象
"""
logger.info(f"使用provider '{provider}' 进行说话人分离,音频长度: {len(audio_segment)/1000:.2f}秒")
if provider not in self._provider_configs:
available_providers = list(self._provider_configs.keys())
raise ValueError(f"不支持的provider: {provider}。支持的provider: {available_providers}")
try:
# 获取说话人分离函数
diarize_func = self._get_diarize_function(provider)
# 过滤并准备参数
filtered_kwargs = self._filter_params(provider, kwargs)
logger.debug(f"调用 {provider} 说话人分离函数,参数: {filtered_kwargs}")
# 执行说话人分离
result = diarize_func(audio_segment, **filtered_kwargs)
logger.info(f"说话人分离完成,检测到 {result.num_speakers} 个说话人,生成 {len(result.segments)} 个分段")
return result
except Exception as e:
logger.error(f"使用provider '{provider}' 进行说话人分离失败: {str(e)}", exc_info=True)
raise RuntimeError(f"说话人分离失败: {str(e)}")
def get_available_providers(self) -> Dict[str, str]:
"""
获取所有可用的provider及其描述
返回:
provider名称到描述的映射
"""
return {
provider: config["description"]
for provider, config in self._provider_configs.items()
}
def get_provider_info(self, provider: str) -> Dict[str, Any]:
"""
获取指定provider的详细信息
参数:
provider: provider名称
返回:
provider的配置信息
"""
if provider not in self._provider_configs:
raise ValueError(f"不支持的provider: {provider}")
return self._provider_configs[provider].copy()
# 创建全局路由器实例
_router = DiarizerRouter()
@spaces.GPU(duration=180)
def diarize_audio(
audio_segment: AudioSegment,
provider: str = "pyannote_mlx",
model_name: Optional[str] = None,
token: Optional[str] = None,
device: str = "cpu",
segmentation_batch_size: int = 32,
**kwargs
) -> DiarizationResult:
"""
统一的音频说话人分离接口函数
参数:
audio_segment: 输入的AudioSegment对象
provider: 说话人分离提供者,可选值:
- "pyannote_mlx": 基于pyannote.audio的原生MLX实现
- "pyannote_transformers": 基于transformers库调用pyannote模型
model_name: 模型名称,如果不指定则使用默认模型
token: Hugging Face令牌,用于访问模型
device: 推理设备,'cpu'、'cuda'、'mps'
segmentation_batch_size: 分割批处理大小,默认为32
**kwargs: 其他参数
返回:
DiarizationResult对象,包含分段结果和说话人数量
示例:
# 使用默认pyannote MLX实现
result = diarize_audio(audio_segment, provider="pyannote_mlx")
# 使用transformers实现
result = diarize_audio(
audio_segment,
provider="pyannote_transformers",
)
# 使用GPU设备
result = diarize_audio(
audio_segment,
provider="pyannote_mlx",
device="cuda"
)
# 自定义批处理大小
result = diarize_audio(
audio_segment,
provider="pyannote_mlx",
segmentation_batch_size=64
)
"""
# 准备参数
params = kwargs.copy()
if model_name is not None:
params["model_name"] = model_name
if token is not None:
params["token"] = token
if device != "cpu":
params["device"] = device
if segmentation_batch_size != 32:
params["segmentation_batch_size"] = segmentation_batch_size
return _router.diarize(audio_segment, provider, **params)
def get_available_providers() -> Dict[str, str]:
"""
获取所有可用的说话人分离提供者
返回:
provider名称到描述的映射
"""
return _router.get_available_providers()
def get_provider_info(provider: str) -> Dict[str, Any]:
"""
获取指定provider的详细信息
参数:
provider: provider名称
返回:
provider的配置信息
"""
return _router.get_provider_info(provider)