messy092 commited on
Commit
3a742d6
ยท
verified ยท
1 Parent(s): a3f511d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -0
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ from sentence_transformers import SentenceTransformer
5
+ from sklearn.metrics.pairwise import cosine_similarity
6
+ import gradio as gr
7
+ from google import generativeai as genai
8
+
9
+ # API ํ‚ค๋ฅผ Hugging Face Spaces์˜ Repository secrets์—์„œ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
10
+ API_KEY = os.getenv("GOOGLE_API_KEY")
11
+
12
+ if API_KEY:
13
+ genai.configure(api_key=API_KEY)
14
+ print("API ํ‚ค๊ฐ€ ์„ฑ๊ณต์ ์œผ๋กœ ์„ค์ •๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
15
+ else:
16
+ raise ValueError("API ํ‚ค๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. Hugging Face Spaces์˜ Repository secrets์— 'GOOGLE_API_KEY'๋ฅผ ์„ค์ •ํ•ด์ฃผ์„ธ์š”.")
17
+
18
+ # ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
19
+ df = pd.read_csv('https://raw.githubusercontent.com/kairess/mental-health-chatbot/master/wellness_dataset_original.csv')
20
+ df = df.drop(columns=['Unnamed: 3'], errors='ignore')
21
+ df = df.dropna(subset=['์œ ์ €', '์ฑ—๋ด‡'])
22
+
23
+ # ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋กœ๋“œ
24
+ model = SentenceTransformer('jhgan/ko-sbert-nli')
25
+
26
+ # ๋ฐ์ดํ„ฐ์…‹ ์ž„๋ฒ ๋”ฉ ๋ฏธ๋ฆฌ ๊ณ„์‚ฐ
27
+ print("๋ฐ์ดํ„ฐ์…‹ ์ž„๋ฒ ๋”ฉ์„ ๋ฏธ๋ฆฌ ๊ณ„์‚ฐ ์ค‘์ž…๋‹ˆ๋‹ค. ์ด ๊ณผ์ •์€ ์‹œ๊ฐ„์ด ์†Œ์š”๋ฉ๋‹ˆ๋‹ค...")
28
+ df['embedding'] = df['์œ ์ €'].apply(lambda x: model.encode(x))
29
+ print("์ž„๋ฒ ๋”ฉ ๊ณ„์‚ฐ์ด ์™„๋ฃŒ๋˜์—ˆ์Šต๋‹ˆ๋‹ค! ์ด์ œ ์ฑ—๋ด‡ ์‘๋‹ต์ด ํ›จ์”ฌ ๋นจ๋ผ์ง‘๋‹ˆ๋‹ค.")
30
+
31
+ # Gemini API ํ˜ธ์ถœ ํ•จ์ˆ˜
32
+ def call_gemini_api(question):
33
+ try:
34
+ llm_model = genai.GenerativeModel('gemini-pro')
35
+ response = llm_model.generate_content(question)
36
+ return response.text
37
+ except Exception as e:
38
+ print(f"API ํ˜ธ์ถœ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
39
+ return f"์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. API ํ˜ธ์ถœ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {e}"
40
+
41
+ # ์œ ์‚ฌ๋„ ์ž„๊ณ„๊ฐ’
42
+ COSINE_SIMILARITY_THRESHOLD = 0.8
43
+
44
+ # ์ฑ—๋ด‡ ํ•ต์‹ฌ ๋กœ์ง
45
+ def chatbot(user_question):
46
+ try:
47
+ user_embedding = model.encode(user_question)
48
+ similarities = df['embedding'].apply(lambda x: cosine_similarity([user_embedding], [x])[0][0])
49
+ best_match_index = similarities.idxmax()
50
+ best_score = similarities.loc[best_match_index]
51
+ best_match_row = df.loc[best_match_index]
52
+
53
+ if best_score >= COSINE_SIMILARITY_THRESHOLD:
54
+ answer = best_match_row['์ฑ—๋ด‡']
55
+ print(f"์œ ์‚ฌ๋„ ๊ธฐ๋ฐ˜ ๋‹ต๋ณ€. ์ ์ˆ˜: {best_score}")
56
+ return answer
57
+ else:
58
+ print(f"์œ ์‚ฌ๋„ ์ž„๊ณ„๊ฐ’({COSINE_SIMILARITY_THRESHOLD}) ๋ฏธ๋งŒ. Gemini ๋ชจ๋ธ์„ ํ˜ธ์ถœํ•ฉ๋‹ˆ๋‹ค. ์ ์ˆ˜: {best_score}")
59
+ return call_gemini_api(user_question)
60
+ except Exception as e:
61
+ print(f"์ฑ—๋ด‡ ์‹คํ–‰ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
62
+ return f"์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์ฑ—๋ด‡ ์‹คํ–‰ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {e}"
63
+
64
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ •์˜
65
+ demo = gr.Interface(
66
+ fn=chatbot,
67
+ inputs=gr.Textbox(lines=2, placeholder="์งˆ๋ฌธ์„ ์ž…๋ ฅํ•ด ์ฃผ์„ธ์š”...", label="์งˆ๋ฌธ", elem_id="user_question_input"),
68
+ outputs=gr.Textbox(lines=5, label="์ฑ—๋ด‡ ๋‹ต๋ณ€"),
69
+ title="๋˜๋ž˜ ์ƒ๋‹ด ์ฑ—๋ด‡",
70
+ description="5๋ถ„ ๋™์•ˆ ๋Œ€ํ™”ํ•˜์—ฌ ์ฃผ์‹œ๊ณ  ๋‹ค์Œ์˜ ๋งํฌ๋ฅผ ํด๋ฆญํ•˜์—ฌ ๊ผญ ์„ค๋ฌธ์กฐ์‚ฌ์— ์ฐธ์—ฌํ•ด์ฃผ์„ธ์š”! https://forms.gle/eWtyejQaQntKbbxG8"
71
+ )
72
+
73
+ # Hugging Face Spaces ํ™˜๊ฒฝ์— ๋งž์ถฐ server_name๊ณผ server_port๋ฅผ ๋ช…์‹œํ•ฉ๋‹ˆ๋‹ค.
74
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)