Spaces:
Runtime error
Runtime error
Jeongsoo1975
Initial commit: Gradio text-based speaker separation app for Hugging Face Spaces
ae9ec05
import whisper | |
import google.generativeai as genai | |
import os | |
import json | |
from datetime import datetime | |
import re | |
def test_speaker_separation(): | |
"""Gemini를 사용한 화자 분리 테스트""" | |
# API 키 로드 | |
from dotenv import load_dotenv | |
load_dotenv() | |
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") | |
if not GOOGLE_API_KEY or GOOGLE_API_KEY == "your_google_api_key_here": | |
print("ERROR: Please set GOOGLE_API_KEY in .env file") | |
return | |
print("Loading models...") | |
try: | |
# Whisper 모델 로드 | |
whisper_model = whisper.load_model("base") | |
print("Whisper model loaded!") | |
# Gemini 모델 설정 | |
genai.configure(api_key=GOOGLE_API_KEY) | |
# gemini-2.0-flash: 최신 Gemini 2.0 모델, 빠르고 정확한 화자 분리 | |
gemini_model = genai.GenerativeModel('gemini-2.0-flash') | |
print("Gemini 2.0 Flash model configured!") | |
# WAV 파일 찾기 | |
wav_files = [] | |
if os.path.exists("data"): | |
for file in os.listdir("data"): | |
if file.endswith(".wav"): | |
wav_files.append(os.path.join("data", file)) | |
if not wav_files: | |
print("No WAV files found in data folder.") | |
return | |
print(f"Found {len(wav_files)} WAV file(s)") | |
for wav_file in wav_files[:1]: # 첫 번째 파일만 테스트 | |
print(f"\nProcessing: {os.path.basename(wav_file)}") | |
# 1단계: 음성 인식 | |
print("Step 1: Speech recognition...") | |
result = whisper_model.transcribe(wav_file) | |
full_text = result['text'].strip() | |
print(f"Language detected: {result['language']}") | |
print(f"Text length: {len(full_text)} characters") | |
print(f"Text preview: {full_text[:200]}...") | |
# 2단계: 화자 분리 | |
print("\nStep 2: Speaker separation with Gemini...") | |
prompt = f""" | |
당신은 2명의 화자가 나누는 대화를 분석하는 전문가입니다. | |
주어진 텍스트를 분석하여 각 발언을 화자별로 구분해주세요. | |
분석 지침: | |
1. 대화의 맥락과 내용을 기반으로 화자를 구분하세요 | |
2. 말투, 주제 전환, 질문과 답변의 패턴을 활용하세요 | |
3. 화자1과 화자2로 구분하여 표시하세요 | |
4. 각 발언 앞에 [화자1] 또는 [화자2]를 붙여주세요 | |
5. 발언이 너무 길 경우 자연스러운 지점에서 나누어주세요 | |
출력 형식: | |
[화자1] 첫 번째 발언 내용 | |
[화자2] 두 번째 발언 내용 | |
[화자1] 세 번째 발언 내용 | |
... | |
분석할 텍스트: | |
{full_text} | |
""" | |
response = gemini_model.generate_content(prompt) | |
separated_text = response.text.strip() | |
print("Speaker separation completed!") | |
# 3단계: 맞춤법 교정 | |
print("\nStep 3: Spell checking with Gemini...") | |
spelling_prompt = f""" | |
당신은 한국어 맞춤법 교정 전문가입니다. | |
주어진 텍스트에서 맞춤법 오류, 띄어쓰기 오류, 오타를 수정해주세요. | |
교정 지침: | |
1. 자연스러운 한국어 표현으로 수정하되, 원본의 의미와 말투는 유지하세요 | |
2. [화자1], [화자2] 태그는 그대로 유지하세요 | |
3. 전문 용어나 고유명사는 가능한 정확하게 수정하세요 | |
4. 구어체 특성은 유지하되, 명백한 오타만 수정하세요 | |
5. 문맥에 맞는 올바른 단어로 교체하세요 | |
교정할 텍스트: | |
{separated_text} | |
""" | |
corrected_response = gemini_model.generate_content(spelling_prompt) | |
corrected_text = corrected_response.text.strip() | |
print("Spell checking completed!") | |
# 4단계: 결과 저장 | |
print("\nStep 4: Saving results...") | |
base_name = os.path.splitext(os.path.basename(wav_file))[0] | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
# output 폴더 생성 | |
if not os.path.exists("output"): | |
os.makedirs("output") | |
# 전체 결과 저장 (원본 + 분리 + 교정) | |
result_path = f"output/{base_name}_complete_result_{timestamp}.txt" | |
with open(result_path, 'w', encoding='utf-8') as f: | |
f.write(f"Filename: {base_name}\n") | |
f.write(f"Processing time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") | |
f.write(f"Language: {result['language']}\n") | |
f.write("="*50 + "\n\n") | |
f.write("Original text:\n") | |
f.write(full_text + "\n\n") | |
f.write("="*50 + "\n\n") | |
f.write("Speaker separated text (original):\n") | |
f.write(separated_text + "\n\n") | |
f.write("="*50 + "\n\n") | |
f.write("Speaker separated text (spell corrected):\n") | |
f.write(corrected_text + "\n") | |
# 교정된 텍스트에서 화자별 분리 결과 파싱 | |
corrected_conversations = {"화자1": [], "화자2": []} | |
pattern = r'\[화자([12])\]\s*(.+?)(?=\[화자[12]\]|$)' | |
matches = re.findall(pattern, corrected_text, re.DOTALL) | |
for speaker_num, content in matches: | |
speaker = f"화자{speaker_num}" | |
content = content.strip() | |
if content: | |
corrected_conversations[speaker].append(content) | |
# 원본 화자별 분리 결과도 파싱 (비교용) | |
original_conversations = {"화자1": [], "화자2": []} | |
matches = re.findall(pattern, separated_text, re.DOTALL) | |
for speaker_num, content in matches: | |
speaker = f"화자{speaker_num}" | |
content = content.strip() | |
if content: | |
original_conversations[speaker].append(content) | |
# 교정된 화자별 개별 파일 저장 | |
for speaker, utterances in corrected_conversations.items(): | |
if utterances: | |
speaker_path = f"output/{base_name}_{speaker}_교정본_{timestamp}.txt" | |
with open(speaker_path, 'w', encoding='utf-8') as f: | |
f.write(f"Filename: {base_name}\n") | |
f.write(f"Speaker: {speaker}\n") | |
f.write(f"Processing time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") | |
f.write(f"Number of utterances: {len(utterances)}\n") | |
f.write("="*50 + "\n\n") | |
for idx, utterance in enumerate(utterances, 1): | |
f.write(f"{idx}. {utterance}\n\n") | |
# JSON 저장 (원본과 교정본 모두 포함) | |
json_path = f"output/{base_name}_complete_data_{timestamp}.json" | |
json_data = { | |
"filename": base_name, | |
"processed_time": datetime.now().isoformat(), | |
"language": result['language'], | |
"original_text": full_text, | |
"separated_text": separated_text, | |
"corrected_text": corrected_text, | |
"conversations_by_speaker_original": original_conversations, | |
"conversations_by_speaker_corrected": corrected_conversations, | |
"segments": result.get("segments", []) | |
} | |
with open(json_path, 'w', encoding='utf-8') as f: | |
json.dump(json_data, f, ensure_ascii=False, indent=2) | |
print(f"Results saved:") | |
print(f" - Complete result: {result_path}") | |
print(f" - JSON data: {json_path}") | |
for speaker in corrected_conversations: | |
if corrected_conversations[speaker]: | |
print(f" - {speaker} (교정본): {len(corrected_conversations[speaker])} utterances") | |
print("\nProcessing completed successfully!") | |
print("✓ Speech recognition with Whisper") | |
print("✓ Speaker separation with Gemini 2.0") | |
print("✓ Spell checking with Gemini 2.0") | |
print("✓ Results saved (original + corrected versions)") | |
except Exception as e: | |
print(f"Error occurred: {e}") | |
import traceback | |
traceback.print_exc() | |
if __name__ == "__main__": | |
test_speaker_separation() | |