File size: 8,695 Bytes
ae9ec05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import whisper
import google.generativeai as genai
import os
import json
from datetime import datetime
import re

def test_speaker_separation():
    """Gemini를 사용한 화자 분리 테스트"""
    
    # API 키 로드
    from dotenv import load_dotenv
    load_dotenv()
    
    GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
    
    if not GOOGLE_API_KEY or GOOGLE_API_KEY == "your_google_api_key_here":
        print("ERROR: Please set GOOGLE_API_KEY in .env file")
        return
    
    print("Loading models...")
    
    try:
        # Whisper 모델 로드
        whisper_model = whisper.load_model("base")
        print("Whisper model loaded!")
        
        # Gemini 모델 설정
        genai.configure(api_key=GOOGLE_API_KEY)
        # gemini-2.0-flash: 최신 Gemini 2.0 모델, 빠르고 정확한 화자 분리
        gemini_model = genai.GenerativeModel('gemini-2.0-flash')
        print("Gemini 2.0 Flash model configured!")
        
        # WAV 파일 찾기
        wav_files = []
        if os.path.exists("data"):
            for file in os.listdir("data"):
                if file.endswith(".wav"):
                    wav_files.append(os.path.join("data", file))
        
        if not wav_files:
            print("No WAV files found in data folder.")
            return
        
        print(f"Found {len(wav_files)} WAV file(s)")
        
        for wav_file in wav_files[:1]:  # 첫 번째 파일만 테스트
            print(f"\nProcessing: {os.path.basename(wav_file)}")
            
            # 1단계: 음성 인식
            print("Step 1: Speech recognition...")
            result = whisper_model.transcribe(wav_file)
            full_text = result['text'].strip()
            
            print(f"Language detected: {result['language']}")
            print(f"Text length: {len(full_text)} characters")
            print(f"Text preview: {full_text[:200]}...")
            
            # 2단계: 화자 분리
            print("\nStep 2: Speaker separation with Gemini...")
            
            prompt = f"""
당신은 2명의 화자가 나누는 대화를 분석하는 전문가입니다. 
주어진 텍스트를 분석하여 각 발언을 화자별로 구분해주세요.

분석 지침:
1. 대화의 맥락과 내용을 기반으로 화자를 구분하세요
2. 말투, 주제 전환, 질문과 답변의 패턴을 활용하세요
3. 화자1과 화자2로 구분하여 표시하세요
4. 각 발언 앞에 [화자1] 또는 [화자2]를 붙여주세요
5. 발언이 너무 길 경우 자연스러운 지점에서 나누어주세요

출력 형식:
[화자1] 첫 번째 발언 내용
[화자2] 두 번째 발언 내용
[화자1] 세 번째 발언 내용
...

분석할 텍스트:
{full_text}
"""

            response = gemini_model.generate_content(prompt)
            separated_text = response.text.strip()
            
            print("Speaker separation completed!")
            
            # 3단계: 맞춤법 교정
            print("\nStep 3: Spell checking with Gemini...")
            
            spelling_prompt = f"""
당신은 한국어 맞춤법 교정 전문가입니다. 
주어진 텍스트에서 맞춤법 오류, 띄어쓰기 오류, 오타를 수정해주세요.

교정 지침:
1. 자연스러운 한국어 표현으로 수정하되, 원본의 의미와 말투는 유지하세요
2. [화자1], [화자2] 태그는 그대로 유지하세요
3. 전문 용어나 고유명사는 가능한 정확하게 수정하세요
4. 구어체 특성은 유지하되, 명백한 오타만 수정하세요
5. 문맥에 맞는 올바른 단어로 교체하세요

교정할 텍스트:
{separated_text}
"""

            corrected_response = gemini_model.generate_content(spelling_prompt)
            corrected_text = corrected_response.text.strip()
            
            print("Spell checking completed!")
            
            # 4단계: 결과 저장
            print("\nStep 4: Saving results...")
            
            base_name = os.path.splitext(os.path.basename(wav_file))[0]
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            
            # output 폴더 생성
            if not os.path.exists("output"):
                os.makedirs("output")
            
            # 전체 결과 저장 (원본 + 분리 + 교정)
            result_path = f"output/{base_name}_complete_result_{timestamp}.txt"
            with open(result_path, 'w', encoding='utf-8') as f:
                f.write(f"Filename: {base_name}\n")
                f.write(f"Processing time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
                f.write(f"Language: {result['language']}\n")
                f.write("="*50 + "\n\n")
                f.write("Original text:\n")
                f.write(full_text + "\n\n")
                f.write("="*50 + "\n\n")
                f.write("Speaker separated text (original):\n")
                f.write(separated_text + "\n\n")
                f.write("="*50 + "\n\n")
                f.write("Speaker separated text (spell corrected):\n")
                f.write(corrected_text + "\n")
            
            # 교정된 텍스트에서 화자별 분리 결과 파싱
            corrected_conversations = {"화자1": [], "화자2": []}
            pattern = r'\[화자([12])\]\s*(.+?)(?=\[화자[12]\]|$)'
            matches = re.findall(pattern, corrected_text, re.DOTALL)
            
            for speaker_num, content in matches:
                speaker = f"화자{speaker_num}"
                content = content.strip()
                if content:
                    corrected_conversations[speaker].append(content)
            
            # 원본 화자별 분리 결과도 파싱 (비교용)
            original_conversations = {"화자1": [], "화자2": []}
            matches = re.findall(pattern, separated_text, re.DOTALL)
            
            for speaker_num, content in matches:
                speaker = f"화자{speaker_num}"
                content = content.strip()
                if content:
                    original_conversations[speaker].append(content)
            
            # 교정된 화자별 개별 파일 저장
            for speaker, utterances in corrected_conversations.items():
                if utterances:
                    speaker_path = f"output/{base_name}_{speaker}_교정본_{timestamp}.txt"
                    with open(speaker_path, 'w', encoding='utf-8') as f:
                        f.write(f"Filename: {base_name}\n")
                        f.write(f"Speaker: {speaker}\n")
                        f.write(f"Processing time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
                        f.write(f"Number of utterances: {len(utterances)}\n")
                        f.write("="*50 + "\n\n")
                        
                        for idx, utterance in enumerate(utterances, 1):
                            f.write(f"{idx}. {utterance}\n\n")
            
            # JSON 저장 (원본과 교정본 모두 포함)
            json_path = f"output/{base_name}_complete_data_{timestamp}.json"
            json_data = {
                "filename": base_name,
                "processed_time": datetime.now().isoformat(),
                "language": result['language'],
                "original_text": full_text,
                "separated_text": separated_text,
                "corrected_text": corrected_text,
                "conversations_by_speaker_original": original_conversations,
                "conversations_by_speaker_corrected": corrected_conversations,
                "segments": result.get("segments", [])
            }
            
            with open(json_path, 'w', encoding='utf-8') as f:
                json.dump(json_data, f, ensure_ascii=False, indent=2)
            
            print(f"Results saved:")
            print(f"  - Complete result: {result_path}")
            print(f"  - JSON data: {json_path}")
            for speaker in corrected_conversations:
                if corrected_conversations[speaker]:
                    print(f"  - {speaker} (교정본): {len(corrected_conversations[speaker])} utterances")
            
            print("\nProcessing completed successfully!")
            print("✓ Speech recognition with Whisper")
            print("✓ Speaker separation with Gemini 2.0")
            print("✓ Spell checking with Gemini 2.0")
            print("✓ Results saved (original + corrected versions)")
            
    except Exception as e:
        print(f"Error occurred: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    test_speaker_separation()