r-story-selection / r_story_test.py
min24ss's picture
Upload r_story_test.py
743534b verified
#!/usr/bin/env python
# coding: utf-8
# ## 1. tsv full data load
import pandas as pd
df = pd.read_csv("sl_webtoon_full_data_sequential.tsv", sep="\t")
print(df.head())
print("์ „์ฒด ๋ฌธ์žฅ ์ˆ˜:", len(df))
print("์ปฌ๋Ÿผ ๋ชฉ๋ก:", df.columns.tolist())
df['row_id'] = df.index # ์ธ๋ฑ์Šค ์ปฌ๋Ÿผ ์ถ”๊ฐ€
df['text'] = df.apply(
lambda x: f"[{x['์—ํ”ผ์†Œ๋“œ']}] #{x['row_id']} {x['type']} {x['scene_text']}",
axis=1
)
texts = df['text'].tolist()
print("์ตœ์ข… ๋ฌธ์žฅ ์ˆ˜:", len(texts))
# ## 2. RAG ๋ฌธ์žฅ ์ƒ์„ฑ
print("์˜ˆ์‹œ 5๊ฐœ:")
for t in df['text'].head(5).tolist():
print("-", t)
# ## 3. ํ•œ๊ตญ์–ด ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋กœ๋“œ, ๋ฒกํ„ฐ db
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
embedding_model = HuggingFaceEmbeddings(model_name='jhgan/ko-sroberta-multitask')
db = FAISS.from_texts(texts, embedding_model)
print(" ๋ฒกํ„ฐDB ์ƒ์„ฑ ์™„๋ฃŒ. ์ด ๋ฌธ์žฅ ์ˆ˜:", len(texts))
db.save_local("solo_leveling_faiss_ko")
db = FAISS.load_local("solo_leveling_faiss_ko", embedding_model, allow_dangerous_deserialization=True)
# ๊ฒ€์ƒ‰ ํ…Œ์ŠคํŠธ
query = "๋งˆ๋‚˜์„์ด ๋ญ์ง€?"
docs = db.similarity_search(query, k=5)
for i, doc in enumerate(docs, 1):
print(f"[{i}] {doc.page_content}")
# ## 4. LLM ๋กœ๋“œ (CPU ์ „์šฉ)
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain_community.llms import HuggingFacePipeline
import torch
# CPU๋กœ ๊ฐ•์ œ
generator = pipeline(
"text-generation",
model="kakaocorp/kanana-nano-2.1b-instruct",
device=-1 # โœ… CPU ์‚ฌ์šฉ
)
embedding_model = HuggingFaceEmbeddings(model_name='jhgan/ko-sroberta-multitask')
vectorstore = FAISS.load_local("solo_leveling_faiss_ko", embedding_model, allow_dangerous_deserialization=True)
model_name = "kakaocorp/kanana-nano-2.1b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32 # โœ… CPU์—์„œ๋Š” float32
).to("cpu") # โœ… CPU ์‚ฌ์šฉ
llm_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=128)
llm = HuggingFacePipeline(pipeline=llm_pipeline)
custom_prompt = PromptTemplate(
input_variables=["context", "question"],
template="๋‹ค์Œ ๋ฌธ๋งฅ์„ ์ฐธ๊ณ ํ•˜์—ฌ ์งˆ๋ฌธ์— ๋‹ตํ•˜์„ธ์š”.\n\n๋ฌธ๋งฅ:\n{context}\n\n์งˆ๋ฌธ:\n{question}\n\n๋‹ต๋ณ€:"
)
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=vectorstore.as_retriever(search_kwargs={"k": 5}),
chain_type="stuff",
return_source_documents=True,
chain_type_kwargs={"prompt": custom_prompt}
)
# ์งˆ๋ฌธ ํ…Œ์ŠคํŠธ
query = "์„ฑ์ง„์šฐ๋Š” ๋ช‡ ๊ธ‰ ํ—Œํ„ฐ์ง€?"
result = qa_chain({"query": query})
print("๋‹ต๋ณ€:", result["result"])
print("\n์ฐธ์กฐ ๋ฌธ์„œ:")
for doc in result["source_documents"]:
print(doc.page_content)
# ## 5. ํ™ฉ๋™์„ ์—ํ”ผ์†Œ๋“œ
choices = [
"1: ํ™ฉ๋™์„ ๋ฌด๋ฆฌ๋ฅผ ๋ชจ๋‘ ์ฒ˜์น˜ํ•œ๋‹ค.",
"2: ์ง„ํ˜ธ๋ฅผ ํฌํ•จํ•œ ํ™ฉ๋™์„ ๋ฌด๋ฆฌ๋ฅผ ๋ชจ๋‘ ์ฒ˜์น˜ํ•œ๋‹ค.",
"3: ์ „๋ถ€ ๊ธฐ์ ˆ ์‹œํ‚ค๊ณ  ์‚ด๋ ค๋‘”๋‹ค.",
"4: ์‹œ์Šคํ…œ์„ ๊ฑฐ๋ถ€ํ•˜๊ณ  ๊ทธ๋ƒฅ ๋„๋ง์นœ๋‹ค."
]
print("\n[์„ ํƒ์ง€]")
for idx, choice in enumerate(choices, start=1):
print(f"{idx}. {choice}")
user_idx = int(input("\n์„ ํƒ ๋ฒˆํ˜ธ ์ž…๋ ฅ: ")) - 1
user_choice = choices[user_idx]
print(f"\n[์‚ฌ์šฉ์ž ์„ ํƒ]: {user_choice}")
result = qa_chain({"query": user_choice})
retrieved_context = "\n".join([doc.page_content for doc in result["source_documents"]])
print("\n[๊ฒ€์ƒ‰๋œ ๊ทผ๊ฑฐ ๋ฌธ์„œ ์˜ˆ์‹œ]")
print(retrieved_context[:600], "...")
prompt = f"""
๋‹น์‹ ์€ ์›นํˆฐ '๋‚˜ ํ˜ผ์ž๋งŒ ๋ ˆ๋ฒจ์—…'์˜ ์„ฑ์ง„์šฐ์ž…๋‹ˆ๋‹ค.
ํ˜„์žฌ ์ƒํ™ฉ:
{retrieved_context}
์‚ฌ์šฉ์ž ์„ ํƒ: {user_choice}
์„ฑ์ง„์šฐ์˜ ๋งํˆฌ๋กœ ๊ฐ„๊ฒฐํ•˜๊ณ  ์ž์—ฐ์Šค๋Ÿฌ์šด ๋Œ€์‚ฌ๋ฅผ 1~2๋ฌธ์žฅ ์ƒ์„ฑํ•˜์„ธ์š”.
์ค‘๋ณต๋œ ๋‚ด์šฉ์ด๋‚˜ ๋น„์Šทํ•œ ๋ฌธ์žฅ์€ ๋งŒ๋“ค์ง€ ๋งˆ์„ธ์š”.
"""
response = generator(
prompt,
max_new_tokens=200,
do_sample=True,
temperature=0.6,
top_p=0.9,
return_full_text=False
)[0]["generated_text"]
print("\n[์„ฑ์ง„์šฐ ์‘๋‹ต]")
print(response)