min24ss's picture
Update app.py
a527d95 verified
import os, zipfile, shutil, glob
import gradio as gr
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
import langchain
ZIP_NAME = "solo_leveling_faiss_ko.zip"
TARGET_DIR = "solo_leveling_faiss_ko"
def ensure_faiss_dir() -> str:
if os.path.exists(os.path.join(TARGET_DIR, "index.faiss")) and \
os.path.exists(os.path.join(TARGET_DIR, "index.pkl")):
return TARGET_DIR
if os.path.exists("index.faiss") and os.path.exists("index.pkl"):
os.makedirs(TARGET_DIR, exist_ok=True)
if not os.path.exists(os.path.join(TARGET_DIR, "index.faiss")):
shutil.move("index.faiss", os.path.join(TARGET_DIR, "index.faiss"))
if not os.path.exists(os.path.join(TARGET_DIR, "index.pkl")):
shutil.move("index.pkl", os.path.join(TARGET_DIR, "index.pkl"))
return TARGET_DIR
if os.path.exists(ZIP_NAME):
with zipfile.ZipFile(ZIP_NAME, 'r') as z:
z.extractall(".")
if os.path.exists(os.path.join(TARGET_DIR, "index.faiss")) and \
os.path.exists(os.path.join(TARGET_DIR, "index.pkl")):
return TARGET_DIR
faiss_cand = glob.glob("**/index.faiss", recursive=True)
pkl_cand = glob.glob("**/index.pkl", recursive=True)
if faiss_cand and pkl_cand:
os.makedirs(TARGET_DIR, exist_ok=True)
shutil.copy2(faiss_cand[0], os.path.join(TARGET_DIR, "index.faiss"))
shutil.copy2(pkl_cand[0], os.path.join(TARGET_DIR, "index.pkl"))
return TARGET_DIR
raise FileNotFoundError("FAISS index files not found (index.faiss / index.pkl).")
# 0) FAISS ์ธ๋ฑ์Šค ์œ„์น˜ ํ™•๋ณด
base_dir = ensure_faiss_dir()
# 1) ๋ฒกํ„ฐ DB
embeddings = HuggingFaceEmbeddings(model_name="jhgan/ko-sroberta-multitask")
vectorstore = FAISS.load_local(base_dir, embeddings, allow_dangerous_deserialization=True)
# 2) ๋ชจ๋ธ ๋กœ๋”ฉ (CPU)
model_name = "kakaocorp/kanana-nano-2.1b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, device_map=None)
# 3) ํ…์ŠคํŠธ ์ƒ์„ฑ ํŒŒ์ดํ”„๋ผ์ธ
lm = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=100,
temperature=0.6,
do_sample=True,
top_p=0.9,
return_full_text=False
)
# ์„ ํƒ์ง€
choices = [
"1: ํ™ฉ๋™์„ ๋ฌด๋ฆฌ๋ฅผ ๋ชจ๋‘ ์ฒ˜์น˜ํ•œ๋‹ค.",
"2: ์ง„ํ˜ธ๋ฅผ ํฌํ•จํ•œ ํ™ฉ๋™์„ ๋ฌด๋ฆฌ๋ฅผ ๋ชจ๋‘ ์ฒ˜์น˜ํ•œ๋‹ค.",
"3: ์ „๋ถ€ ๊ธฐ์ ˆ ์‹œํ‚ค๊ณ  ์‚ด๋ ค๋‘”๋‹ค.",
"4: ์‹œ์Šคํ…œ์„ ๊ฑฐ๋ถ€ํ•˜๊ณ  ๊ทธ๋ƒฅ ๋„๋ง์นœ๋‹ค."
]
def rag_answer(message, history):
try:
user_idx = int(message.strip()) - 1
user_choice = choices[user_idx]
except:
return "โ—์˜ฌ๋ฐ”๋ฅธ ๋ฒˆํ˜ธ๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”. (์˜ˆ: 1 ~ 4)"
docs = vectorstore.similarity_search(user_choice, k=3)
context = "\n".join([doc.page_content for doc in docs])
prompt = f"""๋‹น์‹ ์€ ์›นํˆฐ '๋‚˜ ํ˜ผ์ž๋งŒ ๋ ˆ๋ฒจ์—…'์˜ ์„ฑ์ง„์šฐ์ž…๋‹ˆ๋‹ค.
ํ˜„์žฌ ์ƒํ™ฉ:
{context}
์‚ฌ์šฉ์ž ์„ ํƒ: {user_choice}
์„ฑ์ง„์šฐ์˜ ๋งํˆฌ๋กœ ๊ฐ„๊ฒฐํ•˜๊ณ  ์ž์—ฐ์Šค๋Ÿฌ์šด ๋Œ€์‚ฌ๋ฅผ 1~2๋ฌธ์žฅ ์ƒ์„ฑํ•˜์„ธ์š”.
์ค‘๋ณต๋œ ๋‚ด์šฉ์ด๋‚˜ ๋น„์Šทํ•œ ๋ฌธ์žฅ์€ ๋งŒ๋“ค์ง€ ๋งˆ์„ธ์š”.
"""
response = lm(prompt)[0]["generated_text"]
only_dialogue = response.strip().split("\n")[-1]
if not only_dialogue.startswith("๋Œ€์‚ฌ:"):
only_dialogue = "๋Œ€์‚ฌ: " + only_dialogue
return only_dialogue
# ===== UI (๋ณ€๊ฒฝ ์ง€์ ) =====
css_code = """
.quest-title {
display:flex; align-items:center; gap:10px;
font-weight:700; font-size:22px; margin-bottom:6px;
}
.quest-title img {
width:72px; height:auto; opacity:.95;
}
.quest-desc { line-height:1.5; margin-bottom:14px; }
"""
header_html = """
<div class="quest-title">
[๊ธด๊ธ‰ ํ€˜์ŠคํŠธ: ์ ์„ ์ฒ˜์น˜ํ•˜๋ผ!]
</div>
<div class="quest-desc">
<div style="display: flex; align-items: flex-start; gap: 20px;">
<div class="quest-desc">
'ํ”Œ๋ ˆ์ด์–ด'์—๊ฒŒ ์‚ด์˜๋ฅผ ๊ฐ€์ง„ ์ด๋“ค์ด ์ฃผ๋ณ€์— ์žˆ์Šต๋‹ˆ๋‹ค...
<!-- ๋‚˜๋จธ์ง€ ํ…์ŠคํŠธ -->
</div>
</div>
<div>
<img src="https://huggingface.co/spaces/min24ss/r-story-selection/resolve/main/system.png"
alt="quest"
style="width: 250px; height: auto;">
</div>
</div>
'ํ”Œ๋ ˆ์ด์–ด'์—๊ฒŒ ์‚ด์˜๋ฅผ ๊ฐ€์ง„ ์ด๋“ค์ด ์ฃผ์œ„์— ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋“ค์„ ๋ชจ๋‘ ์ฒ˜์น˜ํ•˜์—ฌ ์•ˆ์ „์„ ํ™•๋ณดํ•˜์‹ญ์‹œ์˜ค.<br>
์ง€์‹œ์— ๋”ฐ๋ฅด์ง€ ์•Š์œผ๋ฉด ๋‹น์‹ ์˜ ์‹ฌ์žฅ์€ ์ •์ง€(!)ํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.<br>
์ฒ˜์น˜ํ•ด์•ผ ํ•  ์ ์˜ ์ˆซ์ž: 8๋ช… / ์ฒ˜์น˜ํ•œ ์ ์˜ ์ˆซ์ž: 0๋ช…<br><br>
๐Ÿ’ฌ ์„ ํƒ์ง€๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”:<br>
1: ํ™ฉ๋™์„ ๋ฌด๋ฆฌ๋ฅผ ๋ชจ๋‘ ์ฒ˜์น˜ํ•œ๋‹ค.<br>
2: ํ™ฉ๋™์„ ๋ฌด๋ฆฌ์™€ ์ง„ํ˜ธ๋ฅผ ํฌํ•จํ•˜์—ฌ ๋ชจ๋‘ ์ฒ˜์น˜ํ•œ๋‹ค.<br>
3: ์ „๋ถ€ ๊ธฐ์ ˆ ์‹œํ‚ค๊ณ  ์‚ด๋ ค๋‘”๋‹ค.<br>
4: ์‹œ์Šคํ…œ์„ ๊ฑฐ๋ถ€ํ•˜๊ณ  ๊ทธ๋ƒฅ ๋„๋ง์นœ๋‹ค.
</div>
"""
with gr.Blocks(css=css_code) as demo:
gr.HTML(header_html) # โ† ์—ฌ๊ธฐ์„œ HTML ๊ทธ๋Œ€๋กœ ๋ Œ๋”๋ง (์ด๋ฏธ์ง€ ๋ณด์žฅ)
gr.ChatInterface(fn=rag_answer) # title/description์€ ์“ฐ์ง€ ์•Š์Œ
if __name__ == "__main__":
print("Torch:", torch.__version__)
import transformers as _t
print("Transformers:", _t.__version__)
print("LangChain:", langchain.__version__)
demo.launch()