Spaces:
Running
Running
""" | |
说话人分离器基础类,包含可复用的方法 | |
""" | |
import os | |
import logging | |
from abc import ABC, abstractmethod | |
from pydub import AudioSegment | |
from typing import Any, Dict, List, Union, Optional, Tuple | |
from ..schemas import DiarizationResult | |
# 配置日志 | |
logger = logging.getLogger("diarization") | |
class BaseDiarizer(ABC): | |
"""说话人分离器基础类""" | |
def __init__( | |
self, | |
model_name: str, | |
token: Optional[str] = None, | |
device: str = "cpu", | |
segmentation_batch_size: int = 32, | |
): | |
""" | |
初始化说话人分离器基础参数 | |
参数: | |
model_name: 模型名称 | |
token: Hugging Face令牌,用于访问模型 | |
device: 推理设备,'cpu'或'cuda' | |
segmentation_batch_size: 分割批处理大小,默认为32 | |
""" | |
self.model_name = model_name | |
self.device = device | |
self.segmentation_batch_size = segmentation_batch_size | |
logger.info(f"初始化说话人分离器,模型: {model_name},设备: {device},分割批处理大小: {segmentation_batch_size}") | |
def _load_model(self): | |
"""加载模型,子类需要实现""" | |
pass | |
def _prepare_audio(self, audio: AudioSegment) -> str: | |
""" | |
准备音频数据,保存为临时文件 | |
参数: | |
audio: 输入的AudioSegment对象 | |
返回: | |
临时音频文件的路径 | |
""" | |
logger.debug(f"准备音频数据: 时长={len(audio)/1000:.2f}秒, 采样率={audio.frame_rate}Hz, 声道数={audio.channels}") | |
# 确保采样率为16kHz (pyannote模型要求) | |
if audio.frame_rate != 16000: | |
logger.debug(f"重采样音频从 {audio.frame_rate}Hz 到 16000Hz") | |
audio = audio.set_frame_rate(16000) | |
# 确保是单声道 | |
if audio.channels > 1: | |
logger.debug(f"将{audio.channels}声道音频转换为单声道") | |
audio = audio.set_channels(1) | |
# 保存为临时文件 | |
temp_audio_path = "_temp_audio_for_diarization.wav" | |
audio.export(temp_audio_path, format="wav") | |
logger.debug(f"音频处理完成,保存至: {temp_audio_path}") | |
return temp_audio_path | |
def _convert_segments(self, diarization) -> Tuple[List[Dict[str, Union[float, str, int]]], int]: | |
""" | |
将pyannote的分段结果转换为所需格式 | |
参数: | |
diarization: pyannote模型返回的分段结果 | |
返回: | |
转换后的分段列表和说话人数量 | |
""" | |
segments = [] | |
speakers = set() | |
# 遍历说话人分离结果 | |
for turn, _, speaker in diarization.itertracks(yield_label=True): | |
segments.append({ | |
"start": turn.start, | |
"end": turn.end, | |
"speaker": speaker | |
}) | |
speakers.add(speaker) | |
# 按开始时间排序 | |
segments.sort(key=lambda x: x["start"]) | |
logger.debug(f"转换了 {len(segments)} 个分段,检测到 {len(speakers)} 个说话人") | |
return segments, len(speakers) | |
def diarize(self, audio: AudioSegment) -> DiarizationResult: | |
""" | |
对音频进行说话人分离,子类需要实现 | |
参数: | |
audio: 要处理的AudioSegment对象 | |
返回: | |
DiarizationResult对象,包含分段结果和说话人数量 | |
""" | |
pass |