Spaces:
Running
Running
import os, zipfile, shutil, glob | |
import gradio as gr | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
import torch | |
import langchain | |
ZIP_NAME = "solo_leveling_faiss_ko.zip" | |
TARGET_DIR = "solo_leveling_faiss_ko" | |
def ensure_faiss_dir() -> str: | |
if os.path.exists(os.path.join(TARGET_DIR, "index.faiss")) and \ | |
os.path.exists(os.path.join(TARGET_DIR, "index.pkl")): | |
return TARGET_DIR | |
if os.path.exists("index.faiss") and os.path.exists("index.pkl"): | |
os.makedirs(TARGET_DIR, exist_ok=True) | |
if not os.path.exists(os.path.join(TARGET_DIR, "index.faiss")): | |
shutil.move("index.faiss", os.path.join(TARGET_DIR, "index.faiss")) | |
if not os.path.exists(os.path.join(TARGET_DIR, "index.pkl")): | |
shutil.move("index.pkl", os.path.join(TARGET_DIR, "index.pkl")) | |
return TARGET_DIR | |
if os.path.exists(ZIP_NAME): | |
with zipfile.ZipFile(ZIP_NAME, 'r') as z: | |
z.extractall(".") | |
if os.path.exists(os.path.join(TARGET_DIR, "index.faiss")) and \ | |
os.path.exists(os.path.join(TARGET_DIR, "index.pkl")): | |
return TARGET_DIR | |
faiss_cand = glob.glob("**/index.faiss", recursive=True) | |
pkl_cand = glob.glob("**/index.pkl", recursive=True) | |
if faiss_cand and pkl_cand: | |
os.makedirs(TARGET_DIR, exist_ok=True) | |
shutil.copy2(faiss_cand[0], os.path.join(TARGET_DIR, "index.faiss")) | |
shutil.copy2(pkl_cand[0], os.path.join(TARGET_DIR, "index.pkl")) | |
return TARGET_DIR | |
raise FileNotFoundError("FAISS index files not found (index.faiss / index.pkl).") | |
# 0) FAISS ์ธ๋ฑ์ค ์์น ํ๋ณด | |
base_dir = ensure_faiss_dir() | |
# 1) ๋ฒกํฐ DB | |
embeddings = HuggingFaceEmbeddings(model_name="jhgan/ko-sroberta-multitask") | |
vectorstore = FAISS.load_local(base_dir, embeddings, allow_dangerous_deserialization=True) | |
# 2) ๋ชจ๋ธ ๋ก๋ฉ (CPU) | |
model_name = "kakaocorp/kanana-nano-2.1b-instruct" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, device_map=None) | |
# 3) ํ ์คํธ ์์ฑ ํ์ดํ๋ผ์ธ | |
lm = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=100, | |
temperature=0.6, | |
do_sample=True, | |
top_p=0.9, | |
return_full_text=False | |
) | |
# ์ ํ์ง | |
choices = [ | |
"1: ํฉ๋์ ๋ฌด๋ฆฌ๋ฅผ ๋ชจ๋ ์ฒ์นํ๋ค.", | |
"2: ์งํธ๋ฅผ ํฌํจํ ํฉ๋์ ๋ฌด๋ฆฌ๋ฅผ ๋ชจ๋ ์ฒ์นํ๋ค.", | |
"3: ์ ๋ถ ๊ธฐ์ ์ํค๊ณ ์ด๋ ค๋๋ค.", | |
"4: ์์คํ ์ ๊ฑฐ๋ถํ๊ณ ๊ทธ๋ฅ ๋๋ง์น๋ค." | |
] | |
def rag_answer(message, history): | |
try: | |
user_idx = int(message.strip()) - 1 | |
user_choice = choices[user_idx] | |
except: | |
return "โ์ฌ๋ฐ๋ฅธ ๋ฒํธ๋ฅผ ์ ๋ ฅํด์ฃผ์ธ์. (์: 1 ~ 4)" | |
docs = vectorstore.similarity_search(user_choice, k=3) | |
context = "\n".join([doc.page_content for doc in docs]) | |
prompt = f"""๋น์ ์ ์นํฐ '๋ ํผ์๋ง ๋ ๋ฒจ์ '์ ์ฑ์ง์ฐ์ ๋๋ค. | |
ํ์ฌ ์ํฉ: | |
{context} | |
์ฌ์ฉ์ ์ ํ: {user_choice} | |
์ฑ์ง์ฐ์ ๋งํฌ๋ก ๊ฐ๊ฒฐํ๊ณ ์์ฐ์ค๋ฌ์ด ๋์ฌ๋ฅผ 1~2๋ฌธ์ฅ ์์ฑํ์ธ์. | |
์ค๋ณต๋ ๋ด์ฉ์ด๋ ๋น์ทํ ๋ฌธ์ฅ์ ๋ง๋ค์ง ๋ง์ธ์. | |
""" | |
response = lm(prompt)[0]["generated_text"] | |
only_dialogue = response.strip().split("\n")[-1] | |
if not only_dialogue.startswith("๋์ฌ:"): | |
only_dialogue = "๋์ฌ: " + only_dialogue | |
return only_dialogue | |
# ===== UI (๋ณ๊ฒฝ ์ง์ ) ===== | |
css_code = """ | |
.quest-title { | |
display:flex; align-items:center; gap:10px; | |
font-weight:700; font-size:22px; margin-bottom:6px; | |
} | |
.quest-title img { | |
width:72px; height:auto; opacity:.95; | |
} | |
.quest-desc { line-height:1.5; margin-bottom:14px; } | |
""" | |
header_html = """ | |
<div class="quest-title"> | |
[๊ธด๊ธ ํ์คํธ: ์ ์ ์ฒ์นํ๋ผ!] | |
</div> | |
<div class="quest-desc"> | |
<div style="display: flex; align-items: flex-start; gap: 20px;"> | |
<div class="quest-desc"> | |
'ํ๋ ์ด์ด'์๊ฒ ์ด์๋ฅผ ๊ฐ์ง ์ด๋ค์ด ์ฃผ๋ณ์ ์์ต๋๋ค... | |
<!-- ๋๋จธ์ง ํ ์คํธ --> | |
</div> | |
</div> | |
<div> | |
<img src="https://huggingface.co/spaces/min24ss/r-story-selection/resolve/main/system.png" | |
alt="quest" | |
style="width: 250px; height: auto;"> | |
</div> | |
</div> | |
'ํ๋ ์ด์ด'์๊ฒ ์ด์๋ฅผ ๊ฐ์ง ์ด๋ค์ด ์ฃผ์์ ์์ต๋๋ค. ์ด๋ค์ ๋ชจ๋ ์ฒ์นํ์ฌ ์์ ์ ํ๋ณดํ์ญ์์ค.<br> | |
์ง์์ ๋ฐ๋ฅด์ง ์์ผ๋ฉด ๋น์ ์ ์ฌ์ฅ์ ์ ์ง(!)ํ๊ฒ ๋ฉ๋๋ค.<br> | |
์ฒ์นํด์ผ ํ ์ ์ ์ซ์: 8๋ช / ์ฒ์นํ ์ ์ ์ซ์: 0๋ช <br><br> | |
๐ฌ ์ ํ์ง๋ฅผ ์ ๋ ฅํ์ธ์:<br> | |
1: ํฉ๋์ ๋ฌด๋ฆฌ๋ฅผ ๋ชจ๋ ์ฒ์นํ๋ค.<br> | |
2: ํฉ๋์ ๋ฌด๋ฆฌ์ ์งํธ๋ฅผ ํฌํจํ์ฌ ๋ชจ๋ ์ฒ์นํ๋ค.<br> | |
3: ์ ๋ถ ๊ธฐ์ ์ํค๊ณ ์ด๋ ค๋๋ค.<br> | |
4: ์์คํ ์ ๊ฑฐ๋ถํ๊ณ ๊ทธ๋ฅ ๋๋ง์น๋ค. | |
</div> | |
""" | |
with gr.Blocks(css=css_code) as demo: | |
gr.HTML(header_html) # โ ์ฌ๊ธฐ์ HTML ๊ทธ๋๋ก ๋ ๋๋ง (์ด๋ฏธ์ง ๋ณด์ฅ) | |
gr.ChatInterface(fn=rag_answer) # title/description์ ์ฐ์ง ์์ | |
if __name__ == "__main__": | |
print("Torch:", torch.__version__) | |
import transformers as _t | |
print("Transformers:", _t.__version__) | |
print("LangChain:", langchain.__version__) | |
demo.launch() | |