min24ss commited on
Commit
ee55b6f
ยท
verified ยท
1 Parent(s): a7f3205

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -97
app.py CHANGED
@@ -1,118 +1,138 @@
1
- import os
2
- import zipfile
3
- import pandas as pd
4
  import gradio as gr
5
  from langchain_community.vectorstores import FAISS
6
  from langchain_community.embeddings import HuggingFaceEmbeddings
7
- from langchain.chains import RetrievalQA
8
- from langchain.prompts import PromptTemplate
9
- from langchain_community.llms import HuggingFacePipeline
10
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
11
  import torch
12
-
13
- # ====== ZIP ์ž๋™ ํ•ด์ œ ======
14
- zip_path = "solo_leveling_faiss_ko.zip"
15
- extract_dir = "solo_leveling_faiss_ko"
16
-
17
- if os.path.exists(zip_path) and not os.path.exists(extract_dir):
18
- with zipfile.ZipFile(zip_path, 'r') as zip_ref:
19
- zip_ref.extractall(extract_dir)
20
- print(f"[INFO] ์••์ถ• ํ•ด์ œ ์™„๋ฃŒ โ†’ {extract_dir}")
21
-
22
- # ====== TSV ๋ฐ์ดํ„ฐ ๋กœ๋“œ ======
23
- df = pd.read_csv("sl_webtoon_full_data_sequential.tsv", sep="\t")
24
- df['row_id'] = df.index
25
- df['text'] = df.apply(
26
- lambda x: f"[{x['์—ํ”ผ์†Œ๋“œ']}] #{x['row_id']} {x['type']} {x['scene_text']}",
27
- axis=1
28
- )
29
-
30
- # ====== FAISS ์•ˆ์ „ ๋กœ๋“œ ======
31
- embedding_model = HuggingFaceEmbeddings(model_name='jhgan/ko-sroberta-multitask')
32
-
33
- possible_paths = [
34
- extract_dir,
35
- os.path.join(extract_dir, "solo_leveling_faiss_ko"),
36
- os.path.join(extract_dir, "faiss_index")
37
- ]
38
-
39
- load_path = None
40
- for path in possible_paths:
41
- if os.path.exists(os.path.join(path, "index.faiss")):
42
- load_path = path
43
- break
44
-
45
- if load_path:
46
- vectorstore = FAISS.load_local(load_path, embedding_model, allow_dangerous_deserialization=True)
47
- print(f"[INFO] FAISS ์ธ๋ฑ์Šค ๋กœ๋“œ ์™„๋ฃŒ โ†’ {load_path}")
48
- else:
49
- raise FileNotFoundError("FAISS index.faiss ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
50
-
51
- # ====== ๋ชจ๋ธ ๋กœ๋“œ (CPU ์ „์šฉ) ======
 
 
 
52
  model_name = "kakaocorp/kanana-nano-2.1b-instruct"
53
  tokenizer = AutoTokenizer.from_pretrained(model_name)
54
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32).to("cpu")
55
- llm_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=200)
56
- llm = HuggingFacePipeline(pipeline=llm_pipeline)
57
-
58
- # ====== QA ์ฒด์ธ ======
59
- custom_prompt = PromptTemplate(
60
- input_variables=["context", "question"],
61
- template="๋‹ค์Œ ๋ฌธ๋งฅ์„ ์ฐธ๊ณ ํ•˜์—ฌ ์งˆ๋ฌธ์— ๋‹ตํ•˜์„ธ์š”.\n\n๋ฌธ๋งฅ:\n{context}\n\n์งˆ๋ฌธ:\n{question}\n\n๋‹ต๋ณ€:"
62
  )
63
 
64
- qa_chain = RetrievalQA.from_chain_type(
65
- llm=llm,
66
- retriever=vectorstore.as_retriever(search_kwargs={"k": 5}),
67
- chain_type="stuff",
68
- return_source_documents=True,
69
- chain_type_kwargs={"prompt": custom_prompt}
 
 
 
 
70
  )
 
71
 
72
- # ====== ๋Œ€ํ™”ํ˜• ์‘๋‹ต ํ•จ์ˆ˜ ======
73
  choices = [
74
- "1. ํ™ฉ๋™์„ ๋ฌด๋ฆฌ๋ฅผ ๋ชจ๋‘ ์ฒ˜์น˜ํ•œ๋‹ค.",
75
- "2. ์ง„ํ˜ธ๋ฅผ ํฌํ•จํ•œ ํ™ฉ๋™์„ ๋ฌด๋ฆฌ๋ฅผ ๋ชจ๋‘ ์ฒ˜์น˜ํ•œ๋‹ค.",
76
- "3. ์ „๋ถ€ ๊ธฐ์ ˆ ์‹œํ‚ค๊ณ  ์‚ด๋ ค๋‘”๋‹ค.",
77
- "4. ์‹œ์Šคํ…œ์„ ๊ฑฐ๋ถ€ํ•˜๊ณ  ๊ทธ๋ƒฅ ๋„๋ง์นœ๋‹ค."
78
  ]
79
 
80
- def respond(message, history):
 
81
  try:
82
- sel_num = int(message.strip())
83
- if sel_num < 1 or sel_num > len(choices):
84
- return gr.ChatMessage(role="assistant", content="โŒ ์˜ฌ๋ฐ”๋ฅธ ๋ฒˆํ˜ธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”. (1~4)")
85
- except ValueError:
86
- return gr.ChatMessage(role="assistant", content="โŒ ๋ฒˆํ˜ธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”. (์˜ˆ: 1, 2, 3, 4)")
87
-
88
- user_choice = choices[sel_num - 1]
89
- result = qa_chain({"query": user_choice})
90
- retrieved_context = "\n".join([doc.page_content for doc in result["source_documents"]])
91
-
92
  prompt = f"""
93
- ๋‹น์‹ ์€ ์›นํˆฐ '๋‚˜ ํ˜ผ์ž๋งŒ ๋ ˆ๋ฒจ์—…'์˜ ์„ฑ์ง„์šฐ์ž…๋‹ˆ๋‹ค.
94
- ํ˜„์žฌ ์ƒํ™ฉ:
95
- {retrieved_context}
96
- ์‚ฌ์šฉ์ž ์„ ํƒ: {user_choice}
97
- ์„ฑ์ง„์šฐ์˜ ๋งํˆฌ๋กœ ๊ฐ„๊ฒฐํ•˜๊ณ  ์ž์—ฐ์Šค๋Ÿฌ์šด ๋Œ€์‚ฌ๋ฅผ 1~2๋ฌธ์žฅ ์ƒ์„ฑํ•˜์„ธ์š”.
98
- ์ค‘๋ณต๋œ ๋‚ด์šฉ์ด๋‚˜ ๋น„์Šทํ•œ ๋ฌธ์žฅ์€ ๋งŒ๋“ค์ง€ ๋งˆ์„ธ์š”.
99
- """
100
-
101
- response = llm_pipeline(prompt)[0]["generated_text"]
102
-
103
- # ์‚ฌ์šฉ์ž ๋ฉ”์‹œ์ง€(์˜ค๋ฅธ์ชฝ)
104
- user_msg = gr.ChatMessage(role="user", content=f"{sel_num}๋ฒˆ ์„ ํƒ ({user_choice})")
105
- # ์„ฑ์ง„์šฐ ๋ฉ”์‹œ์ง€(์™ผ์ชฝ)
106
- sjw_msg = gr.ChatMessage(role="assistant", content=response)
107
-
108
- return [user_msg, sjw_msg]
109
-
110
-
 
 
 
 
 
 
 
 
111
  demo = gr.ChatInterface(
112
- respond,
113
- title="์„ฑ์ง„์šฐ ์„ ํƒ ์‹œ๋ฎฌ๋ ˆ์ด์…˜ (์นด์นด์˜คํ†ก ์Šคํƒ€์ผ, ์„ฑ์ง„์šฐ ์™ผ์ชฝ)",
114
- description="1~4๋ฒˆ ์ค‘ ํ•˜๋‚˜๋ฅผ ์ž…๋ ฅํ•˜๋ฉด ์„ฑ์ง„์šฐ์˜ ์‘๋‹ต์ด ๋Œ€ํ™” ํ˜•์‹์œผ๋กœ ๋‚˜ํƒ€๋‚ฉ๋‹ˆ๋‹ค."
 
 
 
 
 
 
 
 
 
 
115
  )
116
 
 
117
  if __name__ == "__main__":
 
 
 
118
  demo.launch()
 
1
+ import os, zipfile, shutil, glob
 
 
2
  import gradio as gr
3
  from langchain_community.vectorstores import FAISS
4
  from langchain_community.embeddings import HuggingFaceEmbeddings
 
 
 
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
6
  import torch
7
+ import langchain
8
+
9
+ ZIP_NAME = "solo_leveling_faiss_ko.zip"
10
+ TARGET_DIR = "solo_leveling_faiss_ko"
11
+
12
+ def ensure_faiss_dir() -> str:
13
+ """FAISS index๊ฐ€ ์–ด๋”” ์žˆ๋“  ๋กœ๋“œ ๊ฐ€๋Šฅํ•œ ์œ„์น˜๋ฅผ ๋ณด์žฅํ•ฉ๋‹ˆ๋‹ค."""
14
+ if os.path.exists(os.path.join(TARGET_DIR, "index.faiss")) and \
15
+ os.path.exists(os.path.join(TARGET_DIR, "index.pkl")):
16
+ return TARGET_DIR
17
+
18
+ if os.path.exists("index.faiss") and os.path.exists("index.pkl"):
19
+ os.makedirs(TARGET_DIR, exist_ok=True)
20
+ if not os.path.exists(os.path.join(TARGET_DIR, "index.faiss")):
21
+ shutil.move("index.faiss", os.path.join(TARGET_DIR, "index.faiss"))
22
+ if not os.path.exists(os.path.join(TARGET_DIR, "index.pkl")):
23
+ shutil.move("index.pkl", os.path.join(TARGET_DIR, "index.pkl"))
24
+ return TARGET_DIR
25
+
26
+ if os.path.exists(ZIP_NAME):
27
+ with zipfile.ZipFile(ZIP_NAME, 'r') as z:
28
+ z.extractall(".")
29
+ if os.path.exists(os.path.join(TARGET_DIR, "index.faiss")) and \
30
+ os.path.exists(os.path.join(TARGET_DIR, "index.pkl")):
31
+ return TARGET_DIR
32
+ faiss_cand = glob.glob("**/index.faiss", recursive=True)
33
+ pkl_cand = glob.glob("**/index.pkl", recursive=True)
34
+ if faiss_cand and pkl_cand:
35
+ os.makedirs(TARGET_DIR, exist_ok=True)
36
+ shutil.copy2(faiss_cand[0], os.path.join(TARGET_DIR, "index.faiss"))
37
+ shutil.copy2(pkl_cand[0], os.path.join(TARGET_DIR, "index.pkl"))
38
+ return TARGET_DIR
39
+
40
+ raise FileNotFoundError("FAISS index files not found (index.faiss / index.pkl).")
41
+
42
+ # 0) FAISS ์ธ๋ฑ์Šค ์œ„์น˜ ํ™•๋ณด
43
+ base_dir = ensure_faiss_dir()
44
+
45
+ # 1) ๋ฒกํ„ฐ DB
46
+ embeddings = HuggingFaceEmbeddings(model_name="jhgan/ko-sroberta-multitask")
47
+ vectorstore = FAISS.load_local(base_dir, embeddings, allow_dangerous_deserialization=True)
48
+
49
+ # 2) ๋ชจ๋ธ ๋กœ๋”ฉ (CPU ํ™˜๊ฒฝ ์•ˆ์ „ ์˜ต์…˜)
50
  model_name = "kakaocorp/kanana-nano-2.1b-instruct"
51
  tokenizer = AutoTokenizer.from_pretrained(model_name)
52
+ model = AutoModelForCausalLM.from_pretrained(
53
+ model_name,
54
+ torch_dtype=torch.float32,
55
+ device_map=None
 
 
 
 
56
  )
57
 
58
+ # 3) ํ…์ŠคํŠธ ์ƒ์„ฑ ํŒŒ์ดํ”„๋ผ์ธ
59
+ pipe = pipeline(
60
+ "text-generation",
61
+ model=model,
62
+ tokenizer=tokenizer,
63
+ max_new_tokens=100,
64
+ temperature=0.6,
65
+ do_sample=True,
66
+ top_p=0.9,
67
+ return_full_text=False
68
  )
69
+ lm = pipe
70
 
71
+ # ์„ ํƒ์ง€
72
  choices = [
73
+ "1: ํ™ฉ๋™์„ ๋ฌด๋ฆฌ๋ฅผ ๋ชจ๋‘ ์ฒ˜์น˜ํ•œ๋‹ค.",
74
+ "2: ์ง„ํ˜ธ๋ฅผ ํฌํ•จํ•œ ํ™ฉ๋™์„ ๋ฌด๋ฆฌ๋ฅผ ๋ชจ๋‘ ์ฒ˜์น˜ํ•œ๋‹ค.",
75
+ "3: ์ „๋ถ€ ๊ธฐ์ ˆ ์‹œํ‚ค๊ณ  ์‚ด๋ ค๋‘”๋‹ค.",
76
+ "4: ์‹œ์Šคํ…œ์„ ๊ฑฐ๋ถ€ํ•˜๊ณ  ๊ทธ๋ƒฅ ๋„๋ง์นœ๋‹ค."
77
  ]
78
 
79
+ # RAG + ๋Œ€์‚ฌ ์ƒ์„ฑ ํ•จ์ˆ˜
80
+ def rag_answer(message, history):
81
  try:
82
+ user_idx = int(message.strip()) - 1
83
+ user_choice = choices[user_idx]
84
+ except:
85
+ return "โ—์˜ฌ๋ฐ”๋ฅธ ๋ฒˆํ˜ธ๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”. (์˜ˆ: 1 ~ 4)"
86
+
87
+ # FAISS ๊ฒ€์ƒ‰
88
+ docs = vectorstore.similarity_search(user_choice, k=3)
89
+ context = "\n".join([doc.page_content for doc in docs])
 
 
90
  prompt = f"""
91
+ ๋‹น์‹ ์€ ์›นํˆฐ '๋‚˜ ํ˜ผ์ž๋งŒ ๋ ˆ๋ฒจ์—…'์˜ ์„ฑ์ง„์šฐ์ž…๋‹ˆ๋‹ค.
92
+ ํ˜„์žฌ ์ƒํ™ฉ:
93
+ {context}
94
+ ์‚ฌ์šฉ์ž ์„ ํƒ: {user_choice}
95
+ ์„ฑ์ง„์šฐ์˜ ๋งํˆฌ๋กœ ๊ฐ„๊ฒฐํ•˜๊ณ  ์ž์—ฐ์Šค๋Ÿฌ์šด ๋Œ€์‚ฌ๋ฅผ 1~2๋ฌธ์žฅ ์ƒ์„ฑํ•˜์„ธ์š”.
96
+ ์ค‘๋ณต๋œ ๋‚ด์šฉ์ด๋‚˜ ๋น„์Šทํ•œ ๋ฌธ์žฅ์€ ๋งŒ๋“ค์ง€ ๋งˆ์„ธ์š”.
97
+ """
98
+ response = lm(prompt)[0]["generated_text"]
99
+ only_dialogue = response.strip().split("\n")[-1]
100
+
101
+ # "๋Œ€์‚ฌ:" ์ค‘๋ณต ๋ฐฉ์ง€
102
+ if not only_dialogue.startswith("๋Œ€์‚ฌ:"):
103
+ only_dialogue = "๋Œ€์‚ฌ: " + only_dialogue
104
+
105
+ return only_dialogue
106
+
107
+ # ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง€ CSS
108
+ css_code = """
109
+ body {
110
+ background-image: url('https://huggingface.co/spaces/min24ss/r-story-test/resolve/main/jinwoo.png');
111
+ background-size: cover;
112
+ background-position: center;
113
+ }
114
+ """
115
+
116
+ # Gradio UI
117
  demo = gr.ChatInterface(
118
+ fn=rag_answer,
119
+ title="[๊ธด๊ธ‰ ํ€˜์ŠคํŠธ: ์ ์„ ์ฒ˜์น˜ํ•˜๋ผ!]",
120
+ description=(
121
+ "'ํ”Œ๋ ˆ์ด์–ด'์—๊ฒŒ ์‚ด์˜๋ฅผ ๊ฐ€์ง„ ์ด๋“ค์ด ์ฃผ์œ„์— ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋“ค์„ ๋ชจ๋‘ ์ฒ˜์น˜ํ•˜์—ฌ ์•ˆ์ „์„ ํ™•๋ณดํ•˜์‹ญ์‹œ์˜ค.<br>"
122
+ "์ง€์‹œ์— ๋”ฐ๋ฅด์ง€ ์•Š์œผ๋ฉด ๋‹น์‹ ์˜ ์‹ฌ์žฅ์€ ์ •์ง€(!)ํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.<br>"
123
+ "์ฒ˜์น˜ํ•ด์•ผ ํ•  ์ ์˜ ์ˆซ์ž: 8๋ช… / ์ฒ˜์น˜ํ•œ ์ ์˜ ์ˆซ์ž: 0๋ช…<br><br>"
124
+ "๐Ÿ’ฌ ์„ ํƒ์ง€๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”:<br>"
125
+ "1: ํ™ฉ๋™์„ ๋ฌด๋ฆฌ๋ฅผ ๋ชจ๋‘ ์ฒ˜์น˜ํ•œ๋‹ค.<br>"
126
+ "2: ํ™ฉ๋™์„ ๋ฌด๋ฆฌ์™€ ์ง„ํ˜ธ๋ฅผ ํฌํ•จํ•˜์—ฌ ๋ชจ๋‘ ์ฒ˜์น˜ํ•œ๋‹ค.<br>"
127
+ "3: ์ „๋ถ€ ๊ธฐ์ ˆ ์‹œํ‚ค๊ณ  ์‚ด๋ ค๋‘”๋‹ค.<br>"
128
+ "4: ์‹œ์Šคํ…œ์„ ๊ฑฐ๋ถ€ํ•˜๊ณ  ๊ทธ๋ƒฅ ๋„๋ง์นœ๋‹ค."
129
+ ),
130
+ css=css_code
131
  )
132
 
133
+ # ์‹คํ–‰
134
  if __name__ == "__main__":
135
+ print("Torch:", torch.__version__)
136
+ print("Transformers:", __import__('transformers').__version__)
137
+ print("LangChain:", langchain.__version__)
138
  demo.launch()