File size: 2,853 Bytes
3a742d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0d22d3
3a742d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import gradio as gr
from google import generativeai as genai

API_KEY = os.getenv("GOOGLE_API_KEY")

if API_KEY:
    genai.configure(api_key=API_KEY)
    print("API ํ‚ค๊ฐ€ ์„ฑ๊ณต์ ์œผ๋กœ ์„ค์ •๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
else:
    raise ValueError("API ํ‚ค๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. Hugging Face Spaces์˜ Repository secrets์— 'GOOGLE_API_KEY'๋ฅผ ์„ค์ •ํ•ด์ฃผ์„ธ์š”.")

df = pd.read_csv('https://raw.githubusercontent.com/kairess/mental-health-chatbot/master/wellness_dataset_original.csv')
df = df.drop(columns=['Unnamed: 3'], errors='ignore')
df = df.dropna(subset=['์œ ์ €', '์ฑ—๋ด‡'])

model = SentenceTransformer('jhgan/ko-sbert-nli')

print("๋ฐ์ดํ„ฐ์…‹ ์ž„๋ฒ ๋”ฉ์„ ๋ฏธ๋ฆฌ ๊ณ„์‚ฐ ์ค‘์ž…๋‹ˆ๋‹ค. ์ด ๊ณผ์ •์€ ์‹œ๊ฐ„์ด ์†Œ์š”๋ฉ๋‹ˆ๋‹ค...")
df['embedding'] = df['์œ ์ €'].apply(lambda x: model.encode(x))
print("์ž„๋ฒ ๋”ฉ ๊ณ„์‚ฐ์ด ์™„๋ฃŒ๋˜์—ˆ์Šต๋‹ˆ๋‹ค! ์ด์ œ ์ฑ—๋ด‡ ์‘๋‹ต์ด ํ›จ์”ฌ ๋นจ๋ผ์ง‘๋‹ˆ๋‹ค.")

def call_gemini_api(question):
    try:
        llm_model = genai.GenerativeModel('gemini-2.0-flash')
        response = llm_model.generate_content(question)
        return response.text
    except Exception as e:
        print(f"API ํ˜ธ์ถœ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
        return f"์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. API ํ˜ธ์ถœ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {e}"

COSINE_SIMILARITY_THRESHOLD = 0.8

def chatbot(user_question):
    try:
        user_embedding = model.encode(user_question)
        similarities = df['embedding'].apply(lambda x: cosine_similarity([user_embedding], [x])[0][0])
        best_match_index = similarities.idxmax()
        best_score = similarities.loc[best_match_index]
        best_match_row = df.loc[best_match_index]

        if best_score >= COSINE_SIMILARITY_THRESHOLD:
            answer = best_match_row['์ฑ—๋ด‡']
            print(f"์œ ์‚ฌ๋„ ๊ธฐ๋ฐ˜ ๋‹ต๋ณ€. ์ ์ˆ˜: {best_score}")
            return answer
        else:
            print(f"์œ ์‚ฌ๋„ ์ž„๊ณ„๊ฐ’({COSINE_SIMILARITY_THRESHOLD}) ๋ฏธ๋งŒ. Gemini ๋ชจ๋ธ์„ ํ˜ธ์ถœํ•ฉ๋‹ˆ๋‹ค. ์ ์ˆ˜: {best_score}")
            return call_gemini_api(user_question)
    except Exception as e:
        print(f"์ฑ—๋ด‡ ์‹คํ–‰ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
        return f"์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์ฑ—๋ด‡ ์‹คํ–‰ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {e}"

demo = gr.Interface(
    fn=chatbot,
    inputs=gr.Textbox(lines=2, placeholder="์งˆ๋ฌธ์„ ์ž…๋ ฅํ•ด ์ฃผ์„ธ์š”...", label="์งˆ๋ฌธ", elem_id="user_question_input"),
    outputs=gr.Textbox(lines=5, label="์ฑ—๋ด‡ ๋‹ต๋ณ€"),
    title="๋˜๋ž˜ ์ƒ๋‹ด ์ฑ—๋ด‡",
    description="5๋ถ„ ๋™์•ˆ ๋Œ€ํ™”ํ•˜์—ฌ ์ฃผ์‹œ๊ณ  ๋‹ค์Œ์˜ ๋งํฌ๋ฅผ ํด๋ฆญํ•˜์—ฌ ๊ผญ ์„ค๋ฌธ์กฐ์‚ฌ์— ์ฐธ์—ฌํ•ด์ฃผ์„ธ์š”! https://forms.gle/eWtyejQaQntKbbxG8"
)

demo.launch(server_name="0.0.0.0", server_port=7860, share=False)