File size: 3,172 Bytes
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
924aa01
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
924aa01
 
 
8289369
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# 添加项目根目录到Python路径
import json
import sys
from pathlib import Path
import os

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from src.podcast_transcribe.rss.podcast_rss_parser import parse_rss_xml_content
from src.podcast_transcribe.schemas import EnhancedSegment, CombinedTranscriptionResult
from src.podcast_transcribe.summary.speaker_identify import SpeakerIdentifier

if __name__ == '__main__':
    transcribe_result_dump_file = Path.joinpath(Path(__file__).parent, "output", "lex_ai_john_carmack_1.transcription.json")
    podcast_rss_xml_file = Path.joinpath(Path(__file__).parent, "input", "lexfridman.com.rss.xml")
    device = "mps"

    # Load the transcription result
    if not os.path.exists(transcribe_result_dump_file):
        print(f"错误:转录结果文件 '{transcribe_result_dump_file}' 不存在。请先运行 combined_transcription.py 生成结果。")
        sys.exit(1)

    with open(transcribe_result_dump_file, "r", encoding="utf-8") as f:
        # transcription_result = json.load(f) # 旧代码
        data = json.load(f)
        segments_data = data.get("segments", [])
        # 确保 segments_data 中的每个元素都是字典,以避免在 EnhancedSegment(**seg) 时出错
        # 假设 EnhancedSegment 的字段与 JSON 中 segment 字典的键完全对应
        enhanced_segments = []
        for seg_dict in segments_data:
            if isinstance(seg_dict, dict):
                enhanced_segments.append(EnhancedSegment(**seg_dict))
            else:
                # 处理非字典类型 segment 的情况,例如记录日志或抛出错误
                print(f"警告: 在JSON中发现非字典类型的segment: {seg_dict}")
        
        transcription_result = CombinedTranscriptionResult(
            segments=enhanced_segments,
            text=data.get("text", ""),
            language=data.get("language", ""),
            num_speakers=data.get("num_speakers", 0)
        )

    # 打印加载的 CombinedTranscriptionResult 对象的一些信息以供验证
    print(f"\\n成功从JSON加载 CombinedTranscriptionResult 对象:")
    print(f"类型: {type(transcription_result)}")

    # Load the podcast RSS XML file
    with open(podcast_rss_xml_file, "r") as f:
        podcast_rss_xml = f.read()
        mock_podcast_info = parse_rss_xml_content(podcast_rss_xml)

    
    # 查找标题已 "#309" 开头的剧集
    mock_episode_info = next((episode for episode in mock_podcast_info.episodes if episode.title.startswith("#309")), None)
    if not mock_episode_info:
        raise ValueError("Could not find episode with title starting with '#309'")
    

    speaker_identifier = SpeakerIdentifier(
        llm_model_name="google/gemma-3-4b-it",
        llm_provider="gemma-transformers",
        device=device
    )

    # 3. Call the function
    print("\\n--- Test Case 1: Normal execution ---")
    speaker_names = speaker_identifier.recognize_speaker_names(transcription_result.segments, mock_podcast_info, mock_episode_info)
    print("\\nRecognized Speaker Names (Test Case 1):")
    print(json.dumps(speaker_names, ensure_ascii=False, indent=2))