Spaces:
Runtime error
Runtime error
File size: 8,695 Bytes
ae9ec05 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
import whisper
import google.generativeai as genai
import os
import json
from datetime import datetime
import re
def test_speaker_separation():
"""Gemini를 사용한 화자 분리 테스트"""
# API 키 로드
from dotenv import load_dotenv
load_dotenv()
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
if not GOOGLE_API_KEY or GOOGLE_API_KEY == "your_google_api_key_here":
print("ERROR: Please set GOOGLE_API_KEY in .env file")
return
print("Loading models...")
try:
# Whisper 모델 로드
whisper_model = whisper.load_model("base")
print("Whisper model loaded!")
# Gemini 모델 설정
genai.configure(api_key=GOOGLE_API_KEY)
# gemini-2.0-flash: 최신 Gemini 2.0 모델, 빠르고 정확한 화자 분리
gemini_model = genai.GenerativeModel('gemini-2.0-flash')
print("Gemini 2.0 Flash model configured!")
# WAV 파일 찾기
wav_files = []
if os.path.exists("data"):
for file in os.listdir("data"):
if file.endswith(".wav"):
wav_files.append(os.path.join("data", file))
if not wav_files:
print("No WAV files found in data folder.")
return
print(f"Found {len(wav_files)} WAV file(s)")
for wav_file in wav_files[:1]: # 첫 번째 파일만 테스트
print(f"\nProcessing: {os.path.basename(wav_file)}")
# 1단계: 음성 인식
print("Step 1: Speech recognition...")
result = whisper_model.transcribe(wav_file)
full_text = result['text'].strip()
print(f"Language detected: {result['language']}")
print(f"Text length: {len(full_text)} characters")
print(f"Text preview: {full_text[:200]}...")
# 2단계: 화자 분리
print("\nStep 2: Speaker separation with Gemini...")
prompt = f"""
당신은 2명의 화자가 나누는 대화를 분석하는 전문가입니다.
주어진 텍스트를 분석하여 각 발언을 화자별로 구분해주세요.
분석 지침:
1. 대화의 맥락과 내용을 기반으로 화자를 구분하세요
2. 말투, 주제 전환, 질문과 답변의 패턴을 활용하세요
3. 화자1과 화자2로 구분하여 표시하세요
4. 각 발언 앞에 [화자1] 또는 [화자2]를 붙여주세요
5. 발언이 너무 길 경우 자연스러운 지점에서 나누어주세요
출력 형식:
[화자1] 첫 번째 발언 내용
[화자2] 두 번째 발언 내용
[화자1] 세 번째 발언 내용
...
분석할 텍스트:
{full_text}
"""
response = gemini_model.generate_content(prompt)
separated_text = response.text.strip()
print("Speaker separation completed!")
# 3단계: 맞춤법 교정
print("\nStep 3: Spell checking with Gemini...")
spelling_prompt = f"""
당신은 한국어 맞춤법 교정 전문가입니다.
주어진 텍스트에서 맞춤법 오류, 띄어쓰기 오류, 오타를 수정해주세요.
교정 지침:
1. 자연스러운 한국어 표현으로 수정하되, 원본의 의미와 말투는 유지하세요
2. [화자1], [화자2] 태그는 그대로 유지하세요
3. 전문 용어나 고유명사는 가능한 정확하게 수정하세요
4. 구어체 특성은 유지하되, 명백한 오타만 수정하세요
5. 문맥에 맞는 올바른 단어로 교체하세요
교정할 텍스트:
{separated_text}
"""
corrected_response = gemini_model.generate_content(spelling_prompt)
corrected_text = corrected_response.text.strip()
print("Spell checking completed!")
# 4단계: 결과 저장
print("\nStep 4: Saving results...")
base_name = os.path.splitext(os.path.basename(wav_file))[0]
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# output 폴더 생성
if not os.path.exists("output"):
os.makedirs("output")
# 전체 결과 저장 (원본 + 분리 + 교정)
result_path = f"output/{base_name}_complete_result_{timestamp}.txt"
with open(result_path, 'w', encoding='utf-8') as f:
f.write(f"Filename: {base_name}\n")
f.write(f"Processing time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"Language: {result['language']}\n")
f.write("="*50 + "\n\n")
f.write("Original text:\n")
f.write(full_text + "\n\n")
f.write("="*50 + "\n\n")
f.write("Speaker separated text (original):\n")
f.write(separated_text + "\n\n")
f.write("="*50 + "\n\n")
f.write("Speaker separated text (spell corrected):\n")
f.write(corrected_text + "\n")
# 교정된 텍스트에서 화자별 분리 결과 파싱
corrected_conversations = {"화자1": [], "화자2": []}
pattern = r'\[화자([12])\]\s*(.+?)(?=\[화자[12]\]|$)'
matches = re.findall(pattern, corrected_text, re.DOTALL)
for speaker_num, content in matches:
speaker = f"화자{speaker_num}"
content = content.strip()
if content:
corrected_conversations[speaker].append(content)
# 원본 화자별 분리 결과도 파싱 (비교용)
original_conversations = {"화자1": [], "화자2": []}
matches = re.findall(pattern, separated_text, re.DOTALL)
for speaker_num, content in matches:
speaker = f"화자{speaker_num}"
content = content.strip()
if content:
original_conversations[speaker].append(content)
# 교정된 화자별 개별 파일 저장
for speaker, utterances in corrected_conversations.items():
if utterances:
speaker_path = f"output/{base_name}_{speaker}_교정본_{timestamp}.txt"
with open(speaker_path, 'w', encoding='utf-8') as f:
f.write(f"Filename: {base_name}\n")
f.write(f"Speaker: {speaker}\n")
f.write(f"Processing time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"Number of utterances: {len(utterances)}\n")
f.write("="*50 + "\n\n")
for idx, utterance in enumerate(utterances, 1):
f.write(f"{idx}. {utterance}\n\n")
# JSON 저장 (원본과 교정본 모두 포함)
json_path = f"output/{base_name}_complete_data_{timestamp}.json"
json_data = {
"filename": base_name,
"processed_time": datetime.now().isoformat(),
"language": result['language'],
"original_text": full_text,
"separated_text": separated_text,
"corrected_text": corrected_text,
"conversations_by_speaker_original": original_conversations,
"conversations_by_speaker_corrected": corrected_conversations,
"segments": result.get("segments", [])
}
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(json_data, f, ensure_ascii=False, indent=2)
print(f"Results saved:")
print(f" - Complete result: {result_path}")
print(f" - JSON data: {json_path}")
for speaker in corrected_conversations:
if corrected_conversations[speaker]:
print(f" - {speaker} (교정본): {len(corrected_conversations[speaker])} utterances")
print("\nProcessing completed successfully!")
print("✓ Speech recognition with Whisper")
print("✓ Speaker separation with Gemini 2.0")
print("✓ Spell checking with Gemini 2.0")
print("✓ Results saved (original + corrected versions)")
except Exception as e:
print(f"Error occurred: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
test_speaker_separation()
|