min24ss commited on
Commit
b714128
ยท
verified ยท
1 Parent(s): 077d861

Delete r-story-test.py

Browse files
Files changed (1) hide show
  1. r-story-test.py +0 -133
r-story-test.py DELETED
@@ -1,133 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding: utf-8
3
-
4
- # ## 1. tsv full data load
5
- import pandas as pd
6
-
7
- df = pd.read_csv("sl_webtoon_full_data_sequential.tsv", sep="\t")
8
-
9
- print(df.head())
10
- print("์ „์ฒด ๋ฌธ์žฅ ์ˆ˜:", len(df))
11
- print("์ปฌ๋Ÿผ ๋ชฉ๋ก:", df.columns.tolist())
12
-
13
- df['row_id'] = df.index # ์ธ๋ฑ์Šค ์ปฌ๋Ÿผ ์ถ”๊ฐ€
14
- df['text'] = df.apply(
15
- lambda x: f"[{x['์—ํ”ผ์†Œ๋“œ']}] #{x['row_id']} {x['type']} {x['scene_text']}",
16
- axis=1
17
- )
18
- texts = df['text'].tolist()
19
- print("์ตœ์ข… ๋ฌธ์žฅ ์ˆ˜:", len(texts))
20
-
21
- # ## 2. RAG ๋ฌธ์žฅ ์ƒ์„ฑ
22
- print("์˜ˆ์‹œ 5๊ฐœ:")
23
- for t in df['text'].head(5).tolist():
24
- print("-", t)
25
-
26
- # ## 3. ํ•œ๊ตญ์–ด ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋กœ๋“œ, ๋ฒกํ„ฐ db
27
- from langchain.vectorstores import FAISS
28
- from langchain.embeddings import HuggingFaceEmbeddings
29
-
30
- embedding_model = HuggingFaceEmbeddings(model_name='jhgan/ko-sroberta-multitask')
31
-
32
- db = FAISS.from_texts(texts, embedding_model)
33
- print(" ๋ฒกํ„ฐDB ์ƒ์„ฑ ์™„๋ฃŒ. ์ด ๋ฌธ์žฅ ์ˆ˜:", len(texts))
34
- db.save_local("solo_leveling_faiss_ko")
35
-
36
- db = FAISS.load_local("solo_leveling_faiss_ko", embedding_model, allow_dangerous_deserialization=True)
37
-
38
- # ๊ฒ€์ƒ‰ ํ…Œ์ŠคํŠธ
39
- query = "๋งˆ๋‚˜์„์ด ๋ญ์ง€?"
40
- docs = db.similarity_search(query, k=5)
41
- for i, doc in enumerate(docs, 1):
42
- print(f"[{i}] {doc.page_content}")
43
-
44
- # ## 4. LLM ๋กœ๋“œ (CPU ์ „์šฉ)
45
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
46
- from langchain.chains import RetrievalQA
47
- from langchain.prompts import PromptTemplate
48
- from langchain_community.llms import HuggingFacePipeline
49
- import torch
50
-
51
- # CPU๋กœ ๊ฐ•์ œ
52
- generator = pipeline(
53
- "text-generation",
54
- model="kakaocorp/kanana-nano-2.1b-instruct",
55
- device=-1 # โœ… CPU ์‚ฌ์šฉ
56
- )
57
-
58
- embedding_model = HuggingFaceEmbeddings(model_name='jhgan/ko-sroberta-multitask')
59
- vectorstore = FAISS.load_local("solo_leveling_faiss_ko", embedding_model, allow_dangerous_deserialization=True)
60
-
61
- model_name = "kakaocorp/kanana-nano-2.1b-instruct"
62
- tokenizer = AutoTokenizer.from_pretrained(model_name)
63
- model = AutoModelForCausalLM.from_pretrained(
64
- model_name,
65
- torch_dtype=torch.float32 # โœ… CPU์—์„œ๋Š” float32
66
- ).to("cpu") # โœ… CPU ์‚ฌ์šฉ
67
-
68
- llm_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=128)
69
- llm = HuggingFacePipeline(pipeline=llm_pipeline)
70
-
71
- custom_prompt = PromptTemplate(
72
- input_variables=["context", "question"],
73
- template="๋‹ค์Œ ๋ฌธ๋งฅ์„ ์ฐธ๊ณ ํ•˜์—ฌ ์งˆ๋ฌธ์— ๋‹ตํ•˜์„ธ์š”.\n\n๋ฌธ๋งฅ:\n{context}\n\n์งˆ๋ฌธ:\n{question}\n\n๋‹ต๋ณ€:"
74
- )
75
-
76
- qa_chain = RetrievalQA.from_chain_type(
77
- llm=llm,
78
- retriever=vectorstore.as_retriever(search_kwargs={"k": 5}),
79
- chain_type="stuff",
80
- return_source_documents=True,
81
- chain_type_kwargs={"prompt": custom_prompt}
82
- )
83
-
84
- # ์งˆ๋ฌธ ํ…Œ์ŠคํŠธ
85
- query = "์„ฑ์ง„์šฐ๋Š” ๋ช‡ ๊ธ‰ ํ—Œํ„ฐ์ง€?"
86
- result = qa_chain({"query": query})
87
- print("๋‹ต๋ณ€:", result["result"])
88
- print("\n์ฐธ์กฐ ๋ฌธ์„œ:")
89
- for doc in result["source_documents"]:
90
- print(doc.page_content)
91
-
92
- # ## 5. ํ™ฉ๋™์„ ์—ํ”ผ์†Œ๋“œ
93
- choices = [
94
- "1: ํ™ฉ๋™์„ ๋ฌด๋ฆฌ๋ฅผ ๋ชจ๋‘ ์ฒ˜์น˜ํ•œ๋‹ค.",
95
- "2: ์ง„ํ˜ธ๋ฅผ ํฌํ•จํ•œ ํ™ฉ๋™์„ ๋ฌด๋ฆฌ๋ฅผ ๋ชจ๋‘ ์ฒ˜์น˜ํ•œ๋‹ค.",
96
- "3: ์ „๋ถ€ ๊ธฐ์ ˆ ์‹œํ‚ค๊ณ  ์‚ด๋ ค๋‘”๋‹ค.",
97
- "4: ์‹œ์Šคํ…œ์„ ๊ฑฐ๋ถ€ํ•˜๊ณ  ๊ทธ๋ƒฅ ๋„๋ง์นœ๋‹ค."
98
- ]
99
-
100
- print("\n[์„ ํƒ์ง€]")
101
- for idx, choice in enumerate(choices, start=1):
102
- print(f"{idx}. {choice}")
103
-
104
- user_idx = int(input("\n์„ ํƒ ๋ฒˆํ˜ธ ์ž…๋ ฅ: ")) - 1
105
- user_choice = choices[user_idx]
106
- print(f"\n[์‚ฌ์šฉ์ž ์„ ํƒ]: {user_choice}")
107
-
108
- result = qa_chain({"query": user_choice})
109
- retrieved_context = "\n".join([doc.page_content for doc in result["source_documents"]])
110
-
111
- print("\n[๊ฒ€์ƒ‰๋œ ๊ทผ๊ฑฐ ๋ฌธ์„œ ์˜ˆ์‹œ]")
112
- print(retrieved_context[:600], "...")
113
-
114
- prompt = f"""
115
- ๋‹น์‹ ์€ ์›นํˆฐ '๋‚˜ ํ˜ผ์ž๋งŒ ๋ ˆ๋ฒจ์—…'์˜ ์„ฑ์ง„์šฐ์ž…๋‹ˆ๋‹ค.
116
- ํ˜„์žฌ ์ƒํ™ฉ:
117
- {retrieved_context}
118
- ์‚ฌ์šฉ์ž ์„ ํƒ: {user_choice}
119
- ์„ฑ์ง„์šฐ์˜ ๋งํˆฌ๋กœ ๊ฐ„๊ฒฐํ•˜๊ณ  ์ž์—ฐ์Šค๋Ÿฌ์šด ๋Œ€์‚ฌ๋ฅผ 1~2๋ฌธ์žฅ ์ƒ์„ฑํ•˜์„ธ์š”.
120
- ์ค‘๋ณต๋œ ๋‚ด์šฉ์ด๋‚˜ ๋น„์Šทํ•œ ๋ฌธ์žฅ์€ ๋งŒ๋“ค์ง€ ๋งˆ์„ธ์š”.
121
- """
122
-
123
- response = generator(
124
- prompt,
125
- max_new_tokens=200,
126
- do_sample=True,
127
- temperature=0.6,
128
- top_p=0.9,
129
- return_full_text=False
130
- )[0]["generated_text"]
131
-
132
- print("\n[์„ฑ์ง„์šฐ ์‘๋‹ต]")
133
- print(response)