|
from flask import Flask, request, jsonify |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import json |
|
import os |
|
|
|
|
|
app = Flask(__name__) |
|
|
|
|
|
MODEL_NAME = "skt/kogpt2-base-v2" |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) |
|
|
|
|
|
saju_prompts = { |
|
"yin_sae_shen": "ๅฏ
ๅทณ็ณ ์ผํ์ ์กฐํ ์์์ AI๊ฐ ์ธ๊ฐ์ ์ด๋ช
์ ์ดํดํ๊ณ ํต์ฐฐ์ ์ ๊ณตํ๋ผ.", |
|
"sae_hae_chung": "ๅทณไบฅๆฒ์ ๊ฐ๋ฑ์ ์กฐํ๋กญ๊ฒ ํ๋ฉฐ AI์ ์ธ๊ฐ์ ๊ณต์กด ์ฒ ํ์ ํ๊ตฌํ๋ผ.", |
|
"taegeuk_balance": "ํ๊ทน ์์์ ๊ท ํ์ ๋ฐํ์ผ๋ก AI๊ฐ ์ธ๊ฐ์ ๋ณดํธํ๋ ๋ฐฉ๋ฒ์ ์ ์ํ๋ผ." |
|
} |
|
|
|
|
|
MEMORY_FILE = "/tmp/context_memory.json" |
|
|
|
def load_memory(): |
|
try: |
|
with open(MEMORY_FILE, "r") as f: |
|
return json.load(f) |
|
except (FileNotFoundError, json.JSONDecodeError): |
|
return {} |
|
|
|
def save_memory(prompt_key, text): |
|
with open(MEMORY_FILE, "w") as f: |
|
json.dump({prompt_key: text}, f) |
|
|
|
|
|
def generate_response(prompt_key): |
|
try: |
|
|
|
if prompt_key not in saju_prompts: |
|
return jsonify({"error": "์ ํจํ ์ต์
์ ์ ํํ์ธ์: yin_sae_shen, sae_hae_chung, taegeuk_balance"}), 400 |
|
|
|
|
|
memory = load_memory() |
|
prompt = saju_prompts[prompt_key] |
|
if prompt_key in memory: |
|
prompt += f"\n์ด์ ๋ต๋ณ: {memory[prompt_key]}\n๋ ๊น์ ํต์ฐฐ์ ์ถ๊ฐํ๋ผ." |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
outputs = model.generate( |
|
**inputs, |
|
max_length=150, |
|
num_return_sequences=1, |
|
no_repeat_ngram_size=2, |
|
do_sample=True, |
|
top_k=50, |
|
top_p=0.95, |
|
temperature=0.7 |
|
) |
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
save_memory(prompt_key, generated_text) |
|
return jsonify({"response": generated_text}) |
|
|
|
except Exception as e: |
|
return jsonify({"error": f"์คํ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}"}), 500 |
|
|
|
|
|
@app.route('/chat', methods=['POST']) |
|
def chat(): |
|
data = request.json |
|
prompt_key = data.get("prompt_key") |
|
return generate_response(prompt_key) |
|
|
|
|
|
if __name__ == "__main__": |
|
app.run(host='0.0.0.0', port=5000, debug=True) |
|
|