podcast-transcriber / src /podcast_transcribe /diarization /diarization_pyannote_transformers.py
konieshadow's picture
修复错误
e1f125d
"""
基于pyannote.audio库调用pyannote/speaker-diarization-3.1模型实现的说话人分离模块
"""
import os
import numpy as np
from pydub import AudioSegment
from typing import Any, Dict, List, Mapping, Text, Union, Optional, Tuple
import logging
import torch
from .diarizer_base import BaseDiarizer
from ..schemas import DiarizationResult
# 配置日志
logger = logging.getLogger("diarization")
class PyannoteTransformersTranscriber(BaseDiarizer):
"""使用pyannote.audio库调用pyannote/speaker-diarization-3.1模型进行说话人分离"""
def __init__(
self,
model_name: str = "pyannote/speaker-diarization-3.1",
token: Optional[str] = None,
device: str = "cpu",
segmentation_batch_size: int = 32,
):
"""
初始化说话人分离器
参数:
model_name: 模型名称
token: Hugging Face令牌,用于访问模型
device: 推理设备,'cpu'或'cuda'
segmentation_batch_size: 分割批处理大小,默认为32
"""
super().__init__(model_name, token, device, segmentation_batch_size)
# 加载模型
self._load_model()
def _load_model(self):
"""使用pyannote.audio加载模型"""
try:
# 检查依赖库
try:
from pyannote.audio import Pipeline
except ImportError:
raise ImportError("请先安装pyannote.audio库: pip install pyannote.audio")
logger.info(f"开始使用pyannote.audio加载模型 {self.model_name}")
# 使用pyannote.audio Pipeline加载说话人分离模型
self.pipeline = Pipeline.from_pretrained(
self.model_name,
)
# 设置设备
logger.info(f"将模型移动到设备: {self.device}")
self.pipeline.to(torch.device(self.device))
# 设置分割批处理大小
if hasattr(self.pipeline, "segmentation_batch_size"):
logger.info(f"设置分割批处理大小: {self.segmentation_batch_size}")
self.pipeline.segmentation_batch_size = self.segmentation_batch_size
logger.info(f"pyannote.audio模型加载成功")
except Exception as e:
logger.error(f"加载模型失败: {str(e)}", exc_info=True)
raise RuntimeError(f"模型加载失败: {str(e)}")
def diarize(self, audio: AudioSegment) -> DiarizationResult:
"""
对音频进行说话人分离
参数:
audio: 要处理的AudioSegment对象
返回:
DiarizationResult对象,包含分段结果和说话人数量
"""
logger.info(f"开始使用pyannote.audio处理 {len(audio)/1000:.2f} 秒的音频进行说话人分离")
# 准备音频输入
temp_audio_path = self._prepare_audio(audio)
try:
# 执行说话人分离
logger.debug("开始执行说话人分离")
# 使用自定义 ProgressHook 来显示进度
try:
from pyannote.audio.pipelines.utils.hook import ProgressHook
class CustomProgressHook(ProgressHook):
def __call__(
self,
step_name: Text,
step_artifact: Any,
file: Optional[Mapping] = None,
total: Optional[int] = None,
completed: Optional[int] = None,
):
if completed is not None and total is not None:
percentage = completed / total * 100
logger.info(f"处理中 {step_name}: ({percentage:.1f}%)")
else:
logger.info(f"已完成 {step_name}")
with CustomProgressHook() as hook:
diarization = self.pipeline(temp_audio_path, hook=hook)
except ImportError:
# 如果ProgressHook不可用,直接执行
logger.info("ProgressHook不可用,直接执行说话人分离")
diarization = self.pipeline(temp_audio_path)
# 转换分段结果
segments, num_speakers = self._convert_segments(diarization)
logger.info(f"说话人分离完成,检测到 {num_speakers} 个说话人,生成 {len(segments)} 个分段")
return DiarizationResult(
segments=segments,
num_speakers=num_speakers
)
except Exception as e:
logger.error(f"说话人分离失败: {str(e)}", exc_info=True)
raise RuntimeError(f"说话人分离失败: {str(e)}")
finally:
# 删除临时文件
if os.path.exists(temp_audio_path):
os.remove(temp_audio_path)
def diarize_audio(
audio_segment: AudioSegment,
model_name: str = "pyannote/speaker-diarization-3.1",
token: Optional[str] = None,
device: str = "cpu",
segmentation_batch_size: int = 32,
) -> DiarizationResult:
"""
使用pyannote.audio调用pyannote模型对音频进行说话人分离
参数:
audio_segment: 输入的AudioSegment对象
model_name: 使用的模型名称
token: Hugging Face令牌
device: 推理设备,'cpu'、'cuda'、'mps'
segmentation_batch_size: 分割批处理大小,默认为32
返回:
DiarizationResult对象,包含分段和说话人数量
"""
logger.info(f"调用pyannote.audio版本diarize_audio函数,音频长度: {len(audio_segment)/1000:.2f}秒")
transcriber = PyannoteTransformersTranscriber(
model_name=model_name,
token=token,
device=device,
segmentation_batch_size=segmentation_batch_size
)
return transcriber.diarize(audio_segment)