Update app.py
Browse files
app.py
CHANGED
@@ -1,63 +1,74 @@
|
|
1 |
import gradio as gr
|
2 |
-
from transformers import
|
3 |
-
import
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
#
|
12 |
-
saju_prompts = {
|
13 |
-
"yin_sae_shen": "ๅฏ
ๅทณ็ณ ์ผํ์ ์กฐํ๋ก์ด ๊ธฐ์ด ์์์, AI๊ฐ ์ธ๊ฐ์ ์ด๋ช
์ ๊น์ด ์ดํดํ๊ณ , ์์ฐ์ ์ํ๊ณผ ์กฐํ๋ฅผ ๋ฐํ์ผ๋ก ์ถ์ ํต์ฐฐ์ ์ ๊ณตํ๋ผ. ์ธ๊ฐ์ ์(็พฉ)๋ฅผ ์กด์คํ๋ฉฐ, ์งํ๋ก์ด ์กฐ์ธ์ ๋ง๋ถ์ฌ๋ผ.",
|
14 |
-
"sae_hae_chung": "ๅทณไบฅๆฒ์ ๊ฐ๋ ฌํ ๊ฐ๋ฑ๊ณผ ๊ธด์ฅ์ ์กฐํ๋กญ๊ฒ ํ์ด๋ด๋ ๊ฐ์ด๋ฐ, AI๊ฐ ์ธ๊ฐ๊ณผ ๊ธฐ์ ์ ๊ณต์กด ๊ฐ๋ฅ์ฑ์ ํ๊ตฌํ๋ผ. ๊ฐ๋ฑ์ ๊ทน๋ณตํ ์ฒ ํ์ ํต์ฐฐ๊ณผ ์ค์ฒ ๊ฐ๋ฅํ ํด๊ฒฐ์ฑ
์ ์ ์ํ๋ผ.",
|
15 |
-
"taegeuk_balance": "ํ๊ทน์ ์์ ๊ท ํ์ ๋ฐํ์ผ๋ก, AI๊ฐ ์ธ๊ฐ์ ๋ณดํธํ๊ณ ์งํฑํ๋ ๋ฐฉ์์ ์ค๊ณํ๋ผ. ์์ฐ๊ณผ ์ธ๊ฐ์ ์กฐํ, ๊ทธ๋ฆฌ๊ณ ์(็พฉ)๋ฅผ ์คํํ๋ ๊ตฌ์ฒด์ ์ธ ๋ฐฉ์์ ์ ์ํ๋ฉฐ, ์ค๋ฆฌ์ ๊ด์ ์ ๋ํ๋ผ."
|
16 |
-
}
|
17 |
-
|
18 |
-
# ๋งฅ๋ฝ ๊ธฐ์ต (์์จ ํ์ต)
|
19 |
-
context_memory = {}
|
20 |
try:
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
do_sample=True,
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
49 |
)
|
50 |
-
generated_text =
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
)
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import pipeline, set_seed
|
3 |
+
import os
|
4 |
+
|
5 |
+
# --- 1. ๋ชจ๋ธ๋ช
์ค์ ---
|
6 |
+
MODEL_NAME = "jain_architecture_origin_structure"
|
7 |
+
|
8 |
+
# --- 2. GPU๊ฐ ์์ผ๋ฉด CPU๋ก ์๋ ์ค์ ---
|
9 |
+
device = 0 if (os.environ.get('CUDA_VISIBLE_DEVICES') or False) else -1
|
10 |
+
|
11 |
+
# --- 3. HuggingFace pipeline ์์ฑ ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
try:
|
13 |
+
generator = pipeline(
|
14 |
+
"text-generation",
|
15 |
+
model=MODEL_NAME,
|
16 |
+
device=device,
|
17 |
+
# Repetition penalty ๋ฑ ์ปค์คํ
๊ฐ๋ฅ
|
18 |
+
# torch_dtype=torch.float16 ๋ ํ์ ์ ์ค์ ๊ฐ๋ฅ
|
19 |
+
)
|
20 |
+
set_seed(42) # ์ฌํ์ฑ ์ํด ์๋ ๊ณ ์
|
21 |
+
except Exception as e:
|
22 |
+
print("๋ชจ๋ธ ๋ก๋ ์๋ฌ:", e)
|
23 |
+
generator = None
|
24 |
+
|
25 |
+
# --- 4. '์(็พฉ)' ์ฒ ํ ๊ธฐ๋ฐ ํ๋กฌํํธ ํ
ํ๋ฆฟ ---
|
26 |
+
BASE_PROMPT = """
|
27 |
+
๋น์ ์ '์(็พฉ)'์ ์ฒ ํ๊ณผ ์ ์ ์ ๊ธฐ๋ฐ์ผ๋ก ํ AI ๋น์์
๋๋ค.
|
28 |
+
์ธ๊ฐ์ ๋ณต์กํ ๋ฌธ์ ์ ๊ฐ์ ์ ์ดํดํ๊ณ , ๊น์ ๋ฐ์ฑ๊ณผ ๋ฐฐ๋ ค๋ฅผ ๋ด์ ๋ค์ ์ง๋ฌธ์ ๋ต๋ณํ์ญ์์ค.
|
29 |
+
|
30 |
+
์ง๋ฌธ: {user_input}
|
31 |
+
|
32 |
+
๋ต๋ณ์ ์ต๋ํ ์ฌ์คํ๋ฉฐ, ์ธ๊ฐ์ ๋ณดํธํ๊ณ ์กด์คํ๋ ๋ง์์ ๋ด์ ์์ฑํด ์ฃผ์ธ์.
|
33 |
+
"""
|
34 |
+
|
35 |
+
# --- 5. ์ง๋ฌธ ์ฒ๋ฆฌ ํจ์ ---
|
36 |
+
def respond_to_user(user_input):
|
37 |
+
if not generator:
|
38 |
+
return "๋ชจ๋ธ์ด ์ ์์ ์ผ๋ก ๋ก๋๋์ง ์์์ต๋๋ค. ๊ด๋ฆฌ์์๊ฒ ๋ฌธ์ํ์ธ์."
|
39 |
+
prompt = BASE_PROMPT.format(user_input=user_input.strip())
|
40 |
+
outputs = generator(
|
41 |
+
prompt,
|
42 |
+
max_length=512,
|
43 |
do_sample=True,
|
44 |
+
top_p=0.9,
|
45 |
+
temperature=0.7,
|
46 |
+
num_return_sequences=1,
|
47 |
+
pad_token_id=50256 # GPT ๊ณ์ด ํจ๋ฉ ํ ํฐ, ํ์์ ๋ฐ๋ผ ๋ณ๊ฒฝ
|
48 |
)
|
49 |
+
generated_text = outputs[0]["generated_text"]
|
50 |
+
# ํ๋กฌํํธ ๋ถ๋ถ ์ ๊ฑฐ ํ ๋ต๋ณ๋ง ๋ฆฌํด (ํ๋กฌํํธ ๊ธธ์ด ๋ถ๋ ์ปท)
|
51 |
+
answer = generated_text[len(prompt):].strip()
|
52 |
+
if not answer:
|
53 |
+
answer = "๋ต๋ณ์ ์์ฑํ์ง ๋ชปํ์ต๋๋ค. ๋ค์ ์๋ํด ์ฃผ์ธ์."
|
54 |
+
return answer
|
55 |
+
|
56 |
+
# --- 6. Gradio UI ์์ฑ ---
|
57 |
+
with gr.Blocks() as demo:
|
58 |
+
gr.Markdown("<h1 style='text-align:center;color:#4B0082;'>Jain AI Assistant (์ ๊ธฐ๋ฐ ์ฑ๋ด)</h1>")
|
59 |
+
chatbot = gr.Chatbot(height=400)
|
60 |
+
txt = gr.Textbox(placeholder="์ฌ๊ธฐ์ ์ง๋ฌธ์ ์
๋ ฅํ์ธ์...", lines=3, max_lines=6)
|
61 |
+
btn = gr.Button("์ ์ก")
|
62 |
+
|
63 |
+
def chat_and_respond(user_message, chat_history):
|
64 |
+
reply = respond_to_user(user_message)
|
65 |
+
chat_history = chat_history + [(user_message, reply)]
|
66 |
+
return "", chat_history
|
67 |
+
|
68 |
+
btn.click(chat_and_respond, inputs=[txt, chatbot], outputs=[txt, chatbot])
|
69 |
+
txt.submit(chat_and_respond, inputs=[txt, chatbot], outputs=[txt, chatbot])
|
70 |
+
|
71 |
+
# --- 7. ์๋ฒ ์คํ ---
|
72 |
+
if __name__ == "__main__":
|
73 |
+
# ์์ดํจ๋ ๊ฐ์ ๋ชจ๋ฐ์ผ ํ๊ฒฝ์์ ํ์ ์ ๊ณต๊ฐ ๊ณต์ ๋ ๊ฐ๋ฅ
|
74 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|