File size: 6,173 Bytes
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d74f79
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
"""
基于pyannote.audio库调用pyannote/speaker-diarization-3.1模型实现的说话人分离模块
"""

import os
import numpy as np
from pydub import AudioSegment
from typing import Any, Dict, List, Mapping, Text, Union, Optional, Tuple
import logging
import torch

from .diarizer_base import BaseDiarizer
from ..schemas import DiarizationResult

# 配置日志
logger = logging.getLogger("diarization")


class PyannoteTransformersTranscriber(BaseDiarizer):
    """使用pyannote.audio库调用pyannote/speaker-diarization-3.1模型进行说话人分离"""
    
    def __init__(
        self, 
        model_name: str = "pyannote/speaker-diarization-3.1",
        token: Optional[str] = None,
        device: str = "cpu",
        segmentation_batch_size: int = 32,
    ):
        """
        初始化说话人分离器
        
        参数:
            model_name: 模型名称
            token: Hugging Face令牌,用于访问模型
            device: 推理设备,'cpu'或'cuda'
            segmentation_batch_size: 分割批处理大小,默认为32
        """
        super().__init__(model_name, token, device, segmentation_batch_size)
        
        # 加载模型
        self._load_model()
        
    def _load_model(self):
        """使用pyannote.audio加载模型"""
        try:
            # 检查依赖库
            try:
                from pyannote.audio import Pipeline
            except ImportError:
                raise ImportError("请先安装pyannote.audio库: pip install pyannote.audio")
    
            logger.info(f"开始使用pyannote.audio加载模型 {self.model_name}")
            
            # 使用pyannote.audio Pipeline加载说话人分离模型
            self.pipeline = Pipeline.from_pretrained(
                self.model_name,
            )
            
            # 设置设备
            logger.info(f"将模型移动到设备: {self.device}")
            self.pipeline.to(torch.device(self.device))
            
            # 设置分割批处理大小
            if hasattr(self.pipeline, "segmentation_batch_size"):
                logger.info(f"设置分割批处理大小: {self.segmentation_batch_size}")
                self.pipeline.segmentation_batch_size = self.segmentation_batch_size
            
            logger.info(f"pyannote.audio模型加载成功")
            
        except Exception as e:
            logger.error(f"加载模型失败: {str(e)}", exc_info=True)
            raise RuntimeError(f"模型加载失败: {str(e)}")
    
    def diarize(self, audio: AudioSegment) -> DiarizationResult:
        """
        对音频进行说话人分离
        
        参数:
            audio: 要处理的AudioSegment对象
            
        返回:
            DiarizationResult对象,包含分段结果和说话人数量
        """
        logger.info(f"开始使用pyannote.audio处理 {len(audio)/1000:.2f} 秒的音频进行说话人分离")
        
        # 准备音频输入
        temp_audio_path = self._prepare_audio(audio)
        
        try:
            # 执行说话人分离
            logger.debug("开始执行说话人分离")
            
            # 使用自定义 ProgressHook 来显示进度
            try:
                from pyannote.audio.pipelines.utils.hook import ProgressHook

                class CustomProgressHook(ProgressHook):
                    def __call__(
                        self,
                        step_name: Text,
                        step_artifact: Any,
                        file: Optional[Mapping] = None,
                        total: Optional[int] = None,
                        completed: Optional[int] = None,
                    ):
                        if completed is not None and total is not None:
                            percentage = completed / total * 100
                            logger.info(f"处理中 {step_name}: ({percentage:.1f}%)")
                        else:
                            logger.info(f"已完成 {step_name}")

                with CustomProgressHook() as hook:
                    diarization = self.pipeline(temp_audio_path, hook=hook)
                    
            except ImportError:
                # 如果ProgressHook不可用,直接执行
                logger.info("ProgressHook不可用,直接执行说话人分离")
                diarization = self.pipeline(temp_audio_path)

            # 转换分段结果
            segments, num_speakers = self._convert_segments(diarization)
            
            logger.info(f"说话人分离完成,检测到 {num_speakers} 个说话人,生成 {len(segments)} 个分段")
            
            return DiarizationResult(
                segments=segments,
                num_speakers=num_speakers
            )
            
        except Exception as e:
            logger.error(f"说话人分离失败: {str(e)}", exc_info=True)
            raise RuntimeError(f"说话人分离失败: {str(e)}")
        finally:
            # 删除临时文件
            if os.path.exists(temp_audio_path):
                os.remove(temp_audio_path)


def diarize_audio(
    audio_segment: AudioSegment,
    model_name: str = "pyannote/speaker-diarization-3.1",
    token: Optional[str] = None,
    device: str = "cpu",
    segmentation_batch_size: int = 32,
) -> DiarizationResult:
    """
    使用pyannote.audio调用pyannote模型对音频进行说话人分离
    
    参数:
        audio_segment: 输入的AudioSegment对象
        model_name: 使用的模型名称
        token: Hugging Face令牌
        device: 推理设备,'cpu'、'cuda'、'mps'
        segmentation_batch_size: 分割批处理大小,默认为32
        
    返回:
        DiarizationResult对象,包含分段和说话人数量
    """
    logger.info(f"调用pyannote.audio版本diarize_audio函数,音频长度: {len(audio_segment)/1000:.2f}秒")
    transcriber = PyannoteTransformersTranscriber(
        model_name=model_name, 
        token=token, 
        device=device, 
        segmentation_batch_size=segmentation_batch_size
    )
    return transcriber.diarize(audio_segment)