openfree commited on
Commit
7402b8f
ยท
verified ยท
1 Parent(s): 4e0a318

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +733 -81
app.py CHANGED
@@ -1,20 +1,369 @@
1
  import re
2
  import threading
 
 
 
 
 
 
 
 
3
 
4
  import gradio as gr
5
- import spaces
6
  import transformers
7
- from transformers import pipeline
8
-
9
- # ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ
10
- model_name = "CohereForAI/c4ai-command-r7b-arabic-02-2025"
11
- if gr.NO_RELOAD:
12
- pipe = pipeline(
13
- "text-generation",
14
- model=model_name,
15
- device_map="auto",
16
- torch_dtype="auto",
17
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  # ์ตœ์ข… ๋‹ต๋ณ€์„ ๊ฐ์ง€ํ•˜๊ธฐ ์œ„ํ•œ ๋งˆ์ปค
20
  ANSWER_MARKER = "**๋‹ต๋ณ€**"
@@ -51,15 +400,50 @@ latex_delimiters = [
51
 
52
 
53
  def reformat_math(text):
54
- """Gradio ๊ตฌ๋ฌธ(Katex)์„ ์‚ฌ์šฉํ•˜๋„๋ก MathJax ๊ตฌ๋ถ„ ๊ธฐํ˜ธ ์ˆ˜์ •.
55
- ์ด๊ฒƒ์€ Gradio์—์„œ ์ˆ˜ํ•™ ๊ณต์‹์„ ํ‘œ์‹œํ•˜๊ธฐ ์œ„ํ•œ ์ž„์‹œ ํ•ด๊ฒฐ์ฑ…์ž…๋‹ˆ๋‹ค. ํ˜„์žฌ๋กœ์„œ๋Š”
56
- ๋‹ค๋ฅธ latex_delimiters๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์˜ˆ์ƒ๋Œ€๋กœ ์ž‘๋™ํ•˜๊ฒŒ ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์ฐพ์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค...
57
- """
58
  text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL)
59
  text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL)
60
  return text
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def user_input(message, history_original, history_thinking):
64
  """์‚ฌ์šฉ์ž ์ž…๋ ฅ์„ ํžˆ์Šคํ† ๋ฆฌ์— ์ถ”๊ฐ€ํ•˜๊ณ  ์ž…๋ ฅ ํ…์ŠคํŠธ ์ƒ์ž ๋น„์šฐ๊ธฐ"""
65
  return "", history_original + [
@@ -84,18 +468,59 @@ def rebuild_messages(history: list):
84
  return messages
85
 
86
 
87
- @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  def bot_original(
89
  history: list,
90
  max_num_tokens: int,
91
  do_sample: bool,
92
  temperature: float,
 
93
  ):
94
  """์›๋ณธ ๋ชจ๋ธ์ด ์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•˜๋„๋ก ํ•˜๊ธฐ (์ถ”๋ก  ๊ณผ์ • ์—†์ด)"""
 
 
 
95
 
96
  # ๋‚˜์ค‘์— ์Šค๋ ˆ๋“œ์—์„œ ํ† ํฐ์„ ์ŠคํŠธ๋ฆผ์œผ๋กœ ๊ฐ€์ ธ์˜ค๊ธฐ ์œ„ํ•จ
97
  streamer = transformers.TextIteratorStreamer(
98
- pipe.tokenizer, # pyright: ignore
99
  skip_special_tokens=True,
100
  skip_prompt=True,
101
  )
@@ -133,26 +558,50 @@ def bot_original(
133
  yield history
134
 
135
 
136
- @spaces.GPU
137
- def bot_thinking(
138
  history: list,
139
  max_num_tokens: int,
140
  final_num_tokens: int,
141
  do_sample: bool,
142
  temperature: float,
 
 
143
  ):
144
- """์ถ”๋ก  ๊ณผ์ •์„ ํฌํ•จํ•˜์—ฌ ๋ชจ๋ธ์ด ์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•˜๋„๋ก ํ•˜๊ธฐ"""
 
 
 
145
 
146
  # ๋‚˜์ค‘์— ์Šค๋ ˆ๋“œ์—์„œ ํ† ํฐ์„ ์ŠคํŠธ๋ฆผ์œผ๋กœ ๊ฐ€์ ธ์˜ค๊ธฐ ์œ„ํ•จ
147
  streamer = transformers.TextIteratorStreamer(
148
- pipe.tokenizer, # pyright: ignore
149
  skip_special_tokens=True,
150
  skip_prompt=True,
151
  )
152
 
153
  # ํ•„์š”ํ•œ ๊ฒฝ์šฐ ์ถ”๋ก ์— ์งˆ๋ฌธ์„ ๋‹ค์‹œ ์‚ฝ์ž…ํ•˜๊ธฐ ์œ„ํ•จ
154
  question = history[-1]["content"]
155
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  # ๋ณด์กฐ์ž ๋ฉ”์‹œ์ง€ ์ค€๋น„
157
  history.append(
158
  gr.ChatMessage(
@@ -165,9 +614,18 @@ def bot_thinking(
165
  # ํ˜„์žฌ ์ฑ„ํŒ…์— ํ‘œ์‹œ๋  ์ถ”๋ก  ๊ณผ์ •
166
  messages = rebuild_messages(history)
167
 
 
 
 
 
 
 
168
  # ์ „์ฒด ์ถ”๋ก  ๊ณผ์ •์„ ์ €์žฅํ•  ๋ณ€์ˆ˜
169
  full_reasoning = ""
170
 
 
 
 
171
  # ์ถ”๋ก  ๋‹จ๊ณ„ ์‹คํ–‰
172
  for i, prepend in enumerate(rethink_prepends):
173
  if i > 0:
@@ -188,18 +646,57 @@ def bot_thinking(
188
 
189
  # ์ƒˆ ๋‚ด์šฉ์œผ๋กœ ํžˆ์Šคํ† ๋ฆฌ ์žฌ๊ตฌ์„ฑ
190
  history[-1].content += prepend.format(question=question)
 
 
191
  for token in streamer:
192
  history[-1].content += token
193
  history[-1].content = reformat_math(history[-1].content)
 
 
194
  yield history
195
  t.join()
196
 
197
  # ๊ฐ ์ถ”๋ก  ๋‹จ๊ณ„์˜ ๊ฒฐ๊ณผ๋ฅผ full_reasoning์— ์ €์žฅ
198
  full_reasoning = history[-1].content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
- # ์ถ”๋ก  ์™„๋ฃŒ, ์ด์ œ ์ตœ์ข… ๋‹ต๋ณ€์„ ์ƒ์„ฑ
 
201
  history[-1].metadata = {"title": "๐Ÿ’ญ ์‚ฌ๊ณ  ๊ณผ์ •", "status": "done"}
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  # ์ถ”๋ก  ๊ณผ์ •์—์„œ ๊ฒฐ๋ก  ๋ถ€๋ถ„์„ ์ถ”์ถœ (๋งˆ์ง€๋ง‰ 1-2 ๋ฌธ๋‹จ ์ •๋„)
204
  reasoning_parts = full_reasoning.split("\n\n")
205
  reasoning_conclusion = "\n\n".join(reasoning_parts[-2:]) if len(reasoning_parts) > 2 else full_reasoning
@@ -230,48 +727,99 @@ def bot_thinking(
230
  t.start()
231
 
232
  # ์ตœ์ข… ๋‹ต๋ณ€ ์ŠคํŠธ๋ฆฌ๋ฐ
 
233
  for token in streamer:
234
  history[-1].content += token
235
  history[-1].content = reformat_math(history[-1].content)
 
236
  yield history
237
  t.join()
 
 
 
 
 
 
 
 
238
 
239
  yield history
240
 
241
 
242
- with gr.Blocks(fill_height=True, title="Vidraft ThinkFlow") as demo:
243
  # ์ œ๋ชฉ๊ณผ ์„ค๋ช…
244
- gr.Markdown("# Vidraft ThinkFlow")
245
- gr.Markdown("### ์ถ”๋ก  ๊ธฐ๋Šฅ์ด ์—†๋Š” LLM ๋ชจ๋ธ์˜ ์ˆ˜์ • ์—†์ด๋„ ์ถ”๋ก  ๊ธฐ๋Šฅ์„ ์ž๋™์œผ๋กœ ์ ์šฉํ•˜๋Š” LLM ์ถ”๋ก  ์ƒ์„ฑ ํ”Œ๋žซํผ")
246
-
247
- with gr.Row(scale=1):
248
- with gr.Column(scale=2):
249
- gr.Markdown("## Before (Original)")
250
- chatbot_original = gr.Chatbot(
251
- scale=1,
252
- type="messages",
253
- latex_delimiters=latex_delimiters,
254
- label="Original Model (No Reasoning)"
255
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
- with gr.Column(scale=2):
258
- gr.Markdown("## After (Thinking)")
259
- chatbot_thinking = gr.Chatbot(
260
- scale=1,
261
- type="messages",
262
- latex_delimiters=latex_delimiters,
263
- label="Model with Reasoning"
 
 
 
 
 
 
 
 
 
 
 
264
  )
265
-
266
- with gr.Row():
267
- # msg ํ…์ŠคํŠธ๋ฐ•์Šค๋ฅผ ๋จผ์ € ์ •์˜
268
- msg = gr.Textbox(
269
- submit_btn=True,
270
- label="",
271
- show_label=False,
272
- placeholder="์—ฌ๊ธฐ์— ์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜์„ธ์š”.",
273
- autofocus=True,
274
- )
275
 
276
  # ์˜ˆ์ œ ์„น์…˜ - msg ๋ณ€์ˆ˜ ์ •์˜ ์ดํ›„์— ๋ฐฐ์น˜
277
  with gr.Accordion("EXAMPLES", open=False):
@@ -285,53 +833,157 @@ with gr.Blocks(fill_height=True, title="Vidraft ThinkFlow") as demo:
285
  inputs=msg
286
  )
287
 
288
- with gr.Row():
289
- with gr.Column():
290
- gr.Markdown("""## ๋งค๊ฐœ๋ณ€์ˆ˜ ์กฐ์ •""")
291
- num_tokens = gr.Slider(
292
- 50,
293
- 4000,
294
- 2000,
295
- step=1,
296
- label="์ถ”๋ก  ๋‹จ๊ณ„๋‹น ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜",
297
- interactive=True,
298
- )
299
- final_num_tokens = gr.Slider(
300
- 50,
301
- 4000,
302
- 2000,
303
- step=1,
304
- label="์ตœ์ข… ๋‹ต๋ณ€์˜ ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜",
305
- interactive=True,
306
- )
307
- do_sample = gr.Checkbox(True, label="์ƒ˜ํ”Œ๋ง ์‚ฌ์šฉ")
308
- temperature = gr.Slider(0.1, 1.0, 0.7, step=0.1, label="์˜จ๋„")
309
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  # ์‚ฌ์šฉ์ž๊ฐ€ ๋ฉ”์‹œ์ง€๋ฅผ ์ œ์ถœํ•˜๋ฉด ๋‘ ๋ด‡์ด ๋™์‹œ์— ์‘๋‹ตํ•ฉ๋‹ˆ๋‹ค
311
  msg.submit(
312
  user_input,
313
  [msg, chatbot_original, chatbot_thinking], # ์ž…๋ ฅ
314
  [msg, chatbot_original, chatbot_thinking], # ์ถœ๋ ฅ
315
  ).then(
316
- bot_original,
317
  [
318
- chatbot_original,
319
  num_tokens,
320
  do_sample,
321
  temperature,
 
322
  ],
323
  chatbot_original, # ์ถœ๋ ฅ์—์„œ ์ƒˆ ํžˆ์Šคํ† ๋ฆฌ ์ €์žฅ
324
  ).then(
325
- bot_thinking,
326
  [
327
  chatbot_thinking,
328
  num_tokens,
329
- final_num_tokens,
330
  do_sample,
331
  temperature,
 
 
332
  ],
333
  chatbot_thinking, # ์ถœ๋ ฅ์—์„œ ์ƒˆ ํžˆ์Šคํ† ๋ฆฌ ์ €์žฅ
 
 
 
 
334
  )
335
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  if __name__ == "__main__":
337
- demo.queue().launch()
 
 
 
 
 
 
 
 
 
1
  import re
2
  import threading
3
+ import time
4
+ import os
5
+ import logging
6
+ from datetime import datetime
7
+ import torch
8
+ import numpy as np
9
+ from typing import List, Optional, Tuple, Dict
10
+ import networkx as nx
11
 
12
  import gradio as gr
 
13
  import transformers
14
+ from transformers import (
15
+ pipeline,
16
+ AutoModelForCausalLM,
17
+ AutoTokenizer,
18
+ BartForConditionalGeneration,
19
+ BartTokenizer,
20
+ BitsAndBytesConfig
21
+ )
22
+
23
+ # ๋กœ๊น… ์„ค์ •
24
+ logging.basicConfig(level=logging.INFO)
25
+ logger = logging.getLogger(__name__)
26
+
27
+ # ===================== RLRetrievalPolicy =====================
28
+ class RLRetrievalPolicy:
29
+ def __init__(self):
30
+ self.policy_data = {}
31
+ self.alpha = 0.5 # ์œ ์‚ฌ๋„ vs. RL ์ ์ˆ˜ ๊ฐ„ ๊ฐ€์ค‘์น˜
32
+
33
+ def update_policy(self, contexts: List[str], reward: float):
34
+ for ctx in contexts:
35
+ if ctx not in self.policy_data:
36
+ self.policy_data[ctx] = 0.0
37
+ self.policy_data[ctx] += reward
38
+
39
+ def re_rank(self, candidates: List[Tuple[float, str]]) -> List[str]:
40
+ reweighted = []
41
+ for sim, txt in candidates:
42
+ rl_score = self.policy_data.get(txt, 0.0)
43
+ reweighted_score = sim * (1 - self.alpha) + rl_score * self.alpha
44
+ reweighted.append((reweighted_score, txt))
45
+ reweighted.sort(key=lambda x: x[0], reverse=True)
46
+ return [t for _, t in reweighted]
47
+
48
+ # ===================== GraphMemory =====================
49
+ class GraphMemory:
50
+ def __init__(self):
51
+ self.graph = nx.DiGraph()
52
+ # ์ˆ˜ํ•™ ๋ฌธ์ œ ํ•ด๊ฒฐ์— ๋„์›€์ด ๋˜๋Š” ๊ธฐ๋ณธ ๋…ธ๋“œ ์ถ”๊ฐ€
53
+ self.add_node("์ˆ˜ํ•™", "์ˆ˜ํ•™ ๋ฌธ์ œ ํ•ด๊ฒฐ์„ ์œ„ํ•œ ์ผ๋ฐ˜์ ์ธ ์ ‘๊ทผ๋ฒ•")
54
+ self.add_node("๋Œ€์ˆ˜ํ•™", "๋ฐฉ์ •์‹, ํ•จ์ˆ˜, ๋น„๋ก€ ๊ด€๊ณ„ ๋“ฑ์„ ๋‹ค๋ฃจ๋Š” ์ˆ˜ํ•™์˜ ํ•œ ๋ถ„์•ผ")
55
+ self.add_node("๊ธฐํ•˜ํ•™", "๊ณต๊ฐ„, ๋„ํ˜•, ๊ฐ๋„ ๋“ฑ์„ ๋‹ค๋ฃจ๋Š” ์ˆ˜ํ•™์˜ ํ•œ ๋ถ„์•ผ")
56
+ self.add_node("์‚ฐ์ˆ ", "๊ธฐ๋ณธ์ ์ธ ์ˆ˜ ์—ฐ์‚ฐ, ๋น„์œจ, ๋ฐฑ๋ถ„์œจ ๋“ฑ์„ ๋‹ค๋ฃจ๋Š” ๋ถ„์•ผ")
57
+ self.add_node("ํ™•๋ฅ ", "์‚ฌ๊ฑด์˜ ๋ฐœ์ƒ ๊ฐ€๋Šฅ์„ฑ์„ ์ธก์ •ํ•˜๋Š” ์ˆ˜ํ•™์˜ ํ•œ ๋ถ„์•ผ")
58
+
59
+ # ๊ด€๊ณ„ ์„ค์ •
60
+ self.add_edge("๋Œ€์ˆ˜ํ•™", "์ˆ˜ํ•™")
61
+ self.add_edge("๊ธฐํ•˜ํ•™", "์ˆ˜ํ•™")
62
+ self.add_edge("์‚ฐ์ˆ ", "์ˆ˜ํ•™")
63
+ self.add_edge("ํ™•๋ฅ ", "์ˆ˜ํ•™")
64
+
65
+ def add_node(self, node_id: str, text: str = ""):
66
+ self.graph.add_node(node_id, text=text)
67
+
68
+ def add_edge(self, src: str, dst: str):
69
+ self.graph.add_edge(src, dst)
70
+
71
+ def get_text_by_node(self, node_id: str) -> str:
72
+ return self.graph.nodes[node_id].get('text', "")
73
+
74
+ def has_node(self, node_id: str) -> bool:
75
+ return node_id in self.graph.nodes
76
+
77
+ def search_nodes(self, keyword: str, max_nodes: int = 3) -> List[str]:
78
+ matches = []
79
+ for n in self.graph.nodes():
80
+ node_text = self.get_text_by_node(n).lower()
81
+ n_lower = n.lower()
82
+ if keyword.lower() in node_text or keyword.lower() in n_lower:
83
+ score = node_text.count(keyword.lower()) + n_lower.count(keyword.lower())
84
+ matches.append((score, n))
85
+ matches.sort(key=lambda x: x[0], reverse=True)
86
+ top_nodes = [m[1] for m in matches[:max_nodes]]
87
+ return top_nodes
88
+
89
+ def get_connected_context(self, start_node: str, steps: int = 1) -> List[str]:
90
+ contexts = []
91
+ visited = set()
92
+ queue = [(start_node, 0)]
93
+ while queue:
94
+ current, depth = queue.pop(0)
95
+ if current not in visited:
96
+ visited.add(current)
97
+ contexts.append(self.get_text_by_node(current))
98
+ if depth < steps:
99
+ for neighbor in self.graph.successors(current):
100
+ queue.append((neighbor, depth + 1))
101
+ for neighbor in self.graph.predecessors(current):
102
+ queue.append((neighbor, depth + 1))
103
+ return contexts
104
+
105
+ # ===================== SimpleSummarizer =====================
106
+ class SimpleSummarizer:
107
+ def __init__(self, model_name="facebook/bart-large-cnn"):
108
+ self.model_name = model_name
109
+ self.model = None
110
+ self.tokenizer = None
111
+
112
+ def load_summarization_model(self):
113
+ if self.model is None:
114
+ try:
115
+ self.tokenizer = BartTokenizer.from_pretrained(self.model_name)
116
+ self.model = BartForConditionalGeneration.from_pretrained(self.model_name)
117
+ if torch.cuda.is_available():
118
+ self.model = self.model.cuda()
119
+ except Exception as e:
120
+ logger.error(f"Error loading summarization model: {str(e)}")
121
+ raise
122
+
123
+ def summarize_text(self, text: str, max_length: int = 100) -> str:
124
+ try:
125
+ self.load_summarization_model()
126
+ inputs = self.tokenizer([text], max_length=1024, return_tensors='pt', truncation=True)
127
+ if torch.cuda.is_available():
128
+ inputs = {k: v.cuda() for k, v in inputs.items()}
129
+
130
+ with torch.no_grad():
131
+ summary_ids = self.model.generate(
132
+ inputs["input_ids"],
133
+ num_beams=4,
134
+ max_length=max_length,
135
+ early_stopping=True
136
+ )
137
+ summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
138
+ return summary
139
+ except Exception as e:
140
+ logger.error(f"Error in summarization: {str(e)}")
141
+ return "์š”์•ฝ์„ ์ƒ์„ฑํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
142
+
143
+ # ===================== SemanticMemory =====================
144
+ class SemanticMemory:
145
+ def __init__(self, max_entries: int = 4000):
146
+ self.memories: List[dict] = []
147
+ self.max_entries = max_entries
148
+ self.rl_policy = RLRetrievalPolicy()
149
+
150
+ def add_memory(self, text: str, embedding: torch.Tensor):
151
+ if len(self.memories) >= self.max_entries:
152
+ self.memories.pop(0)
153
+ self.memories.append({
154
+ 'text': text,
155
+ 'embedding': embedding,
156
+ 'timestamp': time.time()
157
+ })
158
+
159
+ def get_candidates(self, query_embedding: torch.Tensor) -> List[Tuple[float, str]]:
160
+ candidates = []
161
+ for mem in self.memories:
162
+ if mem['embedding'].shape == query_embedding.shape:
163
+ sim = torch.cosine_similarity(
164
+ query_embedding.float(),
165
+ mem['embedding'].float(),
166
+ dim=-1
167
+ )
168
+ candidates.append((sim.item(), mem['text']))
169
+ candidates.sort(key=lambda x: x[0], reverse=True)
170
+ return candidates
171
+
172
+ def get_relevant_context(self, query_embedding: torch.Tensor, top_k: int = 3) -> List[str]:
173
+ candidates = self.get_candidates(query_embedding)
174
+ re_ranked = self.rl_policy.re_rank(candidates)
175
+ return re_ranked[:top_k]
176
+
177
+ def update_retrieval_reward(self, texts: List[str], reward: float):
178
+ self.rl_policy.update_policy(texts, reward)
179
+
180
+ def clear(self):
181
+ self.memories = []
182
+
183
+ # ===================== GenericInferenceBuffer =====================
184
+ MAX_TOKEN_BUFFER = 1024
185
+
186
+ class GenericInferenceBuffer:
187
+ def __init__(self, layer_idx: int, compression_rank: int = 128):
188
+ self.layer_idx = layer_idx
189
+ self.key_buffer: Optional[torch.Tensor] = None
190
+ self.value_buffer: Optional[torch.Tensor] = None
191
+ self.semantic_context: Optional[torch.Tensor] = None
192
+ self.last_update: float = 0
193
+ self.compression_rank = compression_rank
194
+
195
+ def update_buffer(
196
+ self,
197
+ key: torch.Tensor,
198
+ value: torch.Tensor,
199
+ semantic_context: Optional[torch.Tensor] = None
200
+ ):
201
+ try:
202
+ if self.key_buffer is None:
203
+ self.key_buffer = key.detach().clone()
204
+ self.value_buffer = value.detach().clone()
205
+ if semantic_context is not None:
206
+ self.semantic_context = semantic_context.detach().clone()
207
+ else:
208
+ self.key_buffer = torch.cat([self.key_buffer, key.detach()], dim=2)
209
+ self.value_buffer = torch.cat([self.value_buffer, value.detach()], dim=2)
210
+ if semantic_context is not None and self.semantic_context is not None:
211
+ self.semantic_context = torch.cat([self.semantic_context, semantic_context.detach()], dim=0)
212
+
213
+ if self.key_buffer.shape[2] > MAX_TOKEN_BUFFER:
214
+ excess = self.key_buffer.shape[2] - MAX_TOKEN_BUFFER
215
+ self.key_buffer = self.key_buffer[:, :, excess:, :]
216
+ self.value_buffer = self.value_buffer[:, :, excess:, :]
217
+ if self.semantic_context is not None:
218
+ self.semantic_context = self.semantic_context[excess:, :]
219
+
220
+ self.last_update = time.time()
221
+
222
+ except Exception as e:
223
+ logger.error(f"Buffer update error in layer {self.layer_idx}: {str(e)}")
224
+
225
+ def compress_buffer_svd(self):
226
+ if self.key_buffer is None or self.value_buffer is None:
227
+ return
228
+
229
+ try:
230
+ k_shape = self.key_buffer.shape
231
+ v_shape = self.value_buffer.shape
232
+
233
+ k_2d = self.key_buffer.reshape(k_shape[0]*k_shape[1], k_shape[2]*k_shape[3]).float()
234
+ v_2d = self.value_buffer.reshape(v_shape[0]*v_shape[1], v_shape[2]*v_shape[3]).float()
235
+
236
+ device = k_2d.device
237
+ k_2d_cpu = k_2d.cpu()
238
+ v_2d_cpu = v_2d.cpu()
239
+
240
+ U_k, S_k, V_k = torch.linalg.svd(k_2d_cpu, full_matrices=False)
241
+ U_v, S_v, V_v = torch.linalg.svd(v_2d_cpu, full_matrices=False)
242
+ rank_k = min(self.compression_rank, S_k.shape[0])
243
+ rank_v = min(self.compression_rank, S_v.shape[0])
244
+ k_approx = (U_k[:, :rank_k] * S_k[:rank_k]) @ V_k[:rank_k, :]
245
+ v_approx = (U_v[:, :rank_v] * S_v[:rank_v]) @ V_v[:rank_v, :]
246
+
247
+ k_approx = k_approx.to(device)
248
+ v_approx = v_approx.to(device)
249
+
250
+ self.key_buffer = k_approx.reshape(k_shape).type(self.key_buffer.dtype)
251
+ self.value_buffer = v_approx.reshape(v_shape).type(self.value_buffer.dtype)
252
+
253
+ except Exception as e:
254
+ logger.error(f"SVD compression error in layer {self.layer_idx}: {str(e)}")
255
+
256
+ def get_buffer(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
257
+ return self.key_buffer, self.value_buffer, self.semantic_context
258
+
259
+ def clear(self):
260
+ self.key_buffer = None
261
+ self.value_buffer = None
262
+ self.semantic_context = None
263
+ self.last_update = 0
264
+
265
+ # ===================== InferenceBufferManager =====================
266
+ class InferenceBufferManager:
267
+ def __init__(self, num_layers: int, hidden_size: int):
268
+ self.num_layers = num_layers
269
+ self.hidden_size = hidden_size
270
+ self.layer_buffers = [
271
+ GenericInferenceBuffer(i, compression_rank=128) for i in range(num_layers)
272
+ ]
273
+ self.semantic_memory = SemanticMemory()
274
+ self.graph_memory = GraphMemory()
275
+ self.summarizer = SimpleSummarizer()
276
+ self.summarize_threshold = 1500
277
+ self.generated_tokens_count = 0
278
+ self.compression_interval = 512
279
+ self.token_count_since_compress = 0
280
+
281
+ def _compute_semantic_embedding(self, key: Optional[torch.Tensor], value: Optional[torch.Tensor]) -> torch.Tensor:
282
+ device = "cuda" if torch.cuda.is_available() else "cpu"
283
+ if key is None or value is None:
284
+ return torch.zeros((1, self.hidden_size), dtype=torch.float32, device=device)
285
+ combined = key * value
286
+ combined = combined.mean(dim=2)
287
+ combined = combined.reshape(combined.shape[0], -1)
288
+ combined = torch.nn.functional.normalize(combined, dim=-1)
289
+ return combined
290
+
291
+ def update_buffer(self, layer_outputs, current_tokens: List[int], semantic_context: torch.Tensor, tokenizer):
292
+ try:
293
+ if hasattr(layer_outputs, 'past_key_values'):
294
+ for layer_idx, past_kv in enumerate(layer_outputs.past_key_values):
295
+ if isinstance(past_kv, tuple) and len(past_kv) == 2:
296
+ key, value = past_kv
297
+ if key is not None and value is not None:
298
+ self.layer_buffers[layer_idx].update_buffer(
299
+ key.detach(),
300
+ value.detach(),
301
+ semantic_context
302
+ )
303
+ self.generated_tokens_count += len(current_tokens)
304
+ self.token_count_since_compress += len(current_tokens)
305
+
306
+ if self.token_count_since_compress >= self.compression_interval:
307
+ self.compress_all_buffers()
308
+ self.token_count_since_compress = 0
309
+ except Exception as e:
310
+ logger.error(f"Buffer update error: {str(e)}")
311
+
312
+ def compress_all_buffers(self):
313
+ for buf in self.layer_buffers:
314
+ buf.compress_buffer_svd()
315
+
316
+ def finalize_semantic_memory(self, tokenizer, generated_tokens: List[int]):
317
+ if self.layer_buffers and len(self.layer_buffers) > 0 and self.layer_buffers[-1].key_buffer is not None:
318
+ text_chunk = tokenizer.decode(generated_tokens, skip_special_tokens=True)
319
+ key_buffer = self.layer_buffers[-1].key_buffer
320
+ value_buffer = self.layer_buffers[-1].value_buffer
321
+ embedding = self._compute_semantic_embedding(key_buffer, value_buffer)
322
+ self.semantic_memory.add_memory(text_chunk, embedding)
323
+
324
+ def get_relevant_context(self, query_embedding: torch.Tensor, top_k: int = 3) -> List[str]:
325
+ candidates_sem = self.semantic_memory.get_candidates(query_embedding)
326
+
327
+ # ํ‚ค์›Œ๋“œ ์ถ”์ถœ (๊ฐ„๋‹จํ•œ ๊ตฌํ˜„)
328
+ possible_keywords = ["์ˆ˜ํ•™", "๋Œ€์ˆ˜ํ•™", "๊ธฐํ•˜ํ•™", "์‚ฐ์ˆ ", "ํ™•๋ฅ "]
329
+ text_candidates = []
330
+ for kw in possible_keywords:
331
+ nodes = self.graph_memory.search_nodes(kw)
332
+ for n in nodes:
333
+ context_list = self.graph_memory.get_connected_context(n, steps=1)
334
+ cscore = 1.0
335
+ for ctxt in context_list:
336
+ text_candidates.append((cscore, ctxt))
337
+
338
+ merged_candidates = candidates_sem + text_candidates
339
+ re_ranked = self.semantic_memory.rl_policy.re_rank(merged_candidates)
340
+ return re_ranked[:top_k]
341
+
342
+ def update_retrieval_reward(self, contexts: List[str], reward: float):
343
+ self.semantic_memory.update_retrieval_reward(contexts, reward)
344
+
345
+ def maybe_summarize_memory(self):
346
+ if self.generated_tokens_count < self.summarize_threshold:
347
+ return
348
+
349
+ all_text = "\n".join([m['text'] for m in self.semantic_memory.memories])
350
+ if len(all_text) < 300:
351
+ return
352
+
353
+ summary = self.summarizer.summarize_text(all_text, max_length=120)
354
+ device = "cuda" if torch.cuda.is_available() else "cpu"
355
+ summary_embedding = torch.zeros((1, self.hidden_size), dtype=torch.float32, device=device)
356
+
357
+ self.semantic_memory.clear()
358
+ self.semantic_memory.add_memory(summary, summary_embedding)
359
+ self.generated_tokens_count = 0
360
+
361
+ def clear(self):
362
+ for layer in self.layer_buffers:
363
+ layer.clear()
364
+ self.semantic_memory.clear()
365
+
366
+ # ===================== Enhanced ThinkFlow Implementation =====================
367
 
368
  # ์ตœ์ข… ๋‹ต๋ณ€์„ ๊ฐ์ง€ํ•˜๊ธฐ ์œ„ํ•œ ๋งˆ์ปค
369
  ANSWER_MARKER = "**๋‹ต๋ณ€**"
 
400
 
401
 
402
  def reformat_math(text):
403
+ """Gradio ๊ตฌ๋ฌธ(Katex)์„ ์‚ฌ์šฉํ•˜๋„๋ก MathJax ๊ตฌ๋ถ„ ๊ธฐํ˜ธ ์ˆ˜์ •."""
 
 
 
404
  text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL)
405
  text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL)
406
  return text
407
 
408
 
409
+ def extract_keywords(text: str) -> List[str]:
410
+ """ํ…์ŠคํŠธ์—์„œ ๊ฐ„๋‹จํ•œ ํ‚ค์›Œ๋“œ ์ถ”์ถœ ํ•จ์ˆ˜"""
411
+ # ๊ฐ„๋‹จํ•œ ๊ตฌํ˜„ - ์‹ค์ œ๋กœ๋Š” ๋” ๋ณต์žกํ•œ NLP ๊ธฐ๋ฒ•์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Œ
412
+ common_math_keywords = [
413
+ "์ˆ˜ํ•™", "๋Œ€์ˆ˜ํ•™", "๊ธฐํ•˜ํ•™", "์‚ฐ์ˆ ", "ํ™•๋ฅ ", "๊ณต์‹", "๋ฐฉ์ •์‹",
414
+ "ํ•จ์ˆ˜", "์ ๋ถ„", "๋ฏธ๋ถ„", "๊ธฐํ•˜", "์‚ผ๊ฐํ˜•", "์›", "๊ฐ๋„", "๋น„์œจ",
415
+ "๋น„๋ก€", "ํ‰๊ท ", "๋ถ„์‚ฐ", "ํ‘œ์ค€ํŽธ์ฐจ"
416
+ ]
417
+
418
+ keywords = []
419
+ for kw in common_math_keywords:
420
+ if kw in text:
421
+ keywords.append(kw)
422
+
423
+ return keywords[:5] # ์ตœ๋Œ€ 5๊ฐœ ํ‚ค์›Œ๋“œ๋งŒ ๋ฐ˜ํ™˜
424
+
425
+
426
+ def get_embedding_for_text(text: str, hidden_size: int = 768) -> torch.Tensor:
427
+ """
428
+ ํ…์ŠคํŠธ๋ฅผ ์œ„ํ•œ ์ž„์‹œ ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ ํ•จ์ˆ˜
429
+ ์‹ค์ œ ๊ตฌํ˜„์—์„œ๋Š” ์ ์ ˆํ•œ ์–ธ์–ด ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•ด์•ผ ํ•จ
430
+ """
431
+ # ์ž„์‹œ ๊ตฌํ˜„: ํ…์ŠคํŠธ์˜ ํ•ด์‹œ ๊ฐ’์„ ๊ธฐ๋ฐ˜์œผ๋กœ ํ•œ ์ž„๋ฒ ๋”ฉ
432
+ device = "cuda" if torch.cuda.is_available() else "cpu"
433
+ hash_val = hash(text)
434
+ np.random.seed(hash_val)
435
+
436
+ # ์ž„์˜์˜ ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ
437
+ embedding = np.random.rand(1, hidden_size).astype(np.float32)
438
+
439
+ # ์ •๊ทœํ™”
440
+ norm = np.linalg.norm(embedding)
441
+ if norm > 0:
442
+ embedding = embedding / norm
443
+
444
+ return torch.tensor(embedding, device=device)
445
+
446
+
447
  def user_input(message, history_original, history_thinking):
448
  """์‚ฌ์šฉ์ž ์ž…๋ ฅ์„ ํžˆ์Šคํ† ๋ฆฌ์— ์ถ”๊ฐ€ํ•˜๊ณ  ์ž…๋ ฅ ํ…์ŠคํŠธ ์ƒ์ž ๋น„์šฐ๊ธฐ"""
449
  return "", history_original + [
 
468
  return messages
469
 
470
 
471
+ # ๋ชจ๋ธ๊ณผ ๋ฒ„ํผ ๋งค๋‹ˆ์ € ์ดˆ๊ธฐํ™” ํ•จ์ˆ˜
472
+ def initialize_model_and_manager(model_name):
473
+ """๋ชจ๋ธ๊ณผ ๋ฒ„ํผ ๋งค๋‹ˆ์ € ์ดˆ๊ธฐํ™” ํ•จ์ˆ˜"""
474
+ try:
475
+ pipe = pipeline(
476
+ "text-generation",
477
+ model=model_name,
478
+ device_map="auto",
479
+ torch_dtype="auto",
480
+ )
481
+
482
+ # ๋ชจ๋ธ ๊ตฌ์„ฑ์—์„œ ๋ ˆ์ด์–ด ๋ฐ ์€๋‹‰ ํฌ๊ธฐ ์ •๋ณด ์ถ”์ถœ
483
+ config = pipe.model.config
484
+ if hasattr(config, "n_layer"):
485
+ num_layers = config.n_layer
486
+ elif hasattr(config, "num_layers"):
487
+ num_layers = config.num_layers
488
+ elif hasattr(config, "num_hidden_layers"):
489
+ num_layers = config.num_hidden_layers
490
+ else:
491
+ num_layers = 12 # ๊ธฐ๋ณธ๊ฐ’
492
+
493
+ if hasattr(config, "n_embd"):
494
+ hidden_size = config.n_embd
495
+ elif hasattr(config, "hidden_size"):
496
+ hidden_size = config.hidden_size
497
+ else:
498
+ hidden_size = 768 # ๊ธฐ๋ณธ๊ฐ’
499
+
500
+ # ๋ฒ„ํผ ๋งค๋‹ˆ์ € ์ดˆ๊ธฐํ™”
501
+ buffer_manager = InferenceBufferManager(num_layers, hidden_size)
502
+
503
+ return pipe, buffer_manager
504
+ except Exception as e:
505
+ logger.error(f"๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์˜ค๋ฅ˜: {str(e)}")
506
+ raise
507
+
508
+
509
  def bot_original(
510
  history: list,
511
  max_num_tokens: int,
512
  do_sample: bool,
513
  temperature: float,
514
+ pipe=None
515
  ):
516
  """์›๋ณธ ๋ชจ๋ธ์ด ์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•˜๋„๋ก ํ•˜๊ธฐ (์ถ”๋ก  ๊ณผ์ • ์—†์ด)"""
517
+ if pipe is None:
518
+ # ์ด ๋ถ€๋ถ„์€ ์‹ค์ œ ๊ตฌํ˜„์—์„œ๋Š” ์ „์—ญ ๋ณ€์ˆ˜๋‚˜ ์„ธ์…˜ ์ƒํƒœ๋กœ ๊ด€๋ฆฌํ•ด์•ผ ํ•จ
519
+ return history
520
 
521
  # ๋‚˜์ค‘์— ์Šค๋ ˆ๋“œ์—์„œ ํ† ํฐ์„ ์ŠคํŠธ๋ฆผ์œผ๋กœ ๊ฐ€์ ธ์˜ค๊ธฐ ์œ„ํ•จ
522
  streamer = transformers.TextIteratorStreamer(
523
+ pipe.tokenizer,
524
  skip_special_tokens=True,
525
  skip_prompt=True,
526
  )
 
558
  yield history
559
 
560
 
561
+ def bot_thinking_enhanced(
 
562
  history: list,
563
  max_num_tokens: int,
564
  final_num_tokens: int,
565
  do_sample: bool,
566
  temperature: float,
567
+ pipe=None,
568
+ buffer_manager=None
569
  ):
570
+ """์ถ”๋ก  ๊ณผ์ •์„ ํฌํ•จํ•˜์—ฌ ๋ชจ๋ธ์ด ์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•˜๋„๋ก ํ•˜๊ธฐ - DeepSeek ๊ธฐ๋Šฅ ํ†ตํ•ฉ"""
571
+ if pipe is None or buffer_manager is None:
572
+ # ์ด ๋ถ€๋ถ„์€ ์‹ค์ œ ๊ตฌํ˜„์—์„œ๋Š” ์ „์—ญ ๋ณ€์ˆ˜๋‚˜ ์„ธ์…˜ ์ƒํƒœ๋กœ ๊ด€๋ฆฌํ•ด์•ผ ํ•จ
573
+ return history
574
 
575
  # ๋‚˜์ค‘์— ์Šค๋ ˆ๋“œ์—์„œ ํ† ํฐ์„ ์ŠคํŠธ๋ฆผ์œผ๋กœ ๊ฐ€์ ธ์˜ค๊ธฐ ์œ„ํ•จ
576
  streamer = transformers.TextIteratorStreamer(
577
+ pipe.tokenizer,
578
  skip_special_tokens=True,
579
  skip_prompt=True,
580
  )
581
 
582
  # ํ•„์š”ํ•œ ๊ฒฝ์šฐ ์ถ”๋ก ์— ์งˆ๋ฌธ์„ ๋‹ค์‹œ ์‚ฝ์ž…ํ•˜๊ธฐ ์œ„ํ•จ
583
  question = history[-1]["content"]
584
+
585
+ # ์ฟผ๋ฆฌ ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ
586
+ query_embedding = get_embedding_for_text(question, buffer_manager.hidden_size)
587
+
588
+ # ๊ด€๋ จ ์ปจํ…์ŠคํŠธ ๊ฒ€์ƒ‰
589
+ relevant_contexts = buffer_manager.get_relevant_context(query_embedding, top_k=3)
590
+
591
+ # ํ‚ค์›Œ๋“œ ์ถ”์ถœ ๋ฐ ๊ทธ๋ž˜ํ”„ ๋ฉ”๋ชจ๋ฆฌ์—์„œ ์ปจํ…์ŠคํŠธ ๊ฐ€์ ธ์˜ค๊ธฐ
592
+ keywords = extract_keywords(question)
593
+ graph_contexts = []
594
+ for keyword in keywords:
595
+ nodes = buffer_manager.graph_memory.search_nodes(keyword)
596
+ for node in nodes:
597
+ contexts = buffer_manager.graph_memory.get_connected_context(node)
598
+ graph_contexts.extend(contexts)
599
+
600
+ # ๋ชจ๋“  ์ปจํ…์ŠคํŠธ ๋ณ‘ํ•ฉ
601
+ all_contexts = relevant_contexts + graph_contexts
602
+ all_contexts = list(set(all_contexts)) # ์ค‘๋ณต ์ œ๊ฑฐ
603
+ all_contexts = all_contexts[:5] # ์ตœ๋Œ€ 5๊ฐœ ์ปจํ…์ŠคํŠธ๋กœ ์ œํ•œ
604
+
605
  # ๋ณด์กฐ์ž ๋ฉ”์‹œ์ง€ ์ค€๋น„
606
  history.append(
607
  gr.ChatMessage(
 
614
  # ํ˜„์žฌ ์ฑ„ํŒ…์— ํ‘œ์‹œ๋  ์ถ”๋ก  ๊ณผ์ •
615
  messages = rebuild_messages(history)
616
 
617
+ # ๊ด€๋ จ ์ปจํ…์ŠคํŠธ๊ฐ€ ์žˆ๋‹ค๋ฉด ๋ฉ”์‹œ์ง€์— ์ถ”๊ฐ€
618
+ if all_contexts:
619
+ context_str = "\n\n๊ด€๋ จ ์ปจํ…์ŠคํŠธ:\n" + "\n".join(all_contexts)
620
+ messages[-1]["content"] += context_str
621
+ history[-1].content += context_str
622
+
623
  # ์ „์ฒด ์ถ”๋ก  ๊ณผ์ •์„ ์ €์žฅํ•  ๋ณ€์ˆ˜
624
  full_reasoning = ""
625
 
626
+ # ์ƒ์„ฑ๋œ ํ† ํฐ ์ถ”์ ์„ ์œ„ํ•œ ๋ณ€์ˆ˜
627
+ generated_tokens = []
628
+
629
  # ์ถ”๋ก  ๋‹จ๊ณ„ ์‹คํ–‰
630
  for i, prepend in enumerate(rethink_prepends):
631
  if i > 0:
 
646
 
647
  # ์ƒˆ ๋‚ด์šฉ์œผ๋กœ ํžˆ์Šคํ† ๋ฆฌ ์žฌ๊ตฌ์„ฑ
648
  history[-1].content += prepend.format(question=question)
649
+ step_tokens = []
650
+
651
  for token in streamer:
652
  history[-1].content += token
653
  history[-1].content = reformat_math(history[-1].content)
654
+ step_tokens.append(token)
655
+ generated_tokens.append(token)
656
  yield history
657
  t.join()
658
 
659
  # ๊ฐ ์ถ”๋ก  ๋‹จ๊ณ„์˜ ๊ฒฐ๊ณผ๋ฅผ full_reasoning์— ์ €์žฅ
660
  full_reasoning = history[-1].content
661
+
662
+ # ์ถ”๋ก ์ด ๊ธธ์–ด์ง€๋ฉด ์ค‘๊ฐ„ ์š”์•ฝ ์ƒ์„ฑ
663
+ if i > 0 and i % 3 == 0 and len(generated_tokens) > 500:
664
+ try:
665
+ summary = buffer_manager.summarizer.summarize_text(full_reasoning, max_length=150)
666
+ summary_text = f"\n\n**์ค‘๊ฐ„ ์š”์•ฝ:**\n{summary}\n\n"
667
+ history[-1].content += summary_text
668
+ messages[-1]["content"] += summary_text
669
+ yield history
670
+ except Exception as e:
671
+ logger.error(f"์š”์•ฝ ์ƒ์„ฑ ์˜ค๋ฅ˜: {str(e)}")
672
+
673
+ # KV ์บ์‹œ ์••์ถ•
674
+ if i > 0 and i % 2 == 0:
675
+ buffer_manager.compress_all_buffers()
676
+
677
+ # ์‹œ๋งจํ‹ฑ ์ปจํ…์ŠคํŠธ ์—…๋ฐ์ดํŠธ
678
+ step_text = "".join(step_tokens)
679
+ step_embedding = get_embedding_for_text(step_text, buffer_manager.hidden_size)
680
+ buffer_manager.semantic_memory.add_memory(step_text, step_embedding)
681
+
682
 
683
+
684
+ # ์ถ”๋ก  ์™„๋ฃŒ, ์ด์ œ ์ตœ์ข… ๋‹ต๋ณ€์„ ์ƒ์„ฑ
685
  history[-1].metadata = {"title": "๐Ÿ’ญ ์‚ฌ๊ณ  ๊ณผ์ •", "status": "done"}
686
 
687
+ # ์ถ”๋ก  ๊ณผ์ •์„ ์‹œ๋งจํ‹ฑ ๋ฉ”๋ชจ๋ฆฌ์™€ ๊ทธ๋ž˜ํ”„ ๋ฉ”๋ชจ๋ฆฌ์— ์ €์žฅ
688
+ full_embedding = get_embedding_for_text(full_reasoning, buffer_manager.hidden_size)
689
+ buffer_manager.semantic_memory.add_memory(full_reasoning, full_embedding)
690
+
691
+ # ํ‚ค์›Œ๋“œ์— ๋Œ€ํ•œ ๊ทธ๋ž˜ํ”„ ๋ฉ”๋ชจ๋ฆฌ ์—…๋ฐ์ดํŠธ
692
+ for keyword in keywords:
693
+ if not buffer_manager.graph_memory.has_node(keyword):
694
+ buffer_manager.graph_memory.add_node(keyword, f"{keyword}์— ๊ด€ํ•œ ๊ฐœ๋…: ์ด ์ฃผ์ œ์— ๋Œ€ํ•œ ์ถ”๋ก ์„ ์ˆ˜ํ–‰ํ–ˆ์Šต๋‹ˆ๋‹ค.")
695
+ # ๊ด€๋ จ ๋…ธ๋“œ์™€ ์—ฐ๊ฒฐ
696
+ for related_kw in keywords:
697
+ if related_kw != keyword and buffer_manager.graph_memory.has_node(related_kw):
698
+ buffer_manager.graph_memory.add_edge(keyword, related_kw)
699
+
700
  # ์ถ”๋ก  ๊ณผ์ •์—์„œ ๊ฒฐ๋ก  ๋ถ€๋ถ„์„ ์ถ”์ถœ (๋งˆ์ง€๋ง‰ 1-2 ๋ฌธ๋‹จ ์ •๋„)
701
  reasoning_parts = full_reasoning.split("\n\n")
702
  reasoning_conclusion = "\n\n".join(reasoning_parts[-2:]) if len(reasoning_parts) > 2 else full_reasoning
 
727
  t.start()
728
 
729
  # ์ตœ์ข… ๋‹ต๋ณ€ ์ŠคํŠธ๋ฆฌ๋ฐ
730
+ final_tokens = []
731
  for token in streamer:
732
  history[-1].content += token
733
  history[-1].content = reformat_math(history[-1].content)
734
+ final_tokens.append(token)
735
  yield history
736
  t.join()
737
+
738
+ # ์ตœ์ข… ๋‹ต๋ณ€์„ ์‹œ๋งจํ‹ฑ ๋ฉ”๋ชจ๋ฆฌ์— ์ €์žฅ
739
+ final_text = "".join(final_tokens)
740
+ final_embedding = get_embedding_for_text(final_text, buffer_manager.hidden_size)
741
+ buffer_manager.semantic_memory.add_memory(final_text, final_embedding)
742
+
743
+ # ์ฃผ๊ธฐ์  ๋ฉ”๋ชจ๋ฆฌ ์š”์•ฝ ์ฒดํฌ
744
+ buffer_manager.maybe_summarize_memory()
745
 
746
  yield history
747
 
748
 
749
+ with gr.Blocks(fill_height=True, title="Enhanced ThinkFlow") as demo:
750
  # ์ œ๋ชฉ๊ณผ ์„ค๋ช…
751
+ gr.Markdown("# Enhanced ThinkFlow with DeepSeek Features")
752
+ gr.Markdown("### ์‹œ๋งจํ‹ฑ ๋ฉ”๋ชจ๋ฆฌ, ๊ทธ๋ž˜ํ”„ ๋ฉ”๋ชจ๋ฆฌ, ๋ฐ KV ์บ์‹œ ์••์ถ•์„ ํ†ตํ•ด ํ–ฅ์ƒ๋œ LLM ์ถ”๋ก  ์ƒ์„ฑ ํ”Œ๋žซํผ")
753
+
754
+ # ๋ชจ๋ธ ๋ฐ ๋ฒ„ํผ ๋งค๋‹ˆ์ € ์ดˆ๊ธฐํ™” (์‹ค์ œ ๊ตฌํ˜„์—์„œ๋Š” ์„ธ์…˜ ์ƒํƒœ๋กœ ๊ด€๋ฆฌ)
755
+ model_name = "CohereForAI/c4ai-command-r7b-arabic-02-2025"
756
+
757
+ # ์„ธ์…˜ ๋ณ€์ˆ˜ (์‹ค์ œ ๊ตฌํ˜„์—์„œ๋Š” gr.State() ์‚ฌ์šฉ)
758
+ pipe = None
759
+ buffer_manager = None
760
+ current_contexts = []
761
+
762
+ # ํƒญ ์ธํ„ฐํŽ˜์ด์Šค
763
+ with gr.Tabs() as tabs:
764
+ # ์ฑ„ํŒ… ํƒญ
765
+ with gr.TabItem("ํ†ตํ•ฉ ์ถ”๋ก  ์ธํ„ฐํŽ˜์ด์Šค"):
766
+ with gr.Row(scale=1):
767
+ with gr.Column(scale=2):
768
+ gr.Markdown("## Before (Original)")
769
+ chatbot_original = gr.Chatbot(
770
+ scale=1,
771
+ type="messages",
772
+ latex_delimiters=latex_delimiters,
773
+ label="Original Model (No Reasoning)"
774
+ )
775
+
776
+ with gr.Column(scale=2):
777
+ gr.Markdown("## After (Enhanced Thinking)")
778
+ chatbot_thinking = gr.Chatbot(
779
+ scale=1,
780
+ type="messages",
781
+ latex_delimiters=latex_delimiters,
782
+ label="Model with Enhanced Reasoning"
783
+ )
784
+
785
+ with gr.Row():
786
+ # msg ํ…์ŠคํŠธ๋ฐ•์Šค๋ฅผ ๋จผ์ € ์ •์˜
787
+ msg = gr.Textbox(
788
+ submit_btn=True,
789
+ label="",
790
+ show_label=False,
791
+ placeholder="์—ฌ๊ธฐ์— ์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜์„ธ์š”.",
792
+ autofocus=True,
793
+ )
794
+
795
+ # ํ”ผ๋“œ๋ฐฑ ๋ฒ„ํŠผ
796
+ with gr.Row():
797
+ with gr.Column(scale=1):
798
+ feedback_btn_pos = gr.Button("๐Ÿ‘ ์ด ์ถ”๋ก ์ด ๋„์›€์ด ๋˜์—ˆ์Šต๋‹ˆ๋‹ค")
799
+ with gr.Column(scale=1):
800
+ feedback_btn_neg = gr.Button("๐Ÿ‘Ž ์ด ์ถ”๋ก ์€ ๊ฐœ์„ ์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค")
801
+ with gr.Column(scale=1):
802
+ clear_memory_btn = gr.Button("๐Ÿงน ๋ฉ”๋ชจ๋ฆฌ ์ดˆ๊ธฐํ™”")
803
 
804
+ # ๋ฉ”๋ชจ๋ฆฌ ์‹œ๊ฐํ™” ํƒญ
805
+ with gr.TabItem("๋ฉ”๋ชจ๋ฆฌ ์‹œ๊ฐํ™”"):
806
+ gr.Markdown("## ์‹œ๋งจํ‹ฑ ๋ฉ”๋ชจ๋ฆฌ ๋‚ด์šฉ")
807
+ semantic_memory_display = gr.Textbox(
808
+ label="ํ˜„์žฌ ์‹œ๋งจํ‹ฑ ๋ฉ”๋ชจ๋ฆฌ ๋‚ด์šฉ",
809
+ placeholder="์•„์ง ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.",
810
+ lines=10,
811
+ max_lines=20,
812
+ interactive=False
813
+ )
814
+
815
+ gr.Markdown("## ๊ทธ๋ž˜ํ”„ ์ง€์‹๋ฒ ์ด์Šค")
816
+ graph_memory_display = gr.Textbox(
817
+ label="ํ˜„์žฌ ๊ทธ๋ž˜ํ”„ ๋ฉ”๋ชจ๋ฆฌ ๋‚ด์šฉ",
818
+ placeholder="์•„์ง ๊ทธ๋ž˜ํ”„ ๋…ธ๋“œ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.",
819
+ lines=10,
820
+ max_lines=20,
821
+ interactive=False
822
  )
 
 
 
 
 
 
 
 
 
 
823
 
824
  # ์˜ˆ์ œ ์„น์…˜ - msg ๋ณ€์ˆ˜ ์ •์˜ ์ดํ›„์— ๋ฐฐ์น˜
825
  with gr.Accordion("EXAMPLES", open=False):
 
833
  inputs=msg
834
  )
835
 
836
+ with gr.Accordion("๋งค๊ฐœ๋ณ€์ˆ˜ ์กฐ์ •", open=False):
837
+ with gr.Row():
838
+ with gr.Column():
839
+ model_dropdown = gr.Dropdown(
840
+ ["CohereForAI/c4ai-command-r7b-arabic-02-2025", "meta-llama/Meta-Llama-3-8B-Instruct"],
841
+ label="๋ชจ๋ธ ์„ ํƒ",
842
+ value="CohereForAI/c4ai-command-r7b-arabic-02-2025"
843
+ )
844
+
845
+ num_tokens = gr.Slider(
846
+ 50,
847
+ 4000,
848
+ 2000,
849
+ step=1,
850
+ label="์ถ”๋ก  ๋‹จ๊ณ„๋‹น ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜",
851
+ interactive=True,
852
+ )
853
+ final_num_tokens = gr.Slider(
854
+ 50,
855
+ 4000,
856
+ 2000,
857
+ step=1,
858
+ label="์ตœ์ข… ๋‹ต๋ณ€์˜ ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜",
859
+ interactive=True,
860
+ )
861
+
862
+ with gr.Column():
863
+ do_sample = gr.Checkbox(True, label="์ƒ˜ํ”Œ๋ง ์‚ฌ์šฉ")
864
+ temperature = gr.Slider(0.1, 1.0, 0.7, step=0.1, label="์˜จ๋„")
865
+ memory_weight = gr.Slider(0.0, 1.0, 0.5, step=0.1, label="๋ฉ”๋ชจ๋ฆฌ ๋ฐ˜์˜ ๊ฐ€์ค‘์น˜")
866
+
867
+ # ํ”ผ๋“œ๋ฐฑ ์ฒ˜๋ฆฌ ํ•จ์ˆ˜
868
+ def process_positive_feedback():
869
+ nonlocal buffer_manager, current_contexts
870
+ if buffer_manager:
871
+ buffer_manager.update_retrieval_reward(current_contexts, reward=1.0)
872
+ return "ํ”ผ๋“œ๋ฐฑ ๊ฐ์‚ฌํ•ฉ๋‹ˆ๋‹ค! ์ด ์ ‘๊ทผ ๋ฐฉ์‹์„ ํ–ฅํ›„ ์œ ์‚ฌํ•œ ์งˆ๋ฌธ์— ๋” ์ž์ฃผ ์‚ฌ์šฉํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค."
873
+
874
+ def process_negative_feedback():
875
+ nonlocal buffer_manager, current_contexts
876
+ if buffer_manager:
877
+ buffer_manager.update_retrieval_reward(current_contexts, reward=-0.5)
878
+ return "ํ”ผ๋“œ๋ฐฑ ๊ฐ์‚ฌํ•ฉ๋‹ˆ๋‹ค! ์ด ์ ‘๊ทผ ๋ฐฉ์‹์„ ๊ฐœ์„ ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค."
879
+
880
+ def clear_memory():
881
+ nonlocal buffer_manager
882
+ if buffer_manager:
883
+ buffer_manager.clear()
884
+ return "๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ์ดˆ๊ธฐํ™”๋˜์—ˆ์Šต๋‹ˆ๋‹ค."
885
+
886
+ def update_memory_displays():
887
+ nonlocal buffer_manager
888
+ if not buffer_manager:
889
+ return "๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ์ดˆ๊ธฐํ™”๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.", "๊ทธ๋ž˜ํ”„๊ฐ€ ์ดˆ๊ธฐํ™”๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค."
890
+
891
+ semantic_text = "ํ˜„์žฌ ์ €์žฅ๋œ ๋ฉ”๋ชจ๋ฆฌ:\n\n"
892
+ for i, mem in enumerate(buffer_manager.semantic_memory.memories[:5]): # ์ตœ๋Œ€ 5๊ฐœ๋งŒ ํ‘œ์‹œ
893
+ semantic_text += f"{i+1}. {mem['text'][:100]}...\n\n"
894
+
895
+ graph_text = "ํ˜„์žฌ ๊ทธ๋ž˜ํ”„ ๋…ธ๋“œ:\n\n"
896
+ for node in buffer_manager.graph_memory.graph.nodes():
897
+ node_text = buffer_manager.graph_memory.get_text_by_node(node)
898
+ neighbors = list(buffer_manager.graph_memory.graph.neighbors(node))
899
+ graph_text += f"๋…ธ๋“œ: {node}\n์„ค๋ช…: {node_text[:50]}...\n์—ฐ๊ฒฐ: {', '.join(neighbors[:3])}\n\n"
900
+
901
+ return semantic_text, graph_text
902
+
903
+ # ์ดˆ๊ธฐํ™” ํ•จ์ˆ˜
904
+ def initialize_models():
905
+ nonlocal pipe, buffer_manager, model_name
906
+ try:
907
+ pipe, buffer_manager = initialize_model_and_manager(model_name)
908
+ semantic_text, graph_text = update_memory_displays()
909
+ return "๋ชจ๋ธ์ด ์ดˆ๊ธฐํ™”๋˜์—ˆ์Šต๋‹ˆ๋‹ค.", semantic_text, graph_text
910
+ except Exception as e:
911
+ return f"๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์˜ค๋ฅ˜: {str(e)}", "", ""
912
+
913
+ # ๋ชจ๋ธ ์„ ํƒ ๋ณ€๊ฒฝ ์‹œ ์ฒ˜๋ฆฌ
914
+ def change_model(new_model_name):
915
+ nonlocal model_name
916
+ model_name = new_model_name
917
+ status, semantic_text, graph_text = initialize_models()
918
+ return status, semantic_text, graph_text
919
+
920
+ # ์ดˆ๊ธฐํ™” ํ•จ์ˆ˜ ์‹คํ–‰
921
+ model_dropdown.change(
922
+ change_model,
923
+ [model_dropdown],
924
+ [gr.Textbox(visible=False), semantic_memory_display, graph_memory_display]
925
+ )
926
+
927
+ # ํ”ผ๋“œ๋ฐฑ ๋ฒ„ํŠผ์— ํ•จ์ˆ˜ ์—ฐ๊ฒฐ
928
+ feedback_btn_pos.click(process_positive_feedback, [], gr.Textbox(visible=False))
929
+ feedback_btn_neg.click(process_negative_feedback, [], gr.Textbox(visible=False))
930
+ clear_memory_btn.click(clear_memory, [], gr.Textbox(visible=False))
931
+
932
+ # ํƒญ ๋ณ€๊ฒฝ ์‹œ ๋ฉ”๋ชจ๋ฆฌ ๋””์Šคํ”Œ๋ ˆ์ด ์—…๋ฐ์ดํŠธ
933
+ tabs.change(update_memory_displays, [], [semantic_memory_display, graph_memory_display])
934
+
935
  # ์‚ฌ์šฉ์ž๊ฐ€ ๋ฉ”์‹œ์ง€๋ฅผ ์ œ์ถœํ•˜๋ฉด ๋‘ ๋ด‡์ด ๋™์‹œ์— ์‘๋‹ตํ•ฉ๋‹ˆ๋‹ค
936
  msg.submit(
937
  user_input,
938
  [msg, chatbot_original, chatbot_thinking], # ์ž…๋ ฅ
939
  [msg, chatbot_original, chatbot_thinking], # ์ถœ๋ ฅ
940
  ).then(
941
+ lambda h, n, d, t, p: bot_original(h, n, d, t, p), # pipe ๋งค๊ฐœ๋ณ€์ˆ˜ ์ถ”๊ฐ€
942
  [
943
+ chatbot_original,
944
  num_tokens,
945
  do_sample,
946
  temperature,
947
+ gr.Textbox(value=lambda: pipe, visible=False), # pipe ์ „๋‹ฌ
948
  ],
949
  chatbot_original, # ์ถœ๋ ฅ์—์„œ ์ƒˆ ํžˆ์Šคํ† ๋ฆฌ ์ €์žฅ
950
  ).then(
951
+ lambda h, n, f, d, t, p, b: bot_thinking_enhanced(h, n, f, d, t, p, b), # ๋งค๊ฐœ๋ณ€์ˆ˜ ์ถ”๊ฐ€
952
  [
953
  chatbot_thinking,
954
  num_tokens,
955
+ final_num_tokens,
956
  do_sample,
957
  temperature,
958
+ gr.Textbox(value=lambda: pipe, visible=False), # pipe ์ „๋‹ฌ
959
+ gr.Textbox(value=lambda: buffer_manager, visible=False), # buffer_manager ์ „๋‹ฌ
960
  ],
961
  chatbot_thinking, # ์ถœ๋ ฅ์—์„œ ์ƒˆ ํžˆ์Šคํ† ๋ฆฌ ์ €์žฅ
962
+ ).then(
963
+ update_memory_displays,
964
+ [],
965
+ [semantic_memory_display, graph_memory_display]
966
  )
967
 
968
+ # ์‹œ์ž‘ ์‹œ ๋ชจ๋ธ ์ดˆ๊ธฐํ™”๋ฅผ ์œ„ํ•œ ์ฝ”๋“œ
969
+ def load_on_startup():
970
+ global pipe, buffer_manager
971
+ try:
972
+ # ๊ธฐ๋ณธ ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
973
+ pipe, buffer_manager = initialize_model_and_manager(
974
+ "CohereForAI/c4ai-command-r7b-arabic-02-2025"
975
+ )
976
+ logger.info("๋ชจ๋ธ ๋ฐ ๋ฒ„ํผ ๋งค๋‹ˆ์ €๊ฐ€ ์„ฑ๊ณต์ ์œผ๋กœ ์ดˆ๊ธฐํ™”๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
977
+ except Exception as e:
978
+ logger.error(f"์‹œ์ž‘ ์‹œ ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์‹คํŒจ: {str(e)}")
979
+
980
  if __name__ == "__main__":
981
+ # ์‘์šฉ ํ”„๋กœ๊ทธ๋žจ ์‹œ์ž‘ ์ „์— ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
982
+ load_on_startup()
983
+
984
+ # ๋Œ€๊ธฐ์—ด ๋ฐ ์„œ๋ฒ„ ์‹œ์ž‘
985
+ demo.queue().launch(
986
+ share=False,
987
+ debug=True,
988
+ title="Enhanced ThinkFlow with DeepSeek Features"
989
+ )