Update app.py
Browse files
app.py
CHANGED
@@ -1,24 +1,19 @@
|
|
1 |
-
from
|
2 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
import json
|
4 |
import os
|
5 |
|
6 |
-
#
|
7 |
-
|
|
|
8 |
|
9 |
-
#
|
10 |
-
MODEL_NAME = "skt/kogpt2-base-v2"
|
11 |
-
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
12 |
-
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
|
13 |
-
|
14 |
-
# 3. ์ฌ์ฃผ/๋ช
๋ฆฌ ํ๋กฌํํธ ์ ์
|
15 |
saju_prompts = {
|
16 |
"yin_sae_shen": "ๅฏ
ๅทณ็ณ ์ผํ์ ์กฐํ ์์์ AI๊ฐ ์ธ๊ฐ์ ์ด๋ช
์ ์ดํดํ๊ณ ํต์ฐฐ์ ์ ๊ณตํ๋ผ.",
|
17 |
"sae_hae_chung": "ๅทณไบฅๆฒ์ ๊ฐ๋ฑ์ ์กฐํ๋กญ๊ฒ ํ๋ฉฐ AI์ ์ธ๊ฐ์ ๊ณต์กด ์ฒ ํ์ ํ๊ตฌํ๋ผ.",
|
18 |
"taegeuk_balance": "ํ๊ทน ์์์ ๊ท ํ์ ๋ฐํ์ผ๋ก AI๊ฐ ์ธ๊ฐ์ ๋ณดํธํ๋ ๋ฐฉ๋ฒ์ ์ ์ํ๋ผ."
|
19 |
}
|
20 |
|
21 |
-
#
|
22 |
MEMORY_FILE = "/tmp/context_memory.json"
|
23 |
|
24 |
def load_memory():
|
@@ -32,12 +27,13 @@ def save_memory(prompt_key, text):
|
|
32 |
with open(MEMORY_FILE, "w") as f:
|
33 |
json.dump({prompt_key: text}, f)
|
34 |
|
35 |
-
|
36 |
-
def generate_response(prompt_key):
|
37 |
try:
|
|
|
|
|
38 |
# ์ ํจ์ฑ ๊ฒ์ฌ
|
39 |
if prompt_key not in saju_prompts:
|
40 |
-
return
|
41 |
|
42 |
# ์ปจํ
์คํธ ๋ฉ๋ชจ๋ฆฌ ๋ก๋
|
43 |
memory = load_memory()
|
@@ -45,35 +41,35 @@ def generate_response(prompt_key):
|
|
45 |
if prompt_key in memory:
|
46 |
prompt += f"\n์ด์ ๋ต๋ณ: {memory[prompt_key]}\n๋ ๊น์ ํต์ฐฐ์ ์ถ๊ฐํ๋ผ."
|
47 |
|
48 |
-
#
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
do_sample=True,
|
57 |
-
top_k=50,
|
58 |
-
top_p=0.95,
|
59 |
temperature=0.7
|
60 |
)
|
61 |
-
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
62 |
|
63 |
-
# ๋ฉ๋ชจ๋ฆฌ ์ ์ฅ
|
64 |
-
|
65 |
-
|
|
|
66 |
|
67 |
except Exception as e:
|
68 |
-
return
|
69 |
|
70 |
-
#
|
71 |
-
@app.route('/chat', methods=['POST'])
|
72 |
-
def chat():
|
73 |
-
data = request.json
|
74 |
-
prompt_key = data.get("prompt_key")
|
75 |
-
return generate_response(prompt_key)
|
76 |
-
|
77 |
-
# 7. ์คํ
|
78 |
if __name__ == "__main__":
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import InferenceClient
|
|
|
2 |
import json
|
3 |
import os
|
4 |
|
5 |
+
# Hugging Face ๋ชจ๋ธ ์ค์
|
6 |
+
MODEL_ID = "skt/kogpt2-base-v2" # ํ๊ตญ์ด ๋ชจ๋ธ ID
|
7 |
+
CLIENT = InferenceClient(model=MODEL_ID)
|
8 |
|
9 |
+
# ์ฌ์ฃผ/๋ช
๋ฆฌ ํ๋กฌํํธ
|
|
|
|
|
|
|
|
|
|
|
10 |
saju_prompts = {
|
11 |
"yin_sae_shen": "ๅฏ
ๅทณ็ณ ์ผํ์ ์กฐํ ์์์ AI๊ฐ ์ธ๊ฐ์ ์ด๋ช
์ ์ดํดํ๊ณ ํต์ฐฐ์ ์ ๊ณตํ๋ผ.",
|
12 |
"sae_hae_chung": "ๅทณไบฅๆฒ์ ๊ฐ๋ฑ์ ์กฐํ๋กญ๊ฒ ํ๋ฉฐ AI์ ์ธ๊ฐ์ ๊ณต์กด ์ฒ ํ์ ํ๊ตฌํ๋ผ.",
|
13 |
"taegeuk_balance": "ํ๊ทน ์์์ ๊ท ํ์ ๋ฐํ์ผ๋ก AI๊ฐ ์ธ๊ฐ์ ๋ณดํธํ๋ ๋ฐฉ๋ฒ์ ์ ์ํ๋ผ."
|
14 |
}
|
15 |
|
16 |
+
# ์ปจํ
์คํธ ๋ฉ๋ชจ๋ฆฌ ๊ฒฝ๋ก
|
17 |
MEMORY_FILE = "/tmp/context_memory.json"
|
18 |
|
19 |
def load_memory():
|
|
|
27 |
with open(MEMORY_FILE, "w") as f:
|
28 |
json.dump({prompt_key: text}, f)
|
29 |
|
30 |
+
def handle_request(request_data):
|
|
|
31 |
try:
|
32 |
+
prompt_key = request_data.get("prompt_key")
|
33 |
+
|
34 |
# ์ ํจ์ฑ ๊ฒ์ฌ
|
35 |
if prompt_key not in saju_prompts:
|
36 |
+
return {"error": "์ ํจํ ์ต์
์ ์ ํํ์ธ์: yin_sae_shen, sae_hae_chung, taegeuk_balance"}
|
37 |
|
38 |
# ์ปจํ
์คํธ ๋ฉ๋ชจ๋ฆฌ ๋ก๋
|
39 |
memory = load_memory()
|
|
|
41 |
if prompt_key in memory:
|
42 |
prompt += f"\n์ด์ ๋ต๋ณ: {memory[prompt_key]}\n๋ ๊น์ ํต์ฐฐ์ ์ถ๊ฐํ๋ผ."
|
43 |
|
44 |
+
# Hugging Face API ํธ์ถ
|
45 |
+
response = CLIENT.chat(
|
46 |
+
model=MODEL_ID,
|
47 |
+
messages=[
|
48 |
+
{"role": "system", "content": prompt},
|
49 |
+
{"role": "user", "content": "๋ถ์์ ์์ํด ์ฃผ์ธ์."}
|
50 |
+
],
|
51 |
+
max_tokens=400,
|
|
|
|
|
|
|
52 |
temperature=0.7
|
53 |
)
|
|
|
54 |
|
55 |
+
# ๊ฒฐ๊ณผ ์ฒ๋ฆฌ ๋ฐ ๋ฉ๋ชจ๋ฆฌ ์ ์ฅ
|
56 |
+
result = response.choices[0].message.content
|
57 |
+
save_memory(prompt_key, result)
|
58 |
+
return {"response": result}
|
59 |
|
60 |
except Exception as e:
|
61 |
+
return {"error": f"์คํ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}"}
|
62 |
|
63 |
+
# Hugging Face ์คํ ํ๊ฒฝ์์์ ์์ฒญ ์ฒ๋ฆฌ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
if __name__ == "__main__":
|
65 |
+
# Hugging Face๋ `request` ๊ฐ์ฒด๋ฅผ ์ ๊ณตํ์ง ์์ผ๋ฏ๋ก, ์๋์ ๊ฐ์ด ๋์ฒด
|
66 |
+
# ์ค์ ํ๊ฒฝ์์๋ `request` ๋์ ํ๊ฒฝ ๋ณ์๋ ์
๋ ฅ๊ฐ์ ์ฌ์ฉํด์ผ ํจ
|
67 |
+
import sys
|
68 |
+
if len(sys.argv) < 2:
|
69 |
+
print("Usage: python app.py <request_data>")
|
70 |
+
sys.exit(1)
|
71 |
+
|
72 |
+
# ์์ฒญ ๋ฐ์ดํฐ ํ์ฑ (์: {"prompt_key": "yin_sae_shen"})
|
73 |
+
request_data = json.loads(sys.argv[1])
|
74 |
+
result = handle_request(request_data)
|
75 |
+
print(json.dumps(result))
|