import os import pandas as pd import numpy as np from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity import gradio as gr from google import generativeai as genai API_KEY = os.getenv("GOOGLE_API_KEY") if API_KEY: genai.configure(api_key=API_KEY) print("API 키가 성공적으로 설정되었습니다.") else: raise ValueError("API 키가 설정되지 않았습니다. Hugging Face Spaces의 Repository secrets에 'GOOGLE_API_KEY'를 설정해주세요.") df = pd.read_csv('https://raw.githubusercontent.com/kairess/mental-health-chatbot/master/wellness_dataset_original.csv') df = df.drop(columns=['Unnamed: 3'], errors='ignore') df = df.dropna(subset=['유저', '챗봇']) model = SentenceTransformer('jhgan/ko-sbert-nli') print("데이터셋 임베딩을 미리 계산 중입니다. 이 과정은 시간이 소요됩니다...") df['embedding'] = df['유저'].apply(lambda x: model.encode(x)) print("임베딩 계산이 완료되었습니다! 이제 챗봇 응답이 훨씬 빨라집니다.") def call_gemini_api(question): try: llm_model = genai.GenerativeModel('gemini-2.0-flash') response = llm_model.generate_content(question) return response.text except Exception as e: print(f"API 호출 중 오류 발생: {e}") return f"죄송합니다. API 호출 중 오류가 발생했습니다: {e}" COSINE_SIMILARITY_THRESHOLD = 0.8 def chatbot(user_question): try: user_embedding = model.encode(user_question) similarities = df['embedding'].apply(lambda x: cosine_similarity([user_embedding], [x])[0][0]) best_match_index = similarities.idxmax() best_score = similarities.loc[best_match_index] best_match_row = df.loc[best_match_index] if best_score >= COSINE_SIMILARITY_THRESHOLD: answer = best_match_row['챗봇'] print(f"유사도 기반 답변. 점수: {best_score}") return answer else: print(f"유사도 임계값({COSINE_SIMILARITY_THRESHOLD}) 미만. Gemini 모델을 호출합니다. 점수: {best_score}") return call_gemini_api(user_question) except Exception as e: print(f"챗봇 실행 중 오류 발생: {e}") return f"죄송합니다. 챗봇 실행 중 오류가 발생했습니다: {e}" demo = gr.Interface( fn=chatbot, inputs=gr.Textbox(lines=2, placeholder="질문을 입력해 주세요...", label="질문", elem_id="user_question_input"), outputs=gr.Textbox(lines=5, label="챗봇 답변"), title="또래 상담 챗봇", description="5분 동안 대화하여 주시고 다음의 링크를 클릭하여 꼭 설문조사에 참여해주세요! https://forms.gle/eWtyejQaQntKbbxG8" ) demo.launch(server_name="0.0.0.0", server_port=7860, share=False)