File size: 9,343 Bytes
8289369
 
 
 
 
 
 
 
7803eb5
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7803eb5
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d74f79
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
"""
说话人分离模型调用路由器
根据传递的provider参数调用不同的说话人分离实现,支持延迟加载
"""

import logging
from typing import Dict, Any, Optional, Callable
from pydub import AudioSegment
import spaces
from ..schemas import DiarizationResult
from . import diarization_pyannote_mlx
from . import diarization_pyannote_transformers

# 配置日志
logger = logging.getLogger("diarization")


class DiarizerRouter:
    """说话人分离模型调用路由器,支持多种实现的统一调用"""
    
    def __init__(self):
        """初始化路由器"""
        self._loaded_modules = {}  # 用于缓存已加载的模块
        self._diarizers = {}       # 用于缓存已实例化的分离器
        
        # 定义支持的provider配置
        self._provider_configs = {
            "pyannote_mlx": {
                "module_path": "diarization_pyannote_mlx",
                "function_name": "diarize_audio",
                "default_model": "pyannote/speaker-diarization-3.1",
                "supported_params": ["model_name", "token", "device", "segmentation_batch_size"],
                "description": "基于pyannote.audio的原生MLX实现"
            },
            "pyannote_transformers": {
                "module_path": "diarization_pyannote_transformers",
                "function_name": "diarize_audio",
                "default_model": "pyannote/speaker-diarization-3.1",
                "supported_params": ["model_name", "token", "device", "segmentation_batch_size"],
                "description": "基于transformers库调用pyannote模型"
            }
        }
    
    def _lazy_load_module(self, provider: str):
        """
        获取指定provider的模块
        
        参数:
            provider: provider名称
            
        返回:
            对应的模块
        """
        if provider not in self._provider_configs:
            raise ValueError(f"不支持的provider: {provider}")
            
        if provider not in self._loaded_modules:
            module_path = self._provider_configs[provider]["module_path"]
            logger.info(f"获取模块: {module_path}")
            
            # 根据module_path返回对应的模块
            if module_path == "diarization_pyannote_mlx":
                module = diarization_pyannote_mlx
            elif module_path == "diarization_pyannote_transformers":
                module = diarization_pyannote_transformers
            else:
                raise ImportError(f"未找到模块: {module_path}")
            
            self._loaded_modules[provider] = module
            logger.info(f"模块 {module_path} 获取成功")
        
        return self._loaded_modules[provider]
    
    def _get_diarize_function(self, provider: str) -> Callable:
        """
        获取指定provider的说话人分离函数
        
        参数:
            provider: provider名称
            
        返回:
            说话人分离函数
        """
        module = self._lazy_load_module(provider)
        function_name = self._provider_configs[provider]["function_name"]
        
        if not hasattr(module, function_name):
            raise AttributeError(f"模块中未找到函数: {function_name}")
            
        return getattr(module, function_name)
    
    def _filter_params(self, provider: str, params: Dict[str, Any]) -> Dict[str, Any]:
        """
        过滤参数,只保留指定provider支持的参数
        
        参数:
            provider: provider名称
            params: 原始参数字典
            
        返回:
            过滤后的参数字典
        """
        supported_params = self._provider_configs[provider]["supported_params"]
        filtered_params = {}
        
        for param in supported_params:
            if param in params:
                filtered_params[param] = params[param]
        
        # 如果没有指定model_name,使用默认模型
        if "model_name" not in filtered_params and "model_name" in supported_params:
            filtered_params["model_name"] = self._provider_configs[provider]["default_model"]
        
        return filtered_params
    
    def diarize(
        self,
        audio_segment: AudioSegment,
        provider: str,
        **kwargs
    ) -> DiarizationResult:
        """
        统一的说话人分离接口
        
        参数:
            audio_segment: 输入的AudioSegment对象
            provider: 说话人分离提供者名称
            **kwargs: 其他参数,如model_name, token, device, segmentation_batch_size等
            
        返回:
            DiarizationResult对象
        """
        logger.info(f"使用provider '{provider}' 进行说话人分离,音频长度: {len(audio_segment)/1000:.2f}秒")
        
        if provider not in self._provider_configs:
            available_providers = list(self._provider_configs.keys())
            raise ValueError(f"不支持的provider: {provider}。支持的provider: {available_providers}")
        
        try:
            # 获取说话人分离函数
            diarize_func = self._get_diarize_function(provider)
            
            # 过滤并准备参数
            filtered_kwargs = self._filter_params(provider, kwargs)
            
            logger.debug(f"调用 {provider} 说话人分离函数,参数: {filtered_kwargs}")
            
            # 执行说话人分离
            result = diarize_func(audio_segment, **filtered_kwargs)
            
            logger.info(f"说话人分离完成,检测到 {result.num_speakers} 个说话人,生成 {len(result.segments)} 个分段")
            return result
            
        except Exception as e:
            logger.error(f"使用provider '{provider}' 进行说话人分离失败: {str(e)}", exc_info=True)
            raise RuntimeError(f"说话人分离失败: {str(e)}")
    
    def get_available_providers(self) -> Dict[str, str]:
        """
        获取所有可用的provider及其描述
        
        返回:
            provider名称到描述的映射
        """
        return {
            provider: config["description"] 
            for provider, config in self._provider_configs.items()
        }
    
    def get_provider_info(self, provider: str) -> Dict[str, Any]:
        """
        获取指定provider的详细信息
        
        参数:
            provider: provider名称
            
        返回:
            provider的配置信息
        """
        if provider not in self._provider_configs:
            raise ValueError(f"不支持的provider: {provider}")
            
        return self._provider_configs[provider].copy()


# 创建全局路由器实例
_router = DiarizerRouter()

@spaces.GPU(duration=180)
def diarize_audio(
    audio_segment: AudioSegment,
    provider: str = "pyannote_mlx",
    model_name: Optional[str] = None,
    token: Optional[str] = None,
    device: str = "cpu",
    segmentation_batch_size: int = 32,
    **kwargs
) -> DiarizationResult:
    """
    统一的音频说话人分离接口函数
    
    参数:
        audio_segment: 输入的AudioSegment对象
        provider: 说话人分离提供者,可选值:
            - "pyannote_mlx": 基于pyannote.audio的原生MLX实现
            - "pyannote_transformers": 基于transformers库调用pyannote模型
        model_name: 模型名称,如果不指定则使用默认模型
        token: Hugging Face令牌,用于访问模型
        device: 推理设备,'cpu'、'cuda'、'mps'
        segmentation_batch_size: 分割批处理大小,默认为32
        **kwargs: 其他参数
        
    返回:
        DiarizationResult对象,包含分段结果和说话人数量
        
    示例:
        # 使用默认pyannote MLX实现
        result = diarize_audio(audio_segment, provider="pyannote_mlx")
        
        # 使用transformers实现
        result = diarize_audio(
            audio_segment, 
            provider="pyannote_transformers",
        )
        
        # 使用GPU设备
        result = diarize_audio(
            audio_segment,
            provider="pyannote_mlx",
            device="cuda"
        )
        
        # 自定义批处理大小
        result = diarize_audio(
            audio_segment,
            provider="pyannote_mlx", 
            segmentation_batch_size=64
        )
    """
    # 准备参数
    params = kwargs.copy()
    if model_name is not None:
        params["model_name"] = model_name
    if token is not None:
        params["token"] = token
    if device != "cpu":
        params["device"] = device
    if segmentation_batch_size != 32:
        params["segmentation_batch_size"] = segmentation_batch_size
    
    return _router.diarize(audio_segment, provider, **params)


def get_available_providers() -> Dict[str, str]:
    """
    获取所有可用的说话人分离提供者
    
    返回:
        provider名称到描述的映射
    """
    return _router.get_available_providers()


def get_provider_info(provider: str) -> Dict[str, Any]:
    """
    获取指定provider的详细信息
    
    参数:
        provider: provider名称
        
    返回:
        provider的配置信息
    """
    return _router.get_provider_info(provider)