min24ss commited on
Commit
0fbec2e
ยท
verified ยท
1 Parent(s): fedc021

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -49
app.py CHANGED
@@ -10,11 +10,9 @@ ZIP_NAME = "solo_leveling_faiss_ko.zip"
10
  TARGET_DIR = "solo_leveling_faiss_ko"
11
 
12
  def ensure_faiss_dir() -> str:
13
- """FAISS index๊ฐ€ ์–ด๋”” ์žˆ๋“  ๋กœ๋“œ ๊ฐ€๋Šฅํ•œ ์œ„์น˜๋ฅผ ๋ณด์žฅํ•ฉ๋‹ˆ๋‹ค."""
14
  if os.path.exists(os.path.join(TARGET_DIR, "index.faiss")) and \
15
  os.path.exists(os.path.join(TARGET_DIR, "index.pkl")):
16
  return TARGET_DIR
17
-
18
  if os.path.exists("index.faiss") and os.path.exists("index.pkl"):
19
  os.makedirs(TARGET_DIR, exist_ok=True)
20
  if not os.path.exists(os.path.join(TARGET_DIR, "index.faiss")):
@@ -22,7 +20,6 @@ def ensure_faiss_dir() -> str:
22
  if not os.path.exists(os.path.join(TARGET_DIR, "index.pkl")):
23
  shutil.move("index.pkl", os.path.join(TARGET_DIR, "index.pkl"))
24
  return TARGET_DIR
25
-
26
  if os.path.exists(ZIP_NAME):
27
  with zipfile.ZipFile(ZIP_NAME, 'r') as z:
28
  z.extractall(".")
@@ -36,7 +33,6 @@ def ensure_faiss_dir() -> str:
36
  shutil.copy2(faiss_cand[0], os.path.join(TARGET_DIR, "index.faiss"))
37
  shutil.copy2(pkl_cand[0], os.path.join(TARGET_DIR, "index.pkl"))
38
  return TARGET_DIR
39
-
40
  raise FileNotFoundError("FAISS index files not found (index.faiss / index.pkl).")
41
 
42
  # 0) FAISS ์ธ๋ฑ์Šค ์œ„์น˜ ํ™•๋ณด
@@ -46,17 +42,13 @@ base_dir = ensure_faiss_dir()
46
  embeddings = HuggingFaceEmbeddings(model_name="jhgan/ko-sroberta-multitask")
47
  vectorstore = FAISS.load_local(base_dir, embeddings, allow_dangerous_deserialization=True)
48
 
49
- # 2) ๋ชจ๋ธ ๋กœ๋”ฉ (CPU ํ™˜๊ฒฝ ์•ˆ์ „ ์˜ต์…˜)
50
  model_name = "kakaocorp/kanana-nano-2.1b-instruct"
51
  tokenizer = AutoTokenizer.from_pretrained(model_name)
52
- model = AutoModelForCausalLM.from_pretrained(
53
- model_name,
54
- torch_dtype=torch.float32,
55
- device_map=None
56
- )
57
 
58
  # 3) ํ…์ŠคํŠธ ์ƒ์„ฑ ํŒŒ์ดํ”„๋ผ์ธ
59
- pipe = pipeline(
60
  "text-generation",
61
  model=model,
62
  tokenizer=tokenizer,
@@ -66,7 +58,6 @@ pipe = pipeline(
66
  top_p=0.9,
67
  return_full_text=False
68
  )
69
- lm = pipe
70
 
71
  # ์„ ํƒ์ง€
72
  choices = [
@@ -76,19 +67,15 @@ choices = [
76
  "4: ์‹œ์Šคํ…œ์„ ๊ฑฐ๋ถ€ํ•˜๊ณ  ๊ทธ๋ƒฅ ๋„๋ง์นœ๋‹ค."
77
  ]
78
 
79
- # RAG + ๋Œ€์‚ฌ ์ƒ์„ฑ ํ•จ์ˆ˜
80
  def rag_answer(message, history):
81
  try:
82
  user_idx = int(message.strip()) - 1
83
  user_choice = choices[user_idx]
84
  except:
85
  return "โ—์˜ฌ๋ฐ”๋ฅธ ๋ฒˆํ˜ธ๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”. (์˜ˆ: 1 ~ 4)"
86
-
87
- # FAISS ๊ฒ€์ƒ‰
88
  docs = vectorstore.similarity_search(user_choice, k=3)
89
  context = "\n".join([doc.page_content for doc in docs])
90
- prompt = f"""
91
- ๋‹น์‹ ์€ ์›นํˆฐ '๋‚˜ ํ˜ผ์ž๋งŒ ๋ ˆ๋ฒจ์—…'์˜ ์„ฑ์ง„์šฐ์ž…๋‹ˆ๋‹ค.
92
  ํ˜„์žฌ ์ƒํ™ฉ:
93
  {context}
94
  ์‚ฌ์šฉ์ž ์„ ํƒ: {user_choice}
@@ -97,51 +84,46 @@ def rag_answer(message, history):
97
  """
98
  response = lm(prompt)[0]["generated_text"]
99
  only_dialogue = response.strip().split("\n")[-1]
100
-
101
- # "๋Œ€์‚ฌ:" ์ค‘๋ณต ๋ฐฉ์ง€
102
  if not only_dialogue.startswith("๋Œ€์‚ฌ:"):
103
  only_dialogue = "๋Œ€์‚ฌ: " + only_dialogue
104
-
105
  return only_dialogue
106
 
107
- # CSS - ์ œ๋ชฉ ์˜ค๋ฅธ์ชฝ์— ์•„์ด์ฝ˜
108
  css_code = """
109
  .quest-title {
110
- display: flex;
111
- align-items: center;
112
- gap: 10px;
113
  }
114
  .quest-title img {
115
- width: 60px;
116
- height: auto;
117
  }
 
118
  """
119
 
120
- # Gradio UI
121
- demo = gr.ChatInterface(
122
- fn=rag_answer,
123
- title="""
124
- <div class="quest-title">
125
- [๊ธด๊ธ‰ ํ€˜์ŠคํŠธ: ์ ์„ ์ฒ˜์น˜ํ•˜๋ผ!]
126
- <img src="https://huggingface.co/spaces/min24ss/r-story-selection/resolve/main/system.png">
127
- </div>
128
- """,
129
- description=(
130
- "'ํ”Œ๋ ˆ์ด์–ด'์—๊ฒŒ ์‚ด์˜๋ฅผ ๊ฐ€์ง„ ์ด๋“ค์ด ์ฃผ์œ„์— ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋“ค์„ ๋ชจ๋‘ ์ฒ˜์น˜ํ•˜์—ฌ ์•ˆ์ „์„ ํ™•๋ณดํ•˜์‹ญ์‹œ์˜ค.<br>"
131
- "์ง€์‹œ์— ๋”ฐ๋ฅด์ง€ ์•Š์œผ๋ฉด ๋‹น์‹ ์˜ ์‹ฌ์žฅ์€ ์ •์ง€(!)ํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.<br>"
132
- "์ฒ˜์น˜ํ•ด์•ผ ํ•  ์ ์˜ ์ˆซ์ž: 8๋ช… / ์ฒ˜์น˜ํ•œ ์ ์˜ ์ˆซ์ž: 0๋ช…<br><br>"
133
- "๐Ÿ’ฌ ์„ ํƒ์ง€๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”:<br>"
134
- "1: ํ™ฉ๋™์„ ๋ฌด๋ฆฌ๋ฅผ ๋ชจ๋‘ ์ฒ˜์น˜ํ•œ๋‹ค.<br>"
135
- "2: ํ™ฉ๋™์„ ๋ฌด๋ฆฌ์™€ ์ง„ํ˜ธ๋ฅผ ํฌํ•จํ•˜์—ฌ ๋ชจ๋‘ ์ฒ˜์น˜ํ•œ๋‹ค.<br>"
136
- "3: ์ „๋ถ€ ๊ธฐ์ ˆ ์‹œํ‚ค๊ณ  ์‚ด๋ ค๋‘”๋‹ค.<br>"
137
- "4: ์‹œ์Šคํ…œ์„ ๊ฑฐ๋ถ€ํ•˜๊ณ  ๊ทธ๋ƒฅ ๋„๋ง์นœ๋‹ค."
138
- ),
139
- css=css_code
140
- )
141
 
142
- # ์‹คํ–‰
143
  if __name__ == "__main__":
144
  print("Torch:", torch.__version__)
145
- print("Transformers:", __import__('transformers').__version__)
 
146
  print("LangChain:", langchain.__version__)
147
  demo.launch()
 
10
  TARGET_DIR = "solo_leveling_faiss_ko"
11
 
12
  def ensure_faiss_dir() -> str:
 
13
  if os.path.exists(os.path.join(TARGET_DIR, "index.faiss")) and \
14
  os.path.exists(os.path.join(TARGET_DIR, "index.pkl")):
15
  return TARGET_DIR
 
16
  if os.path.exists("index.faiss") and os.path.exists("index.pkl"):
17
  os.makedirs(TARGET_DIR, exist_ok=True)
18
  if not os.path.exists(os.path.join(TARGET_DIR, "index.faiss")):
 
20
  if not os.path.exists(os.path.join(TARGET_DIR, "index.pkl")):
21
  shutil.move("index.pkl", os.path.join(TARGET_DIR, "index.pkl"))
22
  return TARGET_DIR
 
23
  if os.path.exists(ZIP_NAME):
24
  with zipfile.ZipFile(ZIP_NAME, 'r') as z:
25
  z.extractall(".")
 
33
  shutil.copy2(faiss_cand[0], os.path.join(TARGET_DIR, "index.faiss"))
34
  shutil.copy2(pkl_cand[0], os.path.join(TARGET_DIR, "index.pkl"))
35
  return TARGET_DIR
 
36
  raise FileNotFoundError("FAISS index files not found (index.faiss / index.pkl).")
37
 
38
  # 0) FAISS ์ธ๋ฑ์Šค ์œ„์น˜ ํ™•๋ณด
 
42
  embeddings = HuggingFaceEmbeddings(model_name="jhgan/ko-sroberta-multitask")
43
  vectorstore = FAISS.load_local(base_dir, embeddings, allow_dangerous_deserialization=True)
44
 
45
+ # 2) ๋ชจ๋ธ ๋กœ๋”ฉ (CPU)
46
  model_name = "kakaocorp/kanana-nano-2.1b-instruct"
47
  tokenizer = AutoTokenizer.from_pretrained(model_name)
48
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, device_map=None)
 
 
 
 
49
 
50
  # 3) ํ…์ŠคํŠธ ์ƒ์„ฑ ํŒŒ์ดํ”„๋ผ์ธ
51
+ lm = pipeline(
52
  "text-generation",
53
  model=model,
54
  tokenizer=tokenizer,
 
58
  top_p=0.9,
59
  return_full_text=False
60
  )
 
61
 
62
  # ์„ ํƒ์ง€
63
  choices = [
 
67
  "4: ์‹œ์Šคํ…œ์„ ๊ฑฐ๋ถ€ํ•˜๊ณ  ๊ทธ๋ƒฅ ๋„๋ง์นœ๋‹ค."
68
  ]
69
 
 
70
  def rag_answer(message, history):
71
  try:
72
  user_idx = int(message.strip()) - 1
73
  user_choice = choices[user_idx]
74
  except:
75
  return "โ—์˜ฌ๋ฐ”๋ฅธ ๋ฒˆํ˜ธ๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”. (์˜ˆ: 1 ~ 4)"
 
 
76
  docs = vectorstore.similarity_search(user_choice, k=3)
77
  context = "\n".join([doc.page_content for doc in docs])
78
+ prompt = f"""๋‹น์‹ ์€ ์›นํˆฐ '๋‚˜ ํ˜ผ์ž๋งŒ ๋ ˆ๋ฒจ์—…'์˜ ์„ฑ์ง„์šฐ์ž…๋‹ˆ๋‹ค.
 
79
  ํ˜„์žฌ ์ƒํ™ฉ:
80
  {context}
81
  ์‚ฌ์šฉ์ž ์„ ํƒ: {user_choice}
 
84
  """
85
  response = lm(prompt)[0]["generated_text"]
86
  only_dialogue = response.strip().split("\n")[-1]
 
 
87
  if not only_dialogue.startswith("๋Œ€์‚ฌ:"):
88
  only_dialogue = "๋Œ€์‚ฌ: " + only_dialogue
 
89
  return only_dialogue
90
 
91
+ # ===== UI (๋ณ€๊ฒฝ ์ง€์ ) =====
92
  css_code = """
93
  .quest-title {
94
+ display:flex; align-items:center; gap:10px;
95
+ font-weight:700; font-size:22px; margin-bottom:6px;
 
96
  }
97
  .quest-title img {
98
+ width:72px; height:auto; opacity:.95;
 
99
  }
100
+ .quest-desc { line-height:1.5; margin-bottom:14px; }
101
  """
102
 
103
+ header_html = """
104
+ <div class="quest-title">
105
+ [๊ธด๊ธ‰ ํ€˜์ŠคํŠธ: ์ ์„ ์ฒ˜์น˜ํ•˜๋ผ!]
106
+ <img src="https://huggingface.co/spaces/min24ss/r-story-selection/resolve/main/system.png" alt="quest">
107
+ </div>
108
+ <div class="quest-desc">
109
+ 'ํ”Œ๋ ˆ์ด์–ด'์—๊ฒŒ ์‚ด์˜๋ฅผ ๊ฐ€์ง„ ์ด๋“ค์ด ์ฃผ์œ„์— ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋“ค์„ ๋ชจ๋‘ ์ฒ˜์น˜ํ•˜์—ฌ ์•ˆ์ „์„ ํ™•๋ณดํ•˜์‹ญ์‹œ์˜ค.<br>
110
+ ์ง€์‹œ์— ๋”ฐ๋ฅด์ง€ ์•Š์œผ๋ฉด ๋‹น์‹ ์˜ ์‹ฌ์žฅ์€ ์ •์ง€(!)ํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.<br>
111
+ ์ฒ˜์น˜ํ•ด์•ผ ํ•  ๏ฟฝ๏ฟฝ๏ฟฝ์˜ ์ˆซ์ž: 8๋ช… / ์ฒ˜์น˜ํ•œ ์ ์˜ ์ˆซ์ž: 0๋ช…<br><br>
112
+ ๐Ÿ’ฌ ์„ ํƒ์ง€๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”:<br>
113
+ 1: ํ™ฉ๋™์„ ๋ฌด๋ฆฌ๋ฅผ ๋ชจ๋‘ ์ฒ˜์น˜ํ•œ๋‹ค.<br>
114
+ 2: ํ™ฉ๋™์„ ๋ฌด๋ฆฌ์™€ ์ง„ํ˜ธ๋ฅผ ํฌํ•จํ•˜์—ฌ ๋ชจ๋‘ ์ฒ˜์น˜ํ•œ๋‹ค.<br>
115
+ 3: ์ „๋ถ€ ๊ธฐ์ ˆ ์‹œํ‚ค๊ณ  ์‚ด๋ ค๋‘”๋‹ค.<br>
116
+ 4: ์‹œ์Šคํ…œ์„ ๊ฑฐ๋ถ€ํ•˜๊ณ  ๊ทธ๋ƒฅ ๋„๋ง์นœ๋‹ค.
117
+ </div>
118
+ """
119
+
120
+ with gr.Blocks(css=css_code) as demo:
121
+ gr.HTML(header_html) # โ† ์—ฌ๊ธฐ์„œ HTML ๊ทธ๋Œ€๋กœ ๋ Œ๋”๋ง (์ด๋ฏธ์ง€ ๋ณด์žฅ)
122
+ gr.ChatInterface(fn=rag_answer) # title/description์€ ์“ฐ์ง€ ์•Š์Œ
 
123
 
 
124
  if __name__ == "__main__":
125
  print("Torch:", torch.__version__)
126
+ import transformers as _t
127
+ print("Transformers:", _t.__version__)
128
  print("LangChain:", langchain.__version__)
129
  demo.launch()