File size: 16,408 Bytes
924aa01
8289369
 
 
 
 
 
 
924aa01
 
8289369
 
 
 
 
 
924aa01
8289369
 
 
 
 
 
924aa01
8289369
 
 
924aa01
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a751c2
8289369
5a751c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8289369
5a751c2
 
 
8289369
5a751c2
8289369
5a751c2
 
8289369
5a751c2
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a751c2
 
 
 
 
 
8289369
924aa01
8289369
 
5a751c2
8289369
 
5a751c2
8289369
5a751c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a751c2
 
8289369
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
import logging
from typing import List, Dict, Optional
import json
import re

from ..schemas import EnhancedSegment, PodcastChannel, PodcastEpisode
from ..llm import llm_router

# 配置日志
logger = logging.getLogger("speaker_identify")

class SpeakerIdentifier:
    """
    说话人识别器类,用于根据转录分段和播客元数据识别说话人的真实姓名或昵称
    """
    
    def __init__(self, llm_model_name: str, llm_provider: str, device: Optional[str] = None):
        """
        初始化说话人识别器
        
        参数:
            llm_model_name: LLM模型名称,如果为None则使用默认模型
            llm_provider: LLM提供者,默认为"gemma-mlx"
            device: 计算设备,例如 "cpu", "cuda", "mps"
        """
        self.llm_model_name = llm_model_name
        self.llm_provider = llm_provider
        self.device = device
    
    def _clean_html(self, html_string: Optional[str]) -> str:
        """
        简单地从字符串中移除HTML标签并清理多余空白。
        """
        if not html_string:
            return ""
        # 移除HTML标签
        text = re.sub(r'<[^>]+>', ' ', html_string)
        # 替换HTML实体(简单版本,只处理常见几个)
        text = text.replace('&nbsp;', ' ').replace('&amp;', '&').replace('&lt;', '<').replace('&gt;', '>')
        # 移除多余的空白符
        text = re.sub(r'\\s+', ' ', text).strip()
        return text

    def _get_dialogue_samples(
        self,
        segments: List[EnhancedSegment], 
        max_samples_per_speaker: int = 3,  # 增加样本数量
        max_length_per_sample: int = 200   # 增加样本长度
    ) -> Dict[str, List[str]]:
        """
        为每个说话人提取对话样本。
        """
        speaker_dialogues: Dict[str, List[str]] = {}
        for segment in segments:
            speaker = segment.speaker
            if speaker == "UNKNOWN" or not segment.text.strip(): # 跳过未知说话人或空文本
                continue
            
            if speaker not in speaker_dialogues:
                speaker_dialogues[speaker] = []
            
            if len(speaker_dialogues[speaker]) < max_samples_per_speaker:
                text_sample = segment.text.strip()[:max_length_per_sample]
                if len(segment.text.strip()) > max_length_per_sample:
                    text_sample += "..."
                speaker_dialogues[speaker].append(text_sample)
        return speaker_dialogues

    def recognize_speaker_names(
        self,
        segments: List[EnhancedSegment],
        podcast_info: Optional[PodcastChannel],
        episode_info: Optional[PodcastEpisode],
        max_shownotes_length: int = 1500, 
        max_desc_length: int = 500
    ) -> Dict[str, str]:
        """
        使用LLM根据转录分段和播客/剧集元数据识别说话人的真实姓名或昵称。

        参数:
            segments: 转录后的 EnhancedSegment 列表。
            podcast_info: 包含播客元数据的 PodcastChannel 对象。
            episode_info: 包含单集播客元数据的 PodcastEpisode 对象。
            max_shownotes_length: 用于Prompt的 Shownotes 最大字符数。
            max_desc_length: 用于Prompt的播客描述最大字符数。

        返回:
            一个字典,键是原始的 "SPEAKER_XX",值是识别出的说话人名称。
        """
        unique_speaker_ids = sorted(list(set(seg.speaker for seg in segments if seg.speaker != "UNKNOWN" and seg.text.strip())))
        if not unique_speaker_ids:
            print("未能从 segments 中提取到有效的 speaker_ids。")
            return {}

        dialogue_samples = self._get_dialogue_samples(segments)
        
        # 增加每个说话人的话语分析信息,包括话语频率和长度
        speaker_stats = {}
        for segment in segments:
            speaker = segment.speaker
            if speaker == "UNKNOWN" or not segment.text.strip():
                continue
                
            if speaker not in speaker_stats:
                speaker_stats[speaker] = {
                    "total_segments": 0,
                    "total_chars": 0,
                    "avg_segment_length": 0,
                    "intro_likely": False  # 是否有介绍性质的话语
                }
            
            speaker_stats[speaker]["total_segments"] += 1
            speaker_stats[speaker]["total_chars"] += len(segment.text)
            
            # 检测可能的自我介绍或他人介绍
            lower_text = segment.text.lower()
            intro_patterns = [
                r'欢迎来到', r'欢迎收听', r'我是', r'我叫', r'大家好', r'今天的嘉宾是', r'我们请到了',
                r'welcome to', r'i\'m your host', r'this is', r'today we have', r'joining us', 
                r'our guest', r'my name is'
            ]
            if any(re.search(pattern, lower_text) for pattern in intro_patterns):
                speaker_stats[speaker]["intro_likely"] = True
        
        # 计算平均话语长度
        for speaker, stats in speaker_stats.items():
            if stats["total_segments"] > 0:
                stats["avg_segment_length"] = stats["total_chars"] / stats["total_segments"]
        
        # 创建增强的说话人信息,包含统计数据
        speaker_info_for_prompt = []
        for speaker_id in unique_speaker_ids:
            samples = dialogue_samples.get(speaker_id, ["(No dialogue samples available)"])
            stats = speaker_stats.get(speaker_id, {"total_segments": 0, "avg_segment_length": 0, "intro_likely": False})
            
            speaker_info_for_prompt.append({
                "speaker_id": speaker_id,
                "dialogue_samples": samples,
                "speech_stats": {
                    "total_segments": stats["total_segments"],
                    "avg_segment_length": round(stats["avg_segment_length"], 2),
                    "has_intro_pattern": stats["intro_likely"]
                }
            })

        # 安全地访问属性,提供默认值
        podcast_title = podcast_info.title if podcast_info and podcast_info.title else "Unknown Podcast"
        podcast_author = podcast_info.author if podcast_info and podcast_info.author else "Unknown"
        
        raw_podcast_desc = podcast_info.description if podcast_info and podcast_info.description else ""
        cleaned_podcast_desc = self._clean_html(raw_podcast_desc)
        podcast_desc_for_prompt = cleaned_podcast_desc[:max_desc_length]
        if len(cleaned_podcast_desc) > max_desc_length:
            podcast_desc_for_prompt += "..."
        
        episode_title = episode_info.title if episode_info and episode_info.title else "Unknown Episode"
        
        raw_episode_summary = episode_info.summary if episode_info and episode_info.summary else ""
        cleaned_episode_summary = self._clean_html(raw_episode_summary)
        episode_summary_for_prompt = cleaned_episode_summary[:max_desc_length] # 使用与描述相同的长度限制
        if len(cleaned_episode_summary) > max_desc_length:
            episode_summary_for_prompt += "..."

        raw_episode_shownotes = episode_info.shownotes if episode_info and episode_info.shownotes else ""
        cleaned_episode_shownotes = self._clean_html(raw_episode_shownotes)
        episode_shownotes_for_prompt = cleaned_episode_shownotes[:max_shownotes_length]
        if len(cleaned_episode_shownotes) > max_shownotes_length:
            episode_shownotes_for_prompt += "..."

        system_prompt = """You are a speaker identification expert. Return only a JSON object mapping speaker IDs to names. Start directly with { and end with }. No markdown, no explanations."""

        # 进一步简化,只保留最关键的信息
        key_info = []
        for speaker_id in unique_speaker_ids:
            samples = dialogue_samples.get(speaker_id, [])
            stats = speaker_stats.get(speaker_id, {"total_segments": 0, "intro_likely": False})
            
            # 构建简短描述
            desc_parts = []
            if stats["intro_likely"]:
                desc_parts.append("intro")
            if stats["total_segments"] > 0:
                desc_parts.append(f"{stats['total_segments']}segs")
            if samples:
                # 只取第一个样本的前50个字符
                sample_text = samples[0][:50].replace('\n', ' ').strip()
                if sample_text:
                    desc_parts.append(f'"{sample_text}"')
            
            key_info.append(f"{speaker_id}: {', '.join(desc_parts)}")

        user_prompt_template = f"""Podcast: {podcast_title}
Host: {podcast_author}
Episode: {episode_title}

Notes: {episode_shownotes_for_prompt[:300]}

Speakers:
{chr(10).join(key_info)}

Return JSON like: {{"SPEAKER_00": "Name1", "SPEAKER_01": "Name2"}}"""

        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt_template}
        ]

        # 预设默认映射,使用更智能的启发式方法而不是简单依赖顺序
        final_map = {}
        
        # 尝试使用说话模式启发式方法来初步识别角色
        # 1. 说话次数最多的可能是主持人
        # 2. 有介绍性话语的可能是主持人
        # 3. 其他角色先标记为嘉宾
        
        host_candidates = []
        for speaker_id, stats in speaker_stats.items():
            if stats["intro_likely"]:
                host_candidates.append((speaker_id, 2))  # 优先级2:有介绍性话语
            else:
                # 按说话次数排序
                host_candidates.append((speaker_id, stats["total_segments"]))
        
        # 按可能性排序(介绍性话语 > 说话次数)
        host_candidates.sort(key=lambda x: (-1 if x[1] == 2 else 0, x[1]), reverse=True)
        
        if host_candidates:
            # 最可能的主持人
            host_id = host_candidates[0][0]
            final_map[host_id] = "Podcast Host"
            
            # 其他人先标为嘉宾
            guest_counter = 1
            for speaker_id in unique_speaker_ids:
                if speaker_id != host_id:
                    final_map[speaker_id] = f"Guest {guest_counter}"
                    guest_counter += 1
        else:
            # 如果没有明显线索,使用传统的顺序方法作为备选
            is_host_assigned = False
            guest_counter = 1
            for speaker_id in unique_speaker_ids:
                if not is_host_assigned:
                    final_map[speaker_id] = "Podcast Host"
                    is_host_assigned = True
                else:
                    final_map[speaker_id] = f"Guest {guest_counter}"
                    guest_counter += 1
            
        try:
            response = llm_router.chat_completion(
                messages=messages, 
                provider=self.llm_provider,
                model=self.llm_model_name,
                temperature=0.2,   # 稍微提高温度
                max_tokens=300,    # 进一步增加token数
                top_p=0.5,         # 适度提高top_p
                device=self.device,
                repetition_penalty=1.0,  # 保持不使用重复惩罚
                do_sample=True     # 允许少量采样,不使用stop tokens
            )
            logger.info(f"LLM调用日志,请求参数:【{messages}】, 响应: 【{response}】")
            assistant_response_content = response["choices"][0]["message"]["content"]
            
            # 更严格的JSON提取逻辑
            parsed_llm_output = None
            
            # 首先尝试直接解析整个响应(如果它就是JSON)
            try:
                parsed_llm_output = json.loads(assistant_response_content.strip())
                if isinstance(parsed_llm_output, dict):
                    print("直接解析响应为JSON成功")
                else:
                    parsed_llm_output = None
            except json.JSONDecodeError:
                pass
            
            # 如果直接解析失败,尝试提取JSON部分
            if parsed_llm_output is None:
                # 尝试从Markdown代码块中提取JSON
                json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', assistant_response_content, re.DOTALL)
                if json_match:
                    json_str = json_match.group(1)
                    print("从markdown代码块中提取JSON")
                else:
                    # 如果没有markdown块,尝试找到第一个 '{' 到最后一个 '}'
                    first_brace = assistant_response_content.find('{')
                    last_brace = assistant_response_content.rfind('}')
                    if first_brace != -1 and last_brace != -1 and last_brace > first_brace:
                        json_str = assistant_response_content[first_brace : last_brace+1]
                        print("通过大括号位置提取JSON")
                    else:
                        print("无法找到有效的JSON结构,使用默认映射")
                        return final_map
                
                try:
                    # 清理JSON字符串
                    json_str = json_str.strip()
                    # 移除可能的换行符和多余空格
                    json_str = re.sub(r'\s+', ' ', json_str)
                    
                    parsed_llm_output = json.loads(json_str)
                    if not isinstance(parsed_llm_output, dict):
                        print(f"LLM返回的JSON不是一个字典: {parsed_llm_output}")
                        parsed_llm_output = None
                    else:
                        print("JSON解析成功")
                except json.JSONDecodeError as e:
                    print(f"LLM返回的JSON解析失败: {e}")
                    print(f"用于解析的字符串: '{json_str[:200]}...'")
                    parsed_llm_output = None

            if parsed_llm_output:
                # 直接使用LLM的有效输出,不再依赖预设的角色分配逻辑
                final_map = {}
                unknown_counter = 1
                
                # 先处理LLM识别出的角色
                for spk_id in unique_speaker_ids:
                    if spk_id in parsed_llm_output and isinstance(parsed_llm_output[spk_id], str) and parsed_llm_output[spk_id].strip():
                        final_map[spk_id] = parsed_llm_output[spk_id].strip()
                    else:
                        # 如果LLM没有给出特定ID的结果,使用"Unknown Speaker"
                        final_map[spk_id] = f"Unknown Speaker {unknown_counter}"
                        unknown_counter += 1
                
                # 检查是否有"Host"或"主持人"标识
                has_host = any("主持人" in name or "Host" in name for name in final_map.values())
                
                # 如果没有任何主持人标识,且存在"Unknown Speaker",可以考虑将最活跃的未知说话人设为主持人
                if not has_host and any("Unknown Speaker" in name for name in final_map.values()):
                    # 找出最活跃的未知说话人
                    most_active_unknown = None
                    max_segments = 0
                    
                    for spk_id, name in final_map.items():
                        if "Unknown Speaker" in name and spk_id in speaker_stats:
                            if speaker_stats[spk_id]["total_segments"] > max_segments:
                                max_segments = speaker_stats[spk_id]["total_segments"]
                                most_active_unknown = spk_id
                    
                    if most_active_unknown:
                        final_map[most_active_unknown] = "Podcast Host"
                        
                print(f"LLM识别结果: {final_map}")
            
            return final_map

        except Exception as e:
            import traceback
            print(f"调用LLM或处理响应时发生严重错误: {e}")
            print(traceback.format_exc())
            # 发生任何严重错误,返回初始的启发式映射
            return final_map