konieshadow's picture
优化asr
5d74f79
"""
说话人分离器基础类,包含可复用的方法
"""
import os
import logging
from abc import ABC, abstractmethod
from pydub import AudioSegment
from typing import Any, Dict, List, Union, Optional, Tuple
from ..schemas import DiarizationResult
# 配置日志
logger = logging.getLogger("diarization")
class BaseDiarizer(ABC):
"""说话人分离器基础类"""
def __init__(
self,
model_name: str,
token: Optional[str] = None,
device: str = "cpu",
segmentation_batch_size: int = 32,
):
"""
初始化说话人分离器基础参数
参数:
model_name: 模型名称
token: Hugging Face令牌,用于访问模型
device: 推理设备,'cpu'或'cuda'
segmentation_batch_size: 分割批处理大小,默认为32
"""
self.model_name = model_name
self.device = device
self.segmentation_batch_size = segmentation_batch_size
logger.info(f"初始化说话人分离器,模型: {model_name},设备: {device},分割批处理大小: {segmentation_batch_size}")
@abstractmethod
def _load_model(self):
"""加载模型,子类需要实现"""
pass
def _prepare_audio(self, audio: AudioSegment) -> str:
"""
准备音频数据,保存为临时文件
参数:
audio: 输入的AudioSegment对象
返回:
临时音频文件的路径
"""
logger.debug(f"准备音频数据: 时长={len(audio)/1000:.2f}秒, 采样率={audio.frame_rate}Hz, 声道数={audio.channels}")
# 确保采样率为16kHz (pyannote模型要求)
if audio.frame_rate != 16000:
logger.debug(f"重采样音频从 {audio.frame_rate}Hz 到 16000Hz")
audio = audio.set_frame_rate(16000)
# 确保是单声道
if audio.channels > 1:
logger.debug(f"将{audio.channels}声道音频转换为单声道")
audio = audio.set_channels(1)
# 保存为临时文件
temp_audio_path = "_temp_audio_for_diarization.wav"
audio.export(temp_audio_path, format="wav")
logger.debug(f"音频处理完成,保存至: {temp_audio_path}")
return temp_audio_path
def _convert_segments(self, diarization) -> Tuple[List[Dict[str, Union[float, str, int]]], int]:
"""
将pyannote的分段结果转换为所需格式
参数:
diarization: pyannote模型返回的分段结果
返回:
转换后的分段列表和说话人数量
"""
segments = []
speakers = set()
# 遍历说话人分离结果
for turn, _, speaker in diarization.itertracks(yield_label=True):
segments.append({
"start": turn.start,
"end": turn.end,
"speaker": speaker
})
speakers.add(speaker)
# 按开始时间排序
segments.sort(key=lambda x: x["start"])
logger.debug(f"转换了 {len(segments)} 个分段,检测到 {len(speakers)} 个说话人")
return segments, len(speakers)
@abstractmethod
def diarize(self, audio: AudioSegment) -> DiarizationResult:
"""
对音频进行说话人分离,子类需要实现
参数:
audio: 要处理的AudioSegment对象
返回:
DiarizationResult对象,包含分段结果和说话人数量
"""
pass