Spaces:
Running
Running
File size: 6,173 Bytes
8289369 5d74f79 8289369 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
"""
基于pyannote.audio库调用pyannote/speaker-diarization-3.1模型实现的说话人分离模块
"""
import os
import numpy as np
from pydub import AudioSegment
from typing import Any, Dict, List, Mapping, Text, Union, Optional, Tuple
import logging
import torch
from .diarizer_base import BaseDiarizer
from ..schemas import DiarizationResult
# 配置日志
logger = logging.getLogger("diarization")
class PyannoteTransformersTranscriber(BaseDiarizer):
"""使用pyannote.audio库调用pyannote/speaker-diarization-3.1模型进行说话人分离"""
def __init__(
self,
model_name: str = "pyannote/speaker-diarization-3.1",
token: Optional[str] = None,
device: str = "cpu",
segmentation_batch_size: int = 32,
):
"""
初始化说话人分离器
参数:
model_name: 模型名称
token: Hugging Face令牌,用于访问模型
device: 推理设备,'cpu'或'cuda'
segmentation_batch_size: 分割批处理大小,默认为32
"""
super().__init__(model_name, token, device, segmentation_batch_size)
# 加载模型
self._load_model()
def _load_model(self):
"""使用pyannote.audio加载模型"""
try:
# 检查依赖库
try:
from pyannote.audio import Pipeline
except ImportError:
raise ImportError("请先安装pyannote.audio库: pip install pyannote.audio")
logger.info(f"开始使用pyannote.audio加载模型 {self.model_name}")
# 使用pyannote.audio Pipeline加载说话人分离模型
self.pipeline = Pipeline.from_pretrained(
self.model_name,
)
# 设置设备
logger.info(f"将模型移动到设备: {self.device}")
self.pipeline.to(torch.device(self.device))
# 设置分割批处理大小
if hasattr(self.pipeline, "segmentation_batch_size"):
logger.info(f"设置分割批处理大小: {self.segmentation_batch_size}")
self.pipeline.segmentation_batch_size = self.segmentation_batch_size
logger.info(f"pyannote.audio模型加载成功")
except Exception as e:
logger.error(f"加载模型失败: {str(e)}", exc_info=True)
raise RuntimeError(f"模型加载失败: {str(e)}")
def diarize(self, audio: AudioSegment) -> DiarizationResult:
"""
对音频进行说话人分离
参数:
audio: 要处理的AudioSegment对象
返回:
DiarizationResult对象,包含分段结果和说话人数量
"""
logger.info(f"开始使用pyannote.audio处理 {len(audio)/1000:.2f} 秒的音频进行说话人分离")
# 准备音频输入
temp_audio_path = self._prepare_audio(audio)
try:
# 执行说话人分离
logger.debug("开始执行说话人分离")
# 使用自定义 ProgressHook 来显示进度
try:
from pyannote.audio.pipelines.utils.hook import ProgressHook
class CustomProgressHook(ProgressHook):
def __call__(
self,
step_name: Text,
step_artifact: Any,
file: Optional[Mapping] = None,
total: Optional[int] = None,
completed: Optional[int] = None,
):
if completed is not None and total is not None:
percentage = completed / total * 100
logger.info(f"处理中 {step_name}: ({percentage:.1f}%)")
else:
logger.info(f"已完成 {step_name}")
with CustomProgressHook() as hook:
diarization = self.pipeline(temp_audio_path, hook=hook)
except ImportError:
# 如果ProgressHook不可用,直接执行
logger.info("ProgressHook不可用,直接执行说话人分离")
diarization = self.pipeline(temp_audio_path)
# 转换分段结果
segments, num_speakers = self._convert_segments(diarization)
logger.info(f"说话人分离完成,检测到 {num_speakers} 个说话人,生成 {len(segments)} 个分段")
return DiarizationResult(
segments=segments,
num_speakers=num_speakers
)
except Exception as e:
logger.error(f"说话人分离失败: {str(e)}", exc_info=True)
raise RuntimeError(f"说话人分离失败: {str(e)}")
finally:
# 删除临时文件
if os.path.exists(temp_audio_path):
os.remove(temp_audio_path)
def diarize_audio(
audio_segment: AudioSegment,
model_name: str = "pyannote/speaker-diarization-3.1",
token: Optional[str] = None,
device: str = "cpu",
segmentation_batch_size: int = 32,
) -> DiarizationResult:
"""
使用pyannote.audio调用pyannote模型对音频进行说话人分离
参数:
audio_segment: 输入的AudioSegment对象
model_name: 使用的模型名称
token: Hugging Face令牌
device: 推理设备,'cpu'、'cuda'、'mps'
segmentation_batch_size: 分割批处理大小,默认为32
返回:
DiarizationResult对象,包含分段和说话人数量
"""
logger.info(f"调用pyannote.audio版本diarize_audio函数,音频长度: {len(audio_segment)/1000:.2f}秒")
transcriber = PyannoteTransformersTranscriber(
model_name=model_name,
token=token,
device=device,
segmentation_batch_size=segmentation_batch_size
)
return transcriber.diarize(audio_segment)
|