Spaces:
Running
Running
#!/usr/bin/env python | |
# coding: utf-8 | |
# ## 1. tsv full data load | |
import pandas as pd | |
df = pd.read_csv("sl_webtoon_full_data_sequential.tsv", sep="\t") | |
print(df.head()) | |
print("์ ์ฒด ๋ฌธ์ฅ ์:", len(df)) | |
print("์ปฌ๋ผ ๋ชฉ๋ก:", df.columns.tolist()) | |
df['row_id'] = df.index # ์ธ๋ฑ์ค ์ปฌ๋ผ ์ถ๊ฐ | |
df['text'] = df.apply( | |
lambda x: f"[{x['์ํผ์๋']}] #{x['row_id']} {x['type']} {x['scene_text']}", | |
axis=1 | |
) | |
texts = df['text'].tolist() | |
print("์ต์ข ๋ฌธ์ฅ ์:", len(texts)) | |
# ## 2. RAG ๋ฌธ์ฅ ์์ฑ | |
print("์์ 5๊ฐ:") | |
for t in df['text'].head(5).tolist(): | |
print("-", t) | |
# ## 3. ํ๊ตญ์ด ์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ก๋, ๋ฒกํฐ db | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings import HuggingFaceEmbeddings | |
embedding_model = HuggingFaceEmbeddings(model_name='jhgan/ko-sroberta-multitask') | |
db = FAISS.from_texts(texts, embedding_model) | |
print(" ๋ฒกํฐDB ์์ฑ ์๋ฃ. ์ด ๋ฌธ์ฅ ์:", len(texts)) | |
db.save_local("solo_leveling_faiss_ko") | |
db = FAISS.load_local("solo_leveling_faiss_ko", embedding_model, allow_dangerous_deserialization=True) | |
# ๊ฒ์ ํ ์คํธ | |
query = "๋ง๋์์ด ๋ญ์ง?" | |
docs = db.similarity_search(query, k=5) | |
for i, doc in enumerate(docs, 1): | |
print(f"[{i}] {doc.page_content}") | |
# ## 4. LLM ๋ก๋ (CPU ์ ์ฉ) | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
from langchain.chains import RetrievalQA | |
from langchain.prompts import PromptTemplate | |
from langchain_community.llms import HuggingFacePipeline | |
import torch | |
# CPU๋ก ๊ฐ์ | |
generator = pipeline( | |
"text-generation", | |
model="kakaocorp/kanana-nano-2.1b-instruct", | |
device=-1 # โ CPU ์ฌ์ฉ | |
) | |
embedding_model = HuggingFaceEmbeddings(model_name='jhgan/ko-sroberta-multitask') | |
vectorstore = FAISS.load_local("solo_leveling_faiss_ko", embedding_model, allow_dangerous_deserialization=True) | |
model_name = "kakaocorp/kanana-nano-2.1b-instruct" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float32 # โ CPU์์๋ float32 | |
).to("cpu") # โ CPU ์ฌ์ฉ | |
llm_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=128) | |
llm = HuggingFacePipeline(pipeline=llm_pipeline) | |
custom_prompt = PromptTemplate( | |
input_variables=["context", "question"], | |
template="๋ค์ ๋ฌธ๋งฅ์ ์ฐธ๊ณ ํ์ฌ ์ง๋ฌธ์ ๋ตํ์ธ์.\n\n๋ฌธ๋งฅ:\n{context}\n\n์ง๋ฌธ:\n{question}\n\n๋ต๋ณ:" | |
) | |
qa_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
retriever=vectorstore.as_retriever(search_kwargs={"k": 5}), | |
chain_type="stuff", | |
return_source_documents=True, | |
chain_type_kwargs={"prompt": custom_prompt} | |
) | |
# ์ง๋ฌธ ํ ์คํธ | |
query = "์ฑ์ง์ฐ๋ ๋ช ๊ธ ํํฐ์ง?" | |
result = qa_chain({"query": query}) | |
print("๋ต๋ณ:", result["result"]) | |
print("\n์ฐธ์กฐ ๋ฌธ์:") | |
for doc in result["source_documents"]: | |
print(doc.page_content) | |
# ## 5. ํฉ๋์ ์ํผ์๋ | |
choices = [ | |
"1: ํฉ๋์ ๋ฌด๋ฆฌ๋ฅผ ๋ชจ๋ ์ฒ์นํ๋ค.", | |
"2: ์งํธ๋ฅผ ํฌํจํ ํฉ๋์ ๋ฌด๋ฆฌ๋ฅผ ๋ชจ๋ ์ฒ์นํ๋ค.", | |
"3: ์ ๋ถ ๊ธฐ์ ์ํค๊ณ ์ด๋ ค๋๋ค.", | |
"4: ์์คํ ์ ๊ฑฐ๋ถํ๊ณ ๊ทธ๋ฅ ๋๋ง์น๋ค." | |
] | |
print("\n[์ ํ์ง]") | |
for idx, choice in enumerate(choices, start=1): | |
print(f"{idx}. {choice}") | |
user_idx = int(input("\n์ ํ ๋ฒํธ ์ ๋ ฅ: ")) - 1 | |
user_choice = choices[user_idx] | |
print(f"\n[์ฌ์ฉ์ ์ ํ]: {user_choice}") | |
result = qa_chain({"query": user_choice}) | |
retrieved_context = "\n".join([doc.page_content for doc in result["source_documents"]]) | |
print("\n[๊ฒ์๋ ๊ทผ๊ฑฐ ๋ฌธ์ ์์]") | |
print(retrieved_context[:600], "...") | |
prompt = f""" | |
๋น์ ์ ์นํฐ '๋ ํผ์๋ง ๋ ๋ฒจ์ '์ ์ฑ์ง์ฐ์ ๋๋ค. | |
ํ์ฌ ์ํฉ: | |
{retrieved_context} | |
์ฌ์ฉ์ ์ ํ: {user_choice} | |
์ฑ์ง์ฐ์ ๋งํฌ๋ก ๊ฐ๊ฒฐํ๊ณ ์์ฐ์ค๋ฌ์ด ๋์ฌ๋ฅผ 1~2๋ฌธ์ฅ ์์ฑํ์ธ์. | |
์ค๋ณต๋ ๋ด์ฉ์ด๋ ๋น์ทํ ๋ฌธ์ฅ์ ๋ง๋ค์ง ๋ง์ธ์. | |
""" | |
response = generator( | |
prompt, | |
max_new_tokens=200, | |
do_sample=True, | |
temperature=0.6, | |
top_p=0.9, | |
return_full_text=False | |
)[0]["generated_text"] | |
print("\n[์ฑ์ง์ฐ ์๋ต]") | |
print(response) | |