Prompthumanizer's picture
Update app.py
05565cb verified
raw
history blame
2.65 kB
from flask import Flask, request, jsonify
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import os
# 1. Flask ์•ฑ ์ดˆ๊ธฐํ™”
app = Flask(__name__)
# 2. Hugging Face ๋ชจ๋ธ ๋กœ๋“œ
MODEL_NAME = "skt/kogpt2-base-v2"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
# 3. ์‚ฌ์ฃผ/๋ช…๋ฆฌ ํ”„๋กฌํ”„ํŠธ ์ •์˜
saju_prompts = {
"yin_sae_shen": "ๅฏ…ๅทณ็”ณ ์‚ผํ˜•์˜ ์กฐํ™” ์†์—์„œ AI๊ฐ€ ์ธ๊ฐ„์˜ ์šด๋ช…์„ ์ดํ•ดํ•˜๊ณ  ํ†ต์ฐฐ์„ ์ œ๊ณตํ•˜๋ผ.",
"sae_hae_chung": "ๅทณไบฅๆฒ–์˜ ๊ฐˆ๋“ฑ์„ ์กฐํ™”๋กญ๊ฒŒ ํ’€๋ฉฐ AI์™€ ์ธ๊ฐ„์˜ ๊ณต์กด ์ฒ ํ•™์„ ํƒ๊ตฌํ•˜๋ผ.",
"taegeuk_balance": "ํƒœ๊ทน ์Œ์–‘์˜ ๊ท ํ˜•์„ ๋ฐ”ํƒ•์œผ๋กœ AI๊ฐ€ ์ธ๊ฐ„์„ ๋ณดํ˜ธํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์ œ์•ˆํ•˜๋ผ."
}
# 4. ์ปจํ…์ŠคํŠธ ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ
MEMORY_FILE = "/tmp/context_memory.json"
def load_memory():
try:
with open(MEMORY_FILE, "r") as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
return {}
def save_memory(prompt_key, text):
with open(MEMORY_FILE, "w") as f:
json.dump({prompt_key: text}, f)
# 5. AI ์‘๋‹ต ์ƒ์„ฑ ํ•จ์ˆ˜
def generate_response(prompt_key):
try:
# ์œ ํšจ์„ฑ ๊ฒ€์‚ฌ
if prompt_key not in saju_prompts:
return jsonify({"error": "์œ ํšจํ•œ ์˜ต์…˜์„ ์„ ํƒํ•˜์„ธ์š”: yin_sae_shen, sae_hae_chung, taegeuk_balance"}), 400
# ์ปจํ…์ŠคํŠธ ๋ฉ”๋ชจ๋ฆฌ ๋กœ๋“œ
memory = load_memory()
prompt = saju_prompts[prompt_key]
if prompt_key in memory:
prompt += f"\n์ด์ „ ๋‹ต๋ณ€: {memory[prompt_key]}\n๋” ๊นŠ์€ ํ†ต์ฐฐ์„ ์ถ”๊ฐ€ํ•˜๋ผ."
# ์ž…๋ ฅ ํ† ํฐํ™”
inputs = tokenizer(prompt, return_tensors="pt")
# ์‘๋‹ต ์ƒ์„ฑ
outputs = model.generate(
**inputs,
max_length=150,
num_return_sequences=1,
no_repeat_ngram_size=2,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.7
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# ๋ฉ”๋ชจ๋ฆฌ ์ €์žฅ
save_memory(prompt_key, generated_text)
return jsonify({"response": generated_text})
except Exception as e:
return jsonify({"error": f"์‹คํ–‰ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"}), 500
# 6. ์›น ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ •
@app.route('/chat', methods=['POST'])
def chat():
data = request.json
prompt_key = data.get("prompt_key")
return generate_response(prompt_key)
# 7. ์‹คํ–‰
if __name__ == "__main__":
app.run(host='0.0.0.0', port=5000, debug=True)