Spaces:
Running
Running
File size: 16,408 Bytes
924aa01 8289369 924aa01 8289369 924aa01 8289369 924aa01 8289369 924aa01 8289369 5a751c2 8289369 5a751c2 8289369 5a751c2 8289369 5a751c2 8289369 5a751c2 8289369 5a751c2 8289369 5a751c2 8289369 924aa01 8289369 5a751c2 8289369 5a751c2 8289369 5a751c2 8289369 5a751c2 8289369 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 |
import logging
from typing import List, Dict, Optional
import json
import re
from ..schemas import EnhancedSegment, PodcastChannel, PodcastEpisode
from ..llm import llm_router
# 配置日志
logger = logging.getLogger("speaker_identify")
class SpeakerIdentifier:
"""
说话人识别器类,用于根据转录分段和播客元数据识别说话人的真实姓名或昵称
"""
def __init__(self, llm_model_name: str, llm_provider: str, device: Optional[str] = None):
"""
初始化说话人识别器
参数:
llm_model_name: LLM模型名称,如果为None则使用默认模型
llm_provider: LLM提供者,默认为"gemma-mlx"
device: 计算设备,例如 "cpu", "cuda", "mps"
"""
self.llm_model_name = llm_model_name
self.llm_provider = llm_provider
self.device = device
def _clean_html(self, html_string: Optional[str]) -> str:
"""
简单地从字符串中移除HTML标签并清理多余空白。
"""
if not html_string:
return ""
# 移除HTML标签
text = re.sub(r'<[^>]+>', ' ', html_string)
# 替换HTML实体(简单版本,只处理常见几个)
text = text.replace(' ', ' ').replace('&', '&').replace('<', '<').replace('>', '>')
# 移除多余的空白符
text = re.sub(r'\\s+', ' ', text).strip()
return text
def _get_dialogue_samples(
self,
segments: List[EnhancedSegment],
max_samples_per_speaker: int = 3, # 增加样本数量
max_length_per_sample: int = 200 # 增加样本长度
) -> Dict[str, List[str]]:
"""
为每个说话人提取对话样本。
"""
speaker_dialogues: Dict[str, List[str]] = {}
for segment in segments:
speaker = segment.speaker
if speaker == "UNKNOWN" or not segment.text.strip(): # 跳过未知说话人或空文本
continue
if speaker not in speaker_dialogues:
speaker_dialogues[speaker] = []
if len(speaker_dialogues[speaker]) < max_samples_per_speaker:
text_sample = segment.text.strip()[:max_length_per_sample]
if len(segment.text.strip()) > max_length_per_sample:
text_sample += "..."
speaker_dialogues[speaker].append(text_sample)
return speaker_dialogues
def recognize_speaker_names(
self,
segments: List[EnhancedSegment],
podcast_info: Optional[PodcastChannel],
episode_info: Optional[PodcastEpisode],
max_shownotes_length: int = 1500,
max_desc_length: int = 500
) -> Dict[str, str]:
"""
使用LLM根据转录分段和播客/剧集元数据识别说话人的真实姓名或昵称。
参数:
segments: 转录后的 EnhancedSegment 列表。
podcast_info: 包含播客元数据的 PodcastChannel 对象。
episode_info: 包含单集播客元数据的 PodcastEpisode 对象。
max_shownotes_length: 用于Prompt的 Shownotes 最大字符数。
max_desc_length: 用于Prompt的播客描述最大字符数。
返回:
一个字典,键是原始的 "SPEAKER_XX",值是识别出的说话人名称。
"""
unique_speaker_ids = sorted(list(set(seg.speaker for seg in segments if seg.speaker != "UNKNOWN" and seg.text.strip())))
if not unique_speaker_ids:
print("未能从 segments 中提取到有效的 speaker_ids。")
return {}
dialogue_samples = self._get_dialogue_samples(segments)
# 增加每个说话人的话语分析信息,包括话语频率和长度
speaker_stats = {}
for segment in segments:
speaker = segment.speaker
if speaker == "UNKNOWN" or not segment.text.strip():
continue
if speaker not in speaker_stats:
speaker_stats[speaker] = {
"total_segments": 0,
"total_chars": 0,
"avg_segment_length": 0,
"intro_likely": False # 是否有介绍性质的话语
}
speaker_stats[speaker]["total_segments"] += 1
speaker_stats[speaker]["total_chars"] += len(segment.text)
# 检测可能的自我介绍或他人介绍
lower_text = segment.text.lower()
intro_patterns = [
r'欢迎来到', r'欢迎收听', r'我是', r'我叫', r'大家好', r'今天的嘉宾是', r'我们请到了',
r'welcome to', r'i\'m your host', r'this is', r'today we have', r'joining us',
r'our guest', r'my name is'
]
if any(re.search(pattern, lower_text) for pattern in intro_patterns):
speaker_stats[speaker]["intro_likely"] = True
# 计算平均话语长度
for speaker, stats in speaker_stats.items():
if stats["total_segments"] > 0:
stats["avg_segment_length"] = stats["total_chars"] / stats["total_segments"]
# 创建增强的说话人信息,包含统计数据
speaker_info_for_prompt = []
for speaker_id in unique_speaker_ids:
samples = dialogue_samples.get(speaker_id, ["(No dialogue samples available)"])
stats = speaker_stats.get(speaker_id, {"total_segments": 0, "avg_segment_length": 0, "intro_likely": False})
speaker_info_for_prompt.append({
"speaker_id": speaker_id,
"dialogue_samples": samples,
"speech_stats": {
"total_segments": stats["total_segments"],
"avg_segment_length": round(stats["avg_segment_length"], 2),
"has_intro_pattern": stats["intro_likely"]
}
})
# 安全地访问属性,提供默认值
podcast_title = podcast_info.title if podcast_info and podcast_info.title else "Unknown Podcast"
podcast_author = podcast_info.author if podcast_info and podcast_info.author else "Unknown"
raw_podcast_desc = podcast_info.description if podcast_info and podcast_info.description else ""
cleaned_podcast_desc = self._clean_html(raw_podcast_desc)
podcast_desc_for_prompt = cleaned_podcast_desc[:max_desc_length]
if len(cleaned_podcast_desc) > max_desc_length:
podcast_desc_for_prompt += "..."
episode_title = episode_info.title if episode_info and episode_info.title else "Unknown Episode"
raw_episode_summary = episode_info.summary if episode_info and episode_info.summary else ""
cleaned_episode_summary = self._clean_html(raw_episode_summary)
episode_summary_for_prompt = cleaned_episode_summary[:max_desc_length] # 使用与描述相同的长度限制
if len(cleaned_episode_summary) > max_desc_length:
episode_summary_for_prompt += "..."
raw_episode_shownotes = episode_info.shownotes if episode_info and episode_info.shownotes else ""
cleaned_episode_shownotes = self._clean_html(raw_episode_shownotes)
episode_shownotes_for_prompt = cleaned_episode_shownotes[:max_shownotes_length]
if len(cleaned_episode_shownotes) > max_shownotes_length:
episode_shownotes_for_prompt += "..."
system_prompt = """You are a speaker identification expert. Return only a JSON object mapping speaker IDs to names. Start directly with { and end with }. No markdown, no explanations."""
# 进一步简化,只保留最关键的信息
key_info = []
for speaker_id in unique_speaker_ids:
samples = dialogue_samples.get(speaker_id, [])
stats = speaker_stats.get(speaker_id, {"total_segments": 0, "intro_likely": False})
# 构建简短描述
desc_parts = []
if stats["intro_likely"]:
desc_parts.append("intro")
if stats["total_segments"] > 0:
desc_parts.append(f"{stats['total_segments']}segs")
if samples:
# 只取第一个样本的前50个字符
sample_text = samples[0][:50].replace('\n', ' ').strip()
if sample_text:
desc_parts.append(f'"{sample_text}"')
key_info.append(f"{speaker_id}: {', '.join(desc_parts)}")
user_prompt_template = f"""Podcast: {podcast_title}
Host: {podcast_author}
Episode: {episode_title}
Notes: {episode_shownotes_for_prompt[:300]}
Speakers:
{chr(10).join(key_info)}
Return JSON like: {{"SPEAKER_00": "Name1", "SPEAKER_01": "Name2"}}"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt_template}
]
# 预设默认映射,使用更智能的启发式方法而不是简单依赖顺序
final_map = {}
# 尝试使用说话模式启发式方法来初步识别角色
# 1. 说话次数最多的可能是主持人
# 2. 有介绍性话语的可能是主持人
# 3. 其他角色先标记为嘉宾
host_candidates = []
for speaker_id, stats in speaker_stats.items():
if stats["intro_likely"]:
host_candidates.append((speaker_id, 2)) # 优先级2:有介绍性话语
else:
# 按说话次数排序
host_candidates.append((speaker_id, stats["total_segments"]))
# 按可能性排序(介绍性话语 > 说话次数)
host_candidates.sort(key=lambda x: (-1 if x[1] == 2 else 0, x[1]), reverse=True)
if host_candidates:
# 最可能的主持人
host_id = host_candidates[0][0]
final_map[host_id] = "Podcast Host"
# 其他人先标为嘉宾
guest_counter = 1
for speaker_id in unique_speaker_ids:
if speaker_id != host_id:
final_map[speaker_id] = f"Guest {guest_counter}"
guest_counter += 1
else:
# 如果没有明显线索,使用传统的顺序方法作为备选
is_host_assigned = False
guest_counter = 1
for speaker_id in unique_speaker_ids:
if not is_host_assigned:
final_map[speaker_id] = "Podcast Host"
is_host_assigned = True
else:
final_map[speaker_id] = f"Guest {guest_counter}"
guest_counter += 1
try:
response = llm_router.chat_completion(
messages=messages,
provider=self.llm_provider,
model=self.llm_model_name,
temperature=0.2, # 稍微提高温度
max_tokens=300, # 进一步增加token数
top_p=0.5, # 适度提高top_p
device=self.device,
repetition_penalty=1.0, # 保持不使用重复惩罚
do_sample=True # 允许少量采样,不使用stop tokens
)
logger.info(f"LLM调用日志,请求参数:【{messages}】, 响应: 【{response}】")
assistant_response_content = response["choices"][0]["message"]["content"]
# 更严格的JSON提取逻辑
parsed_llm_output = None
# 首先尝试直接解析整个响应(如果它就是JSON)
try:
parsed_llm_output = json.loads(assistant_response_content.strip())
if isinstance(parsed_llm_output, dict):
print("直接解析响应为JSON成功")
else:
parsed_llm_output = None
except json.JSONDecodeError:
pass
# 如果直接解析失败,尝试提取JSON部分
if parsed_llm_output is None:
# 尝试从Markdown代码块中提取JSON
json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', assistant_response_content, re.DOTALL)
if json_match:
json_str = json_match.group(1)
print("从markdown代码块中提取JSON")
else:
# 如果没有markdown块,尝试找到第一个 '{' 到最后一个 '}'
first_brace = assistant_response_content.find('{')
last_brace = assistant_response_content.rfind('}')
if first_brace != -1 and last_brace != -1 and last_brace > first_brace:
json_str = assistant_response_content[first_brace : last_brace+1]
print("通过大括号位置提取JSON")
else:
print("无法找到有效的JSON结构,使用默认映射")
return final_map
try:
# 清理JSON字符串
json_str = json_str.strip()
# 移除可能的换行符和多余空格
json_str = re.sub(r'\s+', ' ', json_str)
parsed_llm_output = json.loads(json_str)
if not isinstance(parsed_llm_output, dict):
print(f"LLM返回的JSON不是一个字典: {parsed_llm_output}")
parsed_llm_output = None
else:
print("JSON解析成功")
except json.JSONDecodeError as e:
print(f"LLM返回的JSON解析失败: {e}")
print(f"用于解析的字符串: '{json_str[:200]}...'")
parsed_llm_output = None
if parsed_llm_output:
# 直接使用LLM的有效输出,不再依赖预设的角色分配逻辑
final_map = {}
unknown_counter = 1
# 先处理LLM识别出的角色
for spk_id in unique_speaker_ids:
if spk_id in parsed_llm_output and isinstance(parsed_llm_output[spk_id], str) and parsed_llm_output[spk_id].strip():
final_map[spk_id] = parsed_llm_output[spk_id].strip()
else:
# 如果LLM没有给出特定ID的结果,使用"Unknown Speaker"
final_map[spk_id] = f"Unknown Speaker {unknown_counter}"
unknown_counter += 1
# 检查是否有"Host"或"主持人"标识
has_host = any("主持人" in name or "Host" in name for name in final_map.values())
# 如果没有任何主持人标识,且存在"Unknown Speaker",可以考虑将最活跃的未知说话人设为主持人
if not has_host and any("Unknown Speaker" in name for name in final_map.values()):
# 找出最活跃的未知说话人
most_active_unknown = None
max_segments = 0
for spk_id, name in final_map.items():
if "Unknown Speaker" in name and spk_id in speaker_stats:
if speaker_stats[spk_id]["total_segments"] > max_segments:
max_segments = speaker_stats[spk_id]["total_segments"]
most_active_unknown = spk_id
if most_active_unknown:
final_map[most_active_unknown] = "Podcast Host"
print(f"LLM识别结果: {final_map}")
return final_map
except Exception as e:
import traceback
print(f"调用LLM或处理响应时发生严重错误: {e}")
print(traceback.format_exc())
# 发生任何严重错误,返回初始的启发式映射
return final_map |