Spaces:
Runtime error
Runtime error
| import re | |
| import threading | |
| import time | |
| import os | |
| import logging | |
| from datetime import datetime | |
| import torch | |
| import numpy as np | |
| from typing import List, Optional, Tuple, Dict | |
| import networkx as nx | |
| import gradio as gr | |
| import transformers | |
| from transformers import ( | |
| pipeline, | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| BartForConditionalGeneration, | |
| BartTokenizer, | |
| BitsAndBytesConfig | |
| ) | |
| # λ‘κΉ μ€μ | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ===================== RLRetrievalPolicy ===================== | |
| class RLRetrievalPolicy: | |
| def __init__(self): | |
| self.policy_data = {} | |
| self.alpha = 0.5 # μ μ¬λ vs. RL μ μ κ° κ°μ€μΉ | |
| def update_policy(self, contexts: List[str], reward: float): | |
| for ctx in contexts: | |
| if ctx not in self.policy_data: | |
| self.policy_data[ctx] = 0.0 | |
| self.policy_data[ctx] += reward | |
| def re_rank(self, candidates: List[Tuple[float, str]]) -> List[str]: | |
| reweighted = [] | |
| for sim, txt in candidates: | |
| rl_score = self.policy_data.get(txt, 0.0) | |
| reweighted_score = sim * (1 - self.alpha) + rl_score * self.alpha | |
| reweighted.append((reweighted_score, txt)) | |
| reweighted.sort(key=lambda x: x[0], reverse=True) | |
| return [t for _, t in reweighted] | |
| # ===================== GraphMemory ===================== | |
| class GraphMemory: | |
| def __init__(self): | |
| self.graph = nx.DiGraph() | |
| # μν λ¬Έμ ν΄κ²°μ λμμ΄ λλ κΈ°λ³Έ λ Έλ μΆκ° | |
| self.add_node("μν", "μν λ¬Έμ ν΄κ²°μ μν μΌλ°μ μΈ μ κ·Όλ²") | |
| self.add_node("λμν", "λ°©μ μ, ν¨μ, λΉλ‘ κ΄κ³ λ±μ λ€λ£¨λ μνμ ν λΆμΌ") | |
| self.add_node("κΈ°νν", "곡κ°, λν, κ°λ λ±μ λ€λ£¨λ μνμ ν λΆμΌ") | |
| self.add_node("μ°μ ", "κΈ°λ³Έμ μΈ μ μ°μ°, λΉμ¨, λ°±λΆμ¨ λ±μ λ€λ£¨λ λΆμΌ") | |
| self.add_node("νλ₯ ", "μ¬κ±΄μ λ°μ κ°λ₯μ±μ μΈ‘μ νλ μνμ ν λΆμΌ") | |
| # κ΄κ³ μ€μ | |
| self.add_edge("λμν", "μν") | |
| self.add_edge("κΈ°νν", "μν") | |
| self.add_edge("μ°μ ", "μν") | |
| self.add_edge("νλ₯ ", "μν") | |
| def add_node(self, node_id: str, text: str = ""): | |
| self.graph.add_node(node_id, text=text) | |
| def add_edge(self, src: str, dst: str): | |
| self.graph.add_edge(src, dst) | |
| def get_text_by_node(self, node_id: str) -> str: | |
| return self.graph.nodes[node_id].get('text', "") | |
| def has_node(self, node_id: str) -> bool: | |
| return node_id in self.graph.nodes | |
| def search_nodes(self, keyword: str, max_nodes: int = 3) -> List[str]: | |
| matches = [] | |
| for n in self.graph.nodes(): | |
| node_text = self.get_text_by_node(n).lower() | |
| n_lower = n.lower() | |
| if keyword.lower() in node_text or keyword.lower() in n_lower: | |
| score = node_text.count(keyword.lower()) + n_lower.count(keyword.lower()) | |
| matches.append((score, n)) | |
| matches.sort(key=lambda x: x[0], reverse=True) | |
| top_nodes = [m[1] for m in matches[:max_nodes]] | |
| return top_nodes | |
| def get_connected_context(self, start_node: str, steps: int = 1) -> List[str]: | |
| contexts = [] | |
| visited = set() | |
| queue = [(start_node, 0)] | |
| while queue: | |
| current, depth = queue.pop(0) | |
| if current not in visited: | |
| visited.add(current) | |
| contexts.append(self.get_text_by_node(current)) | |
| if depth < steps: | |
| for neighbor in self.graph.successors(current): | |
| queue.append((neighbor, depth + 1)) | |
| for neighbor in self.graph.predecessors(current): | |
| queue.append((neighbor, depth + 1)) | |
| return contexts | |
| # ===================== SimpleSummarizer ===================== | |
| class SimpleSummarizer: | |
| def __init__(self, model_name="facebook/bart-large-cnn"): | |
| self.model_name = model_name | |
| self.model = None | |
| self.tokenizer = None | |
| def load_summarization_model(self): | |
| if self.model is None: | |
| try: | |
| self.tokenizer = BartTokenizer.from_pretrained(self.model_name) | |
| self.model = BartForConditionalGeneration.from_pretrained(self.model_name) | |
| if torch.cuda.is_available(): | |
| self.model = self.model.cuda() | |
| except Exception as e: | |
| logger.error(f"Error loading summarization model: {str(e)}") | |
| raise | |
| def summarize_text(self, text: str, max_length: int = 100) -> str: | |
| try: | |
| self.load_summarization_model() | |
| inputs = self.tokenizer([text], max_length=1024, return_tensors='pt', truncation=True) | |
| if torch.cuda.is_available(): | |
| inputs = {k: v.cuda() for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| summary_ids = self.model.generate( | |
| inputs["input_ids"], | |
| num_beams=4, | |
| max_length=max_length, | |
| early_stopping=True | |
| ) | |
| summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
| return summary | |
| except Exception as e: | |
| logger.error(f"Error in summarization: {str(e)}") | |
| return "μμ½μ μμ±ν μ μμ΅λλ€." | |
| # ===================== SemanticMemory ===================== | |
| class SemanticMemory: | |
| def __init__(self, max_entries: int = 4000): | |
| self.memories: List[dict] = [] | |
| self.max_entries = max_entries | |
| self.rl_policy = RLRetrievalPolicy() | |
| def add_memory(self, text: str, embedding: torch.Tensor): | |
| if len(self.memories) >= self.max_entries: | |
| self.memories.pop(0) | |
| self.memories.append({ | |
| 'text': text, | |
| 'embedding': embedding, | |
| 'timestamp': time.time() | |
| }) | |
| def get_candidates(self, query_embedding: torch.Tensor) -> List[Tuple[float, str]]: | |
| candidates = [] | |
| for mem in self.memories: | |
| if mem['embedding'].shape == query_embedding.shape: | |
| sim = torch.cosine_similarity( | |
| query_embedding.float(), | |
| mem['embedding'].float(), | |
| dim=-1 | |
| ) | |
| candidates.append((sim.item(), mem['text'])) | |
| candidates.sort(key=lambda x: x[0], reverse=True) | |
| return candidates | |
| def get_relevant_context(self, query_embedding: torch.Tensor, top_k: int = 3) -> List[str]: | |
| candidates = self.get_candidates(query_embedding) | |
| re_ranked = self.rl_policy.re_rank(candidates) | |
| return re_ranked[:top_k] | |
| def update_retrieval_reward(self, texts: List[str], reward: float): | |
| self.rl_policy.update_policy(texts, reward) | |
| def clear(self): | |
| self.memories = [] | |
| # ===================== GenericInferenceBuffer ===================== | |
| MAX_TOKEN_BUFFER = 1024 | |
| class GenericInferenceBuffer: | |
| def __init__(self, layer_idx: int, compression_rank: int = 128): | |
| self.layer_idx = layer_idx | |
| self.key_buffer: Optional[torch.Tensor] = None | |
| self.value_buffer: Optional[torch.Tensor] = None | |
| self.semantic_context: Optional[torch.Tensor] = None | |
| self.last_update: float = 0 | |
| self.compression_rank = compression_rank | |
| def update_buffer( | |
| self, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| semantic_context: Optional[torch.Tensor] = None | |
| ): | |
| try: | |
| if self.key_buffer is None: | |
| self.key_buffer = key.detach().clone() | |
| self.value_buffer = value.detach().clone() | |
| if semantic_context is not None: | |
| self.semantic_context = semantic_context.detach().clone() | |
| else: | |
| self.key_buffer = torch.cat([self.key_buffer, key.detach()], dim=2) | |
| self.value_buffer = torch.cat([self.value_buffer, value.detach()], dim=2) | |
| if semantic_context is not None and self.semantic_context is not None: | |
| self.semantic_context = torch.cat([self.semantic_context, semantic_context.detach()], dim=0) | |
| if self.key_buffer.shape[2] > MAX_TOKEN_BUFFER: | |
| excess = self.key_buffer.shape[2] - MAX_TOKEN_BUFFER | |
| self.key_buffer = self.key_buffer[:, :, excess:, :] | |
| self.value_buffer = self.value_buffer[:, :, excess:, :] | |
| if self.semantic_context is not None: | |
| self.semantic_context = self.semantic_context[excess:, :] | |
| self.last_update = time.time() | |
| except Exception as e: | |
| logger.error(f"Buffer update error in layer {self.layer_idx}: {str(e)}") | |
| def compress_buffer_svd(self): | |
| if self.key_buffer is None or self.value_buffer is None: | |
| return | |
| try: | |
| k_shape = self.key_buffer.shape | |
| v_shape = self.value_buffer.shape | |
| k_2d = self.key_buffer.reshape(k_shape[0]*k_shape[1], k_shape[2]*k_shape[3]).float() | |
| v_2d = self.value_buffer.reshape(v_shape[0]*v_shape[1], v_shape[2]*v_shape[3]).float() | |
| device = k_2d.device | |
| k_2d_cpu = k_2d.cpu() | |
| v_2d_cpu = v_2d.cpu() | |
| U_k, S_k, V_k = torch.linalg.svd(k_2d_cpu, full_matrices=False) | |
| U_v, S_v, V_v = torch.linalg.svd(v_2d_cpu, full_matrices=False) | |
| rank_k = min(self.compression_rank, S_k.shape[0]) | |
| rank_v = min(self.compression_rank, S_v.shape[0]) | |
| k_approx = (U_k[:, :rank_k] * S_k[:rank_k]) @ V_k[:rank_k, :] | |
| v_approx = (U_v[:, :rank_v] * S_v[:rank_v]) @ V_v[:rank_v, :] | |
| k_approx = k_approx.to(device) | |
| v_approx = v_approx.to(device) | |
| self.key_buffer = k_approx.reshape(k_shape).type(self.key_buffer.dtype) | |
| self.value_buffer = v_approx.reshape(v_shape).type(self.value_buffer.dtype) | |
| except Exception as e: | |
| logger.error(f"SVD compression error in layer {self.layer_idx}: {str(e)}") | |
| def get_buffer(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: | |
| return self.key_buffer, self.value_buffer, self.semantic_context | |
| def clear(self): | |
| self.key_buffer = None | |
| self.value_buffer = None | |
| self.semantic_context = None | |
| self.last_update = 0 | |
| # ===================== InferenceBufferManager ===================== | |
| class InferenceBufferManager: | |
| def __init__(self, num_layers: int, hidden_size: int): | |
| self.num_layers = num_layers | |
| self.hidden_size = hidden_size | |
| self.layer_buffers = [ | |
| GenericInferenceBuffer(i, compression_rank=128) for i in range(num_layers) | |
| ] | |
| self.semantic_memory = SemanticMemory() | |
| self.graph_memory = GraphMemory() | |
| self.summarizer = SimpleSummarizer() | |
| self.summarize_threshold = 1500 | |
| self.generated_tokens_count = 0 | |
| self.compression_interval = 512 | |
| self.token_count_since_compress = 0 | |
| def _compute_semantic_embedding(self, key: Optional[torch.Tensor], value: Optional[torch.Tensor]) -> torch.Tensor: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if key is None or value is None: | |
| return torch.zeros((1, self.hidden_size), dtype=torch.float32, device=device) | |
| combined = key * value | |
| combined = combined.mean(dim=2) | |
| combined = combined.reshape(combined.shape[0], -1) | |
| combined = torch.nn.functional.normalize(combined, dim=-1) | |
| return combined | |
| def update_buffer(self, layer_outputs, current_tokens: List[int], semantic_context: torch.Tensor, tokenizer): | |
| try: | |
| if hasattr(layer_outputs, 'past_key_values'): | |
| for layer_idx, past_kv in enumerate(layer_outputs.past_key_values): | |
| if isinstance(past_kv, tuple) and len(past_kv) == 2: | |
| key, value = past_kv | |
| if key is not None and value is not None: | |
| self.layer_buffers[layer_idx].update_buffer( | |
| key.detach(), | |
| value.detach(), | |
| semantic_context | |
| ) | |
| self.generated_tokens_count += len(current_tokens) | |
| self.token_count_since_compress += len(current_tokens) | |
| if self.token_count_since_compress >= self.compression_interval: | |
| self.compress_all_buffers() | |
| self.token_count_since_compress = 0 | |
| except Exception as e: | |
| logger.error(f"Buffer update error: {str(e)}") | |
| def compress_all_buffers(self): | |
| for buf in self.layer_buffers: | |
| buf.compress_buffer_svd() | |
| def finalize_semantic_memory(self, tokenizer, generated_tokens: List[int]): | |
| if self.layer_buffers and len(self.layer_buffers) > 0 and self.layer_buffers[-1].key_buffer is not None: | |
| text_chunk = tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
| key_buffer = self.layer_buffers[-1].key_buffer | |
| value_buffer = self.layer_buffers[-1].value_buffer | |
| embedding = self._compute_semantic_embedding(key_buffer, value_buffer) | |
| self.semantic_memory.add_memory(text_chunk, embedding) | |
| def get_relevant_context(self, query_embedding: torch.Tensor, top_k: int = 3) -> List[str]: | |
| candidates_sem = self.semantic_memory.get_candidates(query_embedding) | |
| # ν€μλ μΆμΆ (κ°λ¨ν ꡬν) | |
| possible_keywords = ["μν", "λμν", "κΈ°νν", "μ°μ ", "νλ₯ "] | |
| text_candidates = [] | |
| for kw in possible_keywords: | |
| nodes = self.graph_memory.search_nodes(kw) | |
| for n in nodes: | |
| context_list = self.graph_memory.get_connected_context(n, steps=1) | |
| cscore = 1.0 | |
| for ctxt in context_list: | |
| text_candidates.append((cscore, ctxt)) | |
| merged_candidates = candidates_sem + text_candidates | |
| re_ranked = self.semantic_memory.rl_policy.re_rank(merged_candidates) | |
| return re_ranked[:top_k] | |
| def update_retrieval_reward(self, contexts: List[str], reward: float): | |
| self.semantic_memory.update_retrieval_reward(contexts, reward) | |
| def maybe_summarize_memory(self): | |
| if self.generated_tokens_count < self.summarize_threshold: | |
| return | |
| all_text = "\n".join([m['text'] for m in self.semantic_memory.memories]) | |
| if len(all_text) < 300: | |
| return | |
| summary = self.summarizer.summarize_text(all_text, max_length=120) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| summary_embedding = torch.zeros((1, self.hidden_size), dtype=torch.float32, device=device) | |
| self.semantic_memory.clear() | |
| self.semantic_memory.add_memory(summary, summary_embedding) | |
| self.generated_tokens_count = 0 | |
| def clear(self): | |
| for layer in self.layer_buffers: | |
| layer.clear() | |
| self.semantic_memory.clear() | |
| # ===================== Enhanced ThinkFlow Implementation ===================== | |
| # μ΅μ’ λ΅λ³μ κ°μ§νκΈ° μν λ§μ»€ | |
| ANSWER_MARKER = "**λ΅λ³**" | |
| # λ¨κ³λ³ μΆλ‘ μ μμνλ λ¬Έμ₯λ€ | |
| rethink_prepends = [ | |
| "μ, μ΄μ λ€μμ νμ ν΄μΌ ν©λλ€ ", | |
| "μ μκ°μλ ", | |
| "μ μλ§μ, μ μκ°μλ ", | |
| "λ€μ μ¬νμ΄ λ§λμ§ νμΈν΄ λ³΄κ² μ΅λλ€ ", | |
| "λν κΈ°μ΅ν΄μΌ ν κ²μ ", | |
| "λ λ€λ₯Έ μ£Όλͺ©ν μ μ ", | |
| "κ·Έλ¦¬κ³ μ λ λ€μκ³Ό κ°μ μ¬μ€λ κΈ°μ΅ν©λλ€ ", | |
| "μ΄μ μΆ©λΆν μ΄ν΄νλ€κ³ μκ°ν©λλ€ ", | |
| ] | |
| # μ΅μ’ λ΅λ³ μμ±μ μν ν둬ννΈ μΆκ° | |
| final_answer_prompt = """ | |
| μ§κΈκΉμ§μ μΆλ‘ κ³Όμ μ λ°νμΌλ‘, μλ μ§λ¬Έμ μ¬μ©λ μΈμ΄λ‘ λ΅λ³νκ² μ΅λλ€: | |
| {question} | |
| μλλ λ΄κ° μΆλ‘ ν κ²°λ‘ μ λλ€: | |
| {reasoning_conclusion} | |
| μ μΆλ‘ μ κΈ°λ°μΌλ‘ μ΅μ’ λ΅λ³: | |
| {ANSWER_MARKER} | |
| """ | |
| # μμ νμ λ¬Έμ ν΄κ²°μ μν μ€μ | |
| latex_delimiters = [ | |
| {"left": "$$", "right": "$$", "display": True}, | |
| {"left": "$", "right": "$", "display": False}, | |
| ] | |
| def reformat_math(text): | |
| """Gradio ꡬ문(Katex)μ μ¬μ©νλλ‘ MathJax κ΅¬λΆ κΈ°νΈ μμ .""" | |
| text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL) | |
| text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL) | |
| return text | |
| def extract_keywords(text: str) -> List[str]: | |
| """ν μ€νΈμμ κ°λ¨ν ν€μλ μΆμΆ ν¨μ""" | |
| # κ°λ¨ν ꡬν - μ€μ λ‘λ λ 볡μ‘ν NLP κΈ°λ²μ μ¬μ©ν μ μμ | |
| common_math_keywords = [ | |
| "μν", "λμν", "κΈ°νν", "μ°μ ", "νλ₯ ", "곡μ", "λ°©μ μ", | |
| "ν¨μ", "μ λΆ", "λ―ΈλΆ", "κΈ°ν", "μΌκ°ν", "μ", "κ°λ", "λΉμ¨", | |
| "λΉλ‘", "νκ· ", "λΆμ°", "νμ€νΈμ°¨" | |
| ] | |
| keywords = [] | |
| for kw in common_math_keywords: | |
| if kw in text: | |
| keywords.append(kw) | |
| return keywords[:5] # μ΅λ 5κ° ν€μλλ§ λ°ν | |
| def get_embedding_for_text(text: str, hidden_size: int = 768) -> torch.Tensor: | |
| """ | |
| ν μ€νΈλ₯Ό μν μμ μλ² λ© μμ± ν¨μ | |
| μ€μ ꡬνμμλ μ μ ν μΈμ΄ λͺ¨λΈμ μ¬μ©ν΄μΌ ν¨ | |
| """ | |
| # μμ ꡬν: ν μ€νΈμ ν΄μ κ°μ κΈ°λ°μΌλ‘ ν μλ² λ© | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| hash_val = hash(text) | |
| np.random.seed(hash_val) | |
| # μμμ μλ² λ© μμ± | |
| embedding = np.random.rand(1, hidden_size).astype(np.float32) | |
| # μ κ·ν | |
| norm = np.linalg.norm(embedding) | |
| if norm > 0: | |
| embedding = embedding / norm | |
| return torch.tensor(embedding, device=device) | |
| def user_input(message, history_original, history_thinking): | |
| """μ¬μ©μ μ λ ₯μ νμ€ν 리μ μΆκ°νκ³ μ λ ₯ ν μ€νΈ μμ λΉμ°κΈ°""" | |
| return "", history_original + [ | |
| gr.ChatMessage(role="user", content=message.replace(ANSWER_MARKER, "")) | |
| ], history_thinking + [ | |
| gr.ChatMessage(role="user", content=message.replace(ANSWER_MARKER, "")) | |
| ] | |
| def rebuild_messages(history: list): | |
| """μ€κ° μκ° κ³Όμ μμ΄ λͺ¨λΈμ΄ μ¬μ©ν νμ€ν 리μμ λ©μμ§ μ¬κ΅¬μ±""" | |
| messages = [] | |
| for h in history: | |
| if isinstance(h, dict) and not h.get("metadata", {}).get("title", False): | |
| messages.append(h) | |
| elif ( | |
| isinstance(h, gr.ChatMessage) | |
| and h.metadata.get("title", None) is None | |
| and isinstance(h.content, str) | |
| ): | |
| messages.append({"role": h.role, "content": h.content}) | |
| return messages | |
| # λͺ¨λΈκ³Ό λ²νΌ λ§€λμ μ΄κΈ°ν ν¨μ | |
| def initialize_model_and_manager(model_name): | |
| """λͺ¨λΈκ³Ό λ²νΌ λ§€λμ μ΄κΈ°ν ν¨μ""" | |
| try: | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model_name, | |
| device_map="auto", | |
| torch_dtype="auto", | |
| ) | |
| # λͺ¨λΈ ꡬμ±μμ λ μ΄μ΄ λ° μλ ν¬κΈ° μ 보 μΆμΆ | |
| config = pipe.model.config | |
| if hasattr(config, "n_layer"): | |
| num_layers = config.n_layer | |
| elif hasattr(config, "num_layers"): | |
| num_layers = config.num_layers | |
| elif hasattr(config, "num_hidden_layers"): | |
| num_layers = config.num_hidden_layers | |
| else: | |
| num_layers = 12 # κΈ°λ³Έκ° | |
| if hasattr(config, "n_embd"): | |
| hidden_size = config.n_embd | |
| elif hasattr(config, "hidden_size"): | |
| hidden_size = config.hidden_size | |
| else: | |
| hidden_size = 768 # κΈ°λ³Έκ° | |
| # λ²νΌ λ§€λμ μ΄κΈ°ν | |
| buffer_manager = InferenceBufferManager(num_layers, hidden_size) | |
| return pipe, buffer_manager | |
| except Exception as e: | |
| logger.error(f"λͺ¨λΈ μ΄κΈ°ν μ€λ₯: {str(e)}") | |
| raise | |
| def bot_original( | |
| history: list, | |
| max_num_tokens: int, | |
| do_sample: bool, | |
| temperature: float, | |
| pipe=None | |
| ): | |
| """μλ³Έ λͺ¨λΈμ΄ μ§λ¬Έμ λ΅λ³νλλ‘ νκΈ° (μΆλ‘ κ³Όμ μμ΄)""" | |
| if pipe is None: | |
| # μ΄ λΆλΆμ μ€μ ꡬνμμλ μ μ λ³μλ μΈμ μνλ‘ κ΄λ¦¬ν΄μΌ ν¨ | |
| return history | |
| # λμ€μ μ€λ λμμ ν ν°μ μ€νΈλ¦ΌμΌλ‘ κ°μ Έμ€κΈ° μν¨ | |
| streamer = transformers.TextIteratorStreamer( | |
| pipe.tokenizer, | |
| skip_special_tokens=True, | |
| skip_prompt=True, | |
| ) | |
| # 보쑰μ λ©μμ§ μ€λΉ | |
| history.append( | |
| gr.ChatMessage( | |
| role="assistant", | |
| content=str(""), | |
| ) | |
| ) | |
| # νμ¬ μ±ν μ νμλ λ©μμ§ | |
| messages = rebuild_messages(history[:-1]) # λ§μ§λ§ λΉ λ©μμ§ μ μΈ | |
| # μλ³Έ λͺ¨λΈμ μΆλ‘ μμ΄ λ°λ‘ λ΅λ³ | |
| t = threading.Thread( | |
| target=pipe, | |
| args=(messages,), | |
| kwargs=dict( | |
| max_new_tokens=max_num_tokens, | |
| streamer=streamer, | |
| do_sample=do_sample, | |
| temperature=temperature, | |
| ), | |
| ) | |
| t.start() | |
| for token in streamer: | |
| history[-1].content += token | |
| history[-1].content = reformat_math(history[-1].content) | |
| yield history | |
| t.join() | |
| yield history | |
| def bot_thinking_enhanced( | |
| history: list, | |
| max_num_tokens: int, | |
| final_num_tokens: int, | |
| do_sample: bool, | |
| temperature: float, | |
| pipe=None, | |
| buffer_manager=None | |
| ): | |
| """μΆλ‘ κ³Όμ μ ν¬ν¨νμ¬ λͺ¨λΈμ΄ μ§λ¬Έμ λ΅λ³νλλ‘ νκΈ° - DeepSeek κΈ°λ₯ ν΅ν©""" | |
| if pipe is None or buffer_manager is None: | |
| # μ΄ λΆλΆμ μ€μ ꡬνμμλ μ μ λ³μλ μΈμ μνλ‘ κ΄λ¦¬ν΄μΌ ν¨ | |
| return history | |
| # λμ€μ μ€λ λμμ ν ν°μ μ€νΈλ¦ΌμΌλ‘ κ°μ Έμ€κΈ° μν¨ | |
| streamer = transformers.TextIteratorStreamer( | |
| pipe.tokenizer, | |
| skip_special_tokens=True, | |
| skip_prompt=True, | |
| ) | |
| # νμν κ²½μ° μΆλ‘ μ μ§λ¬Έμ λ€μ μ½μ νκΈ° μν¨ | |
| question = history[-1]["content"] | |
| # 쿼리 μλ² λ© μμ± | |
| query_embedding = get_embedding_for_text(question, buffer_manager.hidden_size) | |
| # κ΄λ ¨ 컨ν μ€νΈ κ²μ | |
| relevant_contexts = buffer_manager.get_relevant_context(query_embedding, top_k=3) | |
| # ν€μλ μΆμΆ λ° κ·Έλν λ©λͺ¨λ¦¬μμ 컨ν μ€νΈ κ°μ Έμ€κΈ° | |
| keywords = extract_keywords(question) | |
| graph_contexts = [] | |
| for keyword in keywords: | |
| nodes = buffer_manager.graph_memory.search_nodes(keyword) | |
| for node in nodes: | |
| contexts = buffer_manager.graph_memory.get_connected_context(node) | |
| graph_contexts.extend(contexts) | |
| # λͺ¨λ 컨ν μ€νΈ λ³ν© | |
| all_contexts = relevant_contexts + graph_contexts | |
| all_contexts = list(set(all_contexts)) # μ€λ³΅ μ κ±° | |
| all_contexts = all_contexts[:5] # μ΅λ 5κ° μ»¨ν μ€νΈλ‘ μ ν | |
| # 보쑰μ λ©μμ§ μ€λΉ | |
| history.append( | |
| gr.ChatMessage( | |
| role="assistant", | |
| content=str(""), | |
| metadata={"title": "π§ μκ° μ€...", "status": "pending"}, | |
| ) | |
| ) | |
| # νμ¬ μ±ν μ νμλ μΆλ‘ κ³Όμ | |
| messages = rebuild_messages(history) | |
| # κ΄λ ¨ 컨ν μ€νΈκ° μλ€λ©΄ λ©μμ§μ μΆκ° | |
| if all_contexts: | |
| context_str = "\n\nκ΄λ ¨ 컨ν μ€νΈ:\n" + "\n".join(all_contexts) | |
| messages[-1]["content"] += context_str | |
| history[-1].content += context_str | |
| # μ 체 μΆλ‘ κ³Όμ μ μ μ₯ν λ³μ | |
| full_reasoning = "" | |
| # μμ±λ ν ν° μΆμ μ μν λ³μ | |
| generated_tokens = [] | |
| # μΆλ‘ λ¨κ³ μ€ν | |
| for i, prepend in enumerate(rethink_prepends): | |
| if i > 0: | |
| messages[-1]["content"] += "\n\n" | |
| messages[-1]["content"] += prepend.format(question=question) | |
| t = threading.Thread( | |
| target=pipe, | |
| args=(messages,), | |
| kwargs=dict( | |
| max_new_tokens=max_num_tokens, | |
| streamer=streamer, | |
| do_sample=do_sample, | |
| temperature=temperature, | |
| ), | |
| ) | |
| t.start() | |
| # μ λ΄μ©μΌλ‘ νμ€ν 리 μ¬κ΅¬μ± | |
| history[-1].content += prepend.format(question=question) | |
| step_tokens = [] | |
| for token in streamer: | |
| history[-1].content += token | |
| history[-1].content = reformat_math(history[-1].content) | |
| step_tokens.append(token) | |
| generated_tokens.append(token) | |
| yield history | |
| t.join() | |
| # κ° μΆλ‘ λ¨κ³μ κ²°κ³Όλ₯Ό full_reasoningμ μ μ₯ | |
| full_reasoning = history[-1].content | |
| # μΆλ‘ μ΄ κΈΈμ΄μ§λ©΄ μ€κ° μμ½ μμ± | |
| if i > 0 and i % 3 == 0 and len(generated_tokens) > 500: | |
| try: | |
| summary = buffer_manager.summarizer.summarize_text(full_reasoning, max_length=150) | |
| summary_text = f"\n\n**μ€κ° μμ½:**\n{summary}\n\n" | |
| history[-1].content += summary_text | |
| messages[-1]["content"] += summary_text | |
| yield history | |
| except Exception as e: | |
| logger.error(f"μμ½ μμ± μ€λ₯: {str(e)}") | |
| # KV μΊμ μμΆ | |
| if i > 0 and i % 2 == 0: | |
| buffer_manager.compress_all_buffers() | |
| # μλ§¨ν± μ»¨ν μ€νΈ μ λ°μ΄νΈ | |
| step_text = "".join(step_tokens) | |
| step_embedding = get_embedding_for_text(step_text, buffer_manager.hidden_size) | |
| buffer_manager.semantic_memory.add_memory(step_text, step_embedding) | |
| # μΆλ‘ μλ£, μ΄μ μ΅μ’ λ΅λ³μ μμ± | |
| history[-1].metadata = {"title": "π μ¬κ³ κ³Όμ ", "status": "done"} | |
| # μΆλ‘ κ³Όμ μ μλ§¨ν± λ©λͺ¨λ¦¬μ κ·Έλν λ©λͺ¨λ¦¬μ μ μ₯ | |
| full_embedding = get_embedding_for_text(full_reasoning, buffer_manager.hidden_size) | |
| buffer_manager.semantic_memory.add_memory(full_reasoning, full_embedding) | |
| # ν€μλμ λν κ·Έλν λ©λͺ¨λ¦¬ μ λ°μ΄νΈ | |
| for keyword in keywords: | |
| if not buffer_manager.graph_memory.has_node(keyword): | |
| buffer_manager.graph_memory.add_node(keyword, f"{keyword}μ κ΄ν κ°λ : μ΄ μ£Όμ μ λν μΆλ‘ μ μννμ΅λλ€.") | |
| # κ΄λ ¨ λ Έλμ μ°κ²° | |
| for related_kw in keywords: | |
| if related_kw != keyword and buffer_manager.graph_memory.has_node(related_kw): | |
| buffer_manager.graph_memory.add_edge(keyword, related_kw) | |
| # μΆλ‘ κ³Όμ μμ κ²°λ‘ λΆλΆμ μΆμΆ (λ§μ§λ§ 1-2 λ¬Έλ¨ μ λ) | |
| reasoning_parts = full_reasoning.split("\n\n") | |
| reasoning_conclusion = "\n\n".join(reasoning_parts[-2:]) if len(reasoning_parts) > 2 else full_reasoning | |
| # μ΅μ’ λ΅λ³ λ©μμ§ μΆκ° | |
| history.append(gr.ChatMessage(role="assistant", content="")) | |
| # μ΅μ’ λ΅λ³μ μν λ©μμ§ κ΅¬μ± | |
| final_messages = rebuild_messages(history[:-1]) # λ§μ§λ§ λΉ λ©μμ§ μ μΈ | |
| final_prompt = final_answer_prompt.format( | |
| question=question, | |
| reasoning_conclusion=reasoning_conclusion, | |
| ANSWER_MARKER=ANSWER_MARKER | |
| ) | |
| final_messages[-1]["content"] += final_prompt | |
| # μ΅μ’ λ΅λ³ μμ± | |
| t = threading.Thread( | |
| target=pipe, | |
| args=(final_messages,), | |
| kwargs=dict( | |
| max_new_tokens=final_num_tokens, | |
| streamer=streamer, | |
| do_sample=do_sample, | |
| temperature=temperature, | |
| ), | |
| ) | |
| t.start() | |
| # μ΅μ’ λ΅λ³ μ€νΈλ¦¬λ° | |
| final_tokens = [] | |
| for token in streamer: | |
| history[-1].content += token | |
| history[-1].content = reformat_math(history[-1].content) | |
| final_tokens.append(token) | |
| yield history | |
| t.join() | |
| # μ΅μ’ λ΅λ³μ μλ§¨ν± λ©λͺ¨λ¦¬μ μ μ₯ | |
| final_text = "".join(final_tokens) | |
| final_embedding = get_embedding_for_text(final_text, buffer_manager.hidden_size) | |
| buffer_manager.semantic_memory.add_memory(final_text, final_embedding) | |
| # μ£ΌκΈ°μ λ©λͺ¨λ¦¬ μμ½ μ²΄ν¬ | |
| buffer_manager.maybe_summarize_memory() | |
| yield history | |
| with gr.Blocks(fill_height=True, title="Enhanced ThinkFlow") as demo: | |
| # μ λͺ©κ³Ό μ€λͺ | |
| gr.Markdown("# Enhanced ThinkFlow with DeepSeek Features") | |
| gr.Markdown("### μλ§¨ν± λ©λͺ¨λ¦¬, κ·Έλν λ©λͺ¨λ¦¬, λ° KV μΊμ μμΆμ ν΅ν΄ ν₯μλ LLM μΆλ‘ μμ± νλ«νΌ") | |
| # λͺ¨λΈ λ° λ²νΌ λ§€λμ μ΄κΈ°ν (μ€μ ꡬνμμλ μΈμ μνλ‘ κ΄λ¦¬) | |
| model_name = "CohereForAI/c4ai-command-r7b-arabic-02-2025" | |
| # μΈμ λ³μ (μ€μ ꡬνμμλ gr.State() μ¬μ©) | |
| pipe = None | |
| buffer_manager = None | |
| current_contexts = [] | |
| # ν μΈν°νμ΄μ€ | |
| with gr.Tabs() as tabs: | |
| # μ±ν ν | |
| with gr.TabItem("ν΅ν© μΆλ‘ μΈν°νμ΄μ€"): | |
| with gr.Row(scale=1): | |
| with gr.Column(scale=2): | |
| gr.Markdown("## Before (Original)") | |
| chatbot_original = gr.Chatbot( | |
| scale=1, | |
| type="messages", | |
| latex_delimiters=latex_delimiters, | |
| label="Original Model (No Reasoning)" | |
| ) | |
| with gr.Column(scale=2): | |
| gr.Markdown("## After (Enhanced Thinking)") | |
| chatbot_thinking = gr.Chatbot( | |
| scale=1, | |
| type="messages", | |
| latex_delimiters=latex_delimiters, | |
| label="Model with Enhanced Reasoning" | |
| ) | |
| with gr.Row(): | |
| # msg ν μ€νΈλ°μ€λ₯Ό λ¨Όμ μ μ | |
| msg = gr.Textbox( | |
| submit_btn=True, | |
| label="", | |
| show_label=False, | |
| placeholder="μ¬κΈ°μ μ§λ¬Έμ μ λ ₯νμΈμ.", | |
| autofocus=True, | |
| ) | |
| # νΌλλ°± λ²νΌ | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| feedback_btn_pos = gr.Button("π μ΄ μΆλ‘ μ΄ λμμ΄ λμμ΅λλ€") | |
| with gr.Column(scale=1): | |
| feedback_btn_neg = gr.Button("π μ΄ μΆλ‘ μ κ°μ μ΄ νμν©λλ€") | |
| with gr.Column(scale=1): | |
| clear_memory_btn = gr.Button("π§Ή λ©λͺ¨λ¦¬ μ΄κΈ°ν") | |
| # λ©λͺ¨λ¦¬ μκ°ν ν | |
| with gr.TabItem("λ©λͺ¨λ¦¬ μκ°ν"): | |
| gr.Markdown("## μλ§¨ν± λ©λͺ¨λ¦¬ λ΄μ©") | |
| semantic_memory_display = gr.Textbox( | |
| label="νμ¬ μλ§¨ν± λ©λͺ¨λ¦¬ λ΄μ©", | |
| placeholder="μμ§ λ©λͺ¨λ¦¬κ° μμ΅λλ€.", | |
| lines=10, | |
| max_lines=20, | |
| interactive=False | |
| ) | |
| gr.Markdown("## κ·Έλν μ§μλ² μ΄μ€") | |
| graph_memory_display = gr.Textbox( | |
| label="νμ¬ κ·Έλν λ©λͺ¨λ¦¬ λ΄μ©", | |
| placeholder="μμ§ κ·Έλν λ Έλκ° μμ΅λλ€.", | |
| lines=10, | |
| max_lines=20, | |
| interactive=False | |
| ) | |
| # μμ μΉμ - msg λ³μ μ μ μ΄νμ λ°°μΉ | |
| with gr.Accordion("EXAMPLES", open=False): | |
| examples = gr.Examples( | |
| examples=[ | |
| "[μΆμ²: MATH-500)] μ²μ 100κ°μ μμ μ μ μ€μμ 3, 4, 5λ‘ λλμ΄ λ¨μ΄μ§λ μλ λͺ κ°μ λκΉ?", | |
| "[μΆμ²: MATH-500)] μν¬μ λ μμ λ μμ€ν μ λ νΉν©λλ€. νΈλ§ν· 1κ°λ λΈλ§ν· 4κ°μ κ°κ³ , λΈλ§ν· 3κ°λ λλ§ν¬ 7κ°μ κ°μ΅λλ€. νΈλ§ν·μμ λλ§ν¬ 56κ°μ κ°μΉλ μΌλ§μ λκΉ?", | |
| "[μΆμ²: MATH-500)] μμ΄λ―Έ, λ²€, ν¬λ¦¬μ€μ νκ· λμ΄λ 6μ΄μ λλ€. 4λ μ ν¬λ¦¬μ€λ μ§κΈ μμ΄λ―Έμ κ°μ λμ΄μμ΅λλ€. 4λ ν λ²€μ λμ΄λ κ·Έλ μμ΄λ―Έμ λμ΄μ $\\frac{3}{5}$κ° λ κ²μ λλ€. ν¬λ¦¬μ€λ μ§κΈ λͺ μ΄μ λκΉ?", | |
| "[μΆμ²: MATH-500)] λ Έλμκ³Ό νλμ ꡬμ¬μ΄ λ€μ΄ μλ κ°λ°©μ΄ μμ΅λλ€. νμ¬ νλμ ꡬμ¬κ³Ό λ Έλμ ꡬμ¬μ λΉμ¨μ 4:3μ λλ€. νλμ κ΅¬μ¬ 5κ°λ₯Ό λνκ³ λ Έλμ κ΅¬μ¬ 3κ°λ₯Ό μ κ±°νλ©΄ λΉμ¨μ 7:3μ΄ λ©λλ€. λ λ£κΈ° μ μ κ°λ°©μ νλμ ꡬμ¬μ΄ λͺ κ° μμμ΅λκΉ?" | |
| ], | |
| inputs=msg | |
| ) | |
| with gr.Accordion("λ§€κ°λ³μ μ‘°μ ", open=False): | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_dropdown = gr.Dropdown( | |
| ["CohereForAI/c4ai-command-r7b-arabic-02-2025", "meta-llama/Meta-Llama-3-8B-Instruct"], | |
| label="λͺ¨λΈ μ ν", | |
| value="CohereForAI/c4ai-command-r7b-arabic-02-2025" | |
| ) | |
| num_tokens = gr.Slider( | |
| 50, | |
| 4000, | |
| 2000, | |
| step=1, | |
| label="μΆλ‘ λ¨κ³λΉ μ΅λ ν ν° μ", | |
| interactive=True, | |
| ) | |
| final_num_tokens = gr.Slider( | |
| 50, | |
| 4000, | |
| 2000, | |
| step=1, | |
| label="μ΅μ’ λ΅λ³μ μ΅λ ν ν° μ", | |
| interactive=True, | |
| ) | |
| with gr.Column(): | |
| do_sample = gr.Checkbox(True, label="μνλ§ μ¬μ©") | |
| temperature = gr.Slider(0.1, 1.0, 0.7, step=0.1, label="μ¨λ") | |
| memory_weight = gr.Slider(0.0, 1.0, 0.5, step=0.1, label="λ©λͺ¨λ¦¬ λ°μ κ°μ€μΉ") | |
| # νΌλλ°± μ²λ¦¬ ν¨μ | |
| def process_positive_feedback(): | |
| nonlocal buffer_manager, current_contexts | |
| if buffer_manager: | |
| buffer_manager.update_retrieval_reward(current_contexts, reward=1.0) | |
| return "νΌλλ°± κ°μ¬ν©λλ€! μ΄ μ κ·Ό λ°©μμ ν₯ν μ μ¬ν μ§λ¬Έμ λ μμ£Ό μ¬μ©νκ² μ΅λλ€." | |
| def process_negative_feedback(): | |
| nonlocal buffer_manager, current_contexts | |
| if buffer_manager: | |
| buffer_manager.update_retrieval_reward(current_contexts, reward=-0.5) | |
| return "νΌλλ°± κ°μ¬ν©λλ€! μ΄ μ κ·Ό λ°©μμ κ°μ νκ² μ΅λλ€." | |
| def clear_memory(): | |
| nonlocal buffer_manager | |
| if buffer_manager: | |
| buffer_manager.clear() | |
| return "λ©λͺ¨λ¦¬κ° μ΄κΈ°νλμμ΅λλ€." | |
| def update_memory_displays(): | |
| nonlocal buffer_manager | |
| if not buffer_manager: | |
| return "λ©λͺ¨λ¦¬κ° μ΄κΈ°νλμ§ μμμ΅λλ€.", "κ·Έλνκ° μ΄κΈ°νλμ§ μμμ΅λλ€." | |
| semantic_text = "νμ¬ μ μ₯λ λ©λͺ¨λ¦¬:\n\n" | |
| for i, mem in enumerate(buffer_manager.semantic_memory.memories[:5]): # μ΅λ 5κ°λ§ νμ | |
| semantic_text += f"{i+1}. {mem['text'][:100]}...\n\n" | |
| graph_text = "νμ¬ κ·Έλν λ Έλ:\n\n" | |
| for node in buffer_manager.graph_memory.graph.nodes(): | |
| node_text = buffer_manager.graph_memory.get_text_by_node(node) | |
| neighbors = list(buffer_manager.graph_memory.graph.neighbors(node)) | |
| graph_text += f"λ Έλ: {node}\nμ€λͺ : {node_text[:50]}...\nμ°κ²°: {', '.join(neighbors[:3])}\n\n" | |
| return semantic_text, graph_text | |
| # μ΄κΈ°ν ν¨μ | |
| def initialize_models(): | |
| nonlocal pipe, buffer_manager, model_name | |
| try: | |
| pipe, buffer_manager = initialize_model_and_manager(model_name) | |
| semantic_text, graph_text = update_memory_displays() | |
| return "λͺ¨λΈμ΄ μ΄κΈ°νλμμ΅λλ€.", semantic_text, graph_text | |
| except Exception as e: | |
| return f"λͺ¨λΈ μ΄κΈ°ν μ€λ₯: {str(e)}", "", "" | |
| # λͺ¨λΈ μ ν λ³κ²½ μ μ²λ¦¬ | |
| def change_model(new_model_name): | |
| nonlocal model_name | |
| model_name = new_model_name | |
| status, semantic_text, graph_text = initialize_models() | |
| return status, semantic_text, graph_text | |
| # μ΄κΈ°ν ν¨μ μ€ν | |
| model_dropdown.change( | |
| change_model, | |
| [model_dropdown], | |
| [gr.Textbox(visible=False), semantic_memory_display, graph_memory_display] | |
| ) | |
| # νΌλλ°± λ²νΌμ ν¨μ μ°κ²° | |
| feedback_btn_pos.click(process_positive_feedback, [], gr.Textbox(visible=False)) | |
| feedback_btn_neg.click(process_negative_feedback, [], gr.Textbox(visible=False)) | |
| clear_memory_btn.click(clear_memory, [], gr.Textbox(visible=False)) | |
| # ν λ³κ²½ μ λ©λͺ¨λ¦¬ λμ€νλ μ΄ μ λ°μ΄νΈ | |
| tabs.change(update_memory_displays, [], [semantic_memory_display, graph_memory_display]) | |
| # μ¬μ©μκ° λ©μμ§λ₯Ό μ μΆνλ©΄ λ λ΄μ΄ λμμ μλ΅ν©λλ€ | |
| msg.submit( | |
| user_input, | |
| [msg, chatbot_original, chatbot_thinking], # μ λ ₯ | |
| [msg, chatbot_original, chatbot_thinking], # μΆλ ₯ | |
| ).then( | |
| lambda h, n, d, t, p: bot_original(h, n, d, t, p), # pipe λ§€κ°λ³μ μΆκ° | |
| [ | |
| chatbot_original, | |
| num_tokens, | |
| do_sample, | |
| temperature, | |
| gr.Textbox(value=lambda: pipe, visible=False), # pipe μ λ¬ | |
| ], | |
| chatbot_original, # μΆλ ₯μμ μ νμ€ν 리 μ μ₯ | |
| ).then( | |
| lambda h, n, f, d, t, p, b: bot_thinking_enhanced(h, n, f, d, t, p, b), # λ§€κ°λ³μ μΆκ° | |
| [ | |
| chatbot_thinking, | |
| num_tokens, | |
| final_num_tokens, | |
| do_sample, | |
| temperature, | |
| gr.Textbox(value=lambda: pipe, visible=False), # pipe μ λ¬ | |
| gr.Textbox(value=lambda: buffer_manager, visible=False), # buffer_manager μ λ¬ | |
| ], | |
| chatbot_thinking, # μΆλ ₯μμ μ νμ€ν 리 μ μ₯ | |
| ).then( | |
| update_memory_displays, | |
| [], | |
| [semantic_memory_display, graph_memory_display] | |
| ) | |
| # μμ μ λͺ¨λΈ μ΄κΈ°νλ₯Ό μν μ½λ | |
| def load_on_startup(): | |
| global pipe, buffer_manager | |
| try: | |
| # κΈ°λ³Έ λͺ¨λΈ μ΄κΈ°ν | |
| pipe, buffer_manager = initialize_model_and_manager( | |
| "CohereForAI/c4ai-command-r7b-arabic-02-2025" | |
| ) | |
| logger.info("λͺ¨λΈ λ° λ²νΌ λ§€λμ κ° μ±κ³΅μ μΌλ‘ μ΄κΈ°νλμμ΅λλ€.") | |
| except Exception as e: | |
| logger.error(f"μμ μ λͺ¨λΈ μ΄κΈ°ν μ€ν¨: {str(e)}") | |
| if __name__ == "__main__": | |
| # μμ© νλ‘κ·Έλ¨ μμ μ μ λͺ¨λΈ μ΄κΈ°ν | |
| load_on_startup() | |
| # λκΈ°μ΄ λ° μλ² μμ | |
| demo.queue().launch( | |
| share=False, | |
| debug=True, | |
| title="Enhanced ThinkFlow with DeepSeek Features" | |
| ) |