min24ss commited on
Commit
743534b
ยท
verified ยท
1 Parent(s): b714128

Upload r_story_test.py

Browse files
Files changed (1) hide show
  1. r_story_test.py +133 -0
r_story_test.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # ## 1. tsv full data load
5
+ import pandas as pd
6
+
7
+ df = pd.read_csv("sl_webtoon_full_data_sequential.tsv", sep="\t")
8
+
9
+ print(df.head())
10
+ print("์ „์ฒด ๋ฌธ์žฅ ์ˆ˜:", len(df))
11
+ print("์ปฌ๋Ÿผ ๋ชฉ๋ก:", df.columns.tolist())
12
+
13
+ df['row_id'] = df.index # ์ธ๋ฑ์Šค ์ปฌ๋Ÿผ ์ถ”๊ฐ€
14
+ df['text'] = df.apply(
15
+ lambda x: f"[{x['์—ํ”ผ์†Œ๋“œ']}] #{x['row_id']} {x['type']} {x['scene_text']}",
16
+ axis=1
17
+ )
18
+ texts = df['text'].tolist()
19
+ print("์ตœ์ข… ๋ฌธ์žฅ ์ˆ˜:", len(texts))
20
+
21
+ # ## 2. RAG ๋ฌธ์žฅ ์ƒ์„ฑ
22
+ print("์˜ˆ์‹œ 5๊ฐœ:")
23
+ for t in df['text'].head(5).tolist():
24
+ print("-", t)
25
+
26
+ # ## 3. ํ•œ๊ตญ์–ด ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋กœ๋“œ, ๋ฒกํ„ฐ db
27
+ from langchain.vectorstores import FAISS
28
+ from langchain.embeddings import HuggingFaceEmbeddings
29
+
30
+ embedding_model = HuggingFaceEmbeddings(model_name='jhgan/ko-sroberta-multitask')
31
+
32
+ db = FAISS.from_texts(texts, embedding_model)
33
+ print(" ๋ฒกํ„ฐDB ์ƒ์„ฑ ์™„๋ฃŒ. ์ด ๋ฌธ์žฅ ์ˆ˜:", len(texts))
34
+ db.save_local("solo_leveling_faiss_ko")
35
+
36
+ db = FAISS.load_local("solo_leveling_faiss_ko", embedding_model, allow_dangerous_deserialization=True)
37
+
38
+ # ๊ฒ€์ƒ‰ ํ…Œ์ŠคํŠธ
39
+ query = "๋งˆ๋‚˜์„์ด ๋ญ์ง€?"
40
+ docs = db.similarity_search(query, k=5)
41
+ for i, doc in enumerate(docs, 1):
42
+ print(f"[{i}] {doc.page_content}")
43
+
44
+ # ## 4. LLM ๋กœ๋“œ (CPU ์ „์šฉ)
45
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
46
+ from langchain.chains import RetrievalQA
47
+ from langchain.prompts import PromptTemplate
48
+ from langchain_community.llms import HuggingFacePipeline
49
+ import torch
50
+
51
+ # CPU๋กœ ๊ฐ•์ œ
52
+ generator = pipeline(
53
+ "text-generation",
54
+ model="kakaocorp/kanana-nano-2.1b-instruct",
55
+ device=-1 # โœ… CPU ์‚ฌ์šฉ
56
+ )
57
+
58
+ embedding_model = HuggingFaceEmbeddings(model_name='jhgan/ko-sroberta-multitask')
59
+ vectorstore = FAISS.load_local("solo_leveling_faiss_ko", embedding_model, allow_dangerous_deserialization=True)
60
+
61
+ model_name = "kakaocorp/kanana-nano-2.1b-instruct"
62
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
63
+ model = AutoModelForCausalLM.from_pretrained(
64
+ model_name,
65
+ torch_dtype=torch.float32 # โœ… CPU์—์„œ๋Š” float32
66
+ ).to("cpu") # โœ… CPU ์‚ฌ์šฉ
67
+
68
+ llm_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=128)
69
+ llm = HuggingFacePipeline(pipeline=llm_pipeline)
70
+
71
+ custom_prompt = PromptTemplate(
72
+ input_variables=["context", "question"],
73
+ template="๋‹ค์Œ ๋ฌธ๋งฅ์„ ์ฐธ๊ณ ํ•˜์—ฌ ์งˆ๋ฌธ์— ๋‹ตํ•˜์„ธ์š”.\n\n๋ฌธ๋งฅ:\n{context}\n\n์งˆ๋ฌธ:\n{question}\n\n๋‹ต๋ณ€:"
74
+ )
75
+
76
+ qa_chain = RetrievalQA.from_chain_type(
77
+ llm=llm,
78
+ retriever=vectorstore.as_retriever(search_kwargs={"k": 5}),
79
+ chain_type="stuff",
80
+ return_source_documents=True,
81
+ chain_type_kwargs={"prompt": custom_prompt}
82
+ )
83
+
84
+ # ์งˆ๋ฌธ ํ…Œ์ŠคํŠธ
85
+ query = "์„ฑ์ง„์šฐ๋Š” ๋ช‡ ๊ธ‰ ํ—Œํ„ฐ์ง€?"
86
+ result = qa_chain({"query": query})
87
+ print("๋‹ต๋ณ€:", result["result"])
88
+ print("\n์ฐธ์กฐ ๋ฌธ์„œ:")
89
+ for doc in result["source_documents"]:
90
+ print(doc.page_content)
91
+
92
+ # ## 5. ํ™ฉ๋™์„ ์—ํ”ผ์†Œ๋“œ
93
+ choices = [
94
+ "1: ํ™ฉ๋™์„ ๋ฌด๋ฆฌ๋ฅผ ๋ชจ๋‘ ์ฒ˜์น˜ํ•œ๋‹ค.",
95
+ "2: ์ง„ํ˜ธ๋ฅผ ํฌํ•จํ•œ ํ™ฉ๋™์„ ๋ฌด๋ฆฌ๋ฅผ ๋ชจ๋‘ ์ฒ˜์น˜ํ•œ๋‹ค.",
96
+ "3: ์ „๋ถ€ ๊ธฐ์ ˆ ์‹œํ‚ค๊ณ  ์‚ด๋ ค๋‘”๋‹ค.",
97
+ "4: ์‹œ์Šคํ…œ์„ ๊ฑฐ๋ถ€ํ•˜๊ณ  ๊ทธ๋ƒฅ ๋„๋ง์นœ๋‹ค."
98
+ ]
99
+
100
+ print("\n[์„ ํƒ์ง€]")
101
+ for idx, choice in enumerate(choices, start=1):
102
+ print(f"{idx}. {choice}")
103
+
104
+ user_idx = int(input("\n์„ ํƒ ๋ฒˆํ˜ธ ์ž…๋ ฅ: ")) - 1
105
+ user_choice = choices[user_idx]
106
+ print(f"\n[์‚ฌ์šฉ์ž ์„ ํƒ]: {user_choice}")
107
+
108
+ result = qa_chain({"query": user_choice})
109
+ retrieved_context = "\n".join([doc.page_content for doc in result["source_documents"]])
110
+
111
+ print("\n[๊ฒ€์ƒ‰๋œ ๊ทผ๊ฑฐ ๋ฌธ์„œ ์˜ˆ์‹œ]")
112
+ print(retrieved_context[:600], "...")
113
+
114
+ prompt = f"""
115
+ ๋‹น์‹ ์€ ์›นํˆฐ '๋‚˜ ํ˜ผ์ž๋งŒ ๋ ˆ๋ฒจ์—…'์˜ ์„ฑ์ง„์šฐ์ž…๋‹ˆ๋‹ค.
116
+ ํ˜„์žฌ ์ƒํ™ฉ:
117
+ {retrieved_context}
118
+ ์‚ฌ์šฉ์ž ์„ ํƒ: {user_choice}
119
+ ์„ฑ์ง„์šฐ์˜ ๋งํˆฌ๋กœ ๊ฐ„๊ฒฐํ•˜๊ณ  ์ž์—ฐ์Šค๋Ÿฌ์šด ๋Œ€์‚ฌ๋ฅผ 1~2๋ฌธ์žฅ ์ƒ์„ฑํ•˜์„ธ์š”.
120
+ ์ค‘๋ณต๋œ ๋‚ด์šฉ์ด๋‚˜ ๋น„์Šทํ•œ ๋ฌธ์žฅ์€ ๋งŒ๋“ค์ง€ ๋งˆ์„ธ์š”.
121
+ """
122
+
123
+ response = generator(
124
+ prompt,
125
+ max_new_tokens=200,
126
+ do_sample=True,
127
+ temperature=0.6,
128
+ top_p=0.9,
129
+ return_full_text=False
130
+ )[0]["generated_text"]
131
+
132
+ print("\n[์„ฑ์ง„์šฐ ์‘๋‹ต]")
133
+ print(response)