LiamKhoaLe commited on
Commit
3cca1c2
·
1 Parent(s): 4a82770

Upd smartRAG

Browse files
Files changed (1) hide show
  1. memory.py +97 -44
memory.py CHANGED
@@ -1,79 +1,118 @@
1
  # memory.py
2
- import re
 
 
3
  import numpy as np
4
  import faiss
5
- from collections import defaultdict, deque
6
- from typing import List
7
  from sentence_transformers import SentenceTransformer
8
  from google import genai # must be configured in app.py and imported globally
9
  import logging
10
 
11
- _LLM = "gemini-2.5-flash-lite-preview-06-17" # Small model for NLP simple tasks
12
  # Load embedding model
13
- embedding_model = SentenceTransformer("/app/model_cache", device="cpu").half()
14
  logger = logging.getLogger("medical-chatbot")
15
 
16
  class MemoryManager:
17
- def __init__(self, max_users=1000, history_per_user=10):
18
- self.text_cache = defaultdict(lambda: deque(maxlen=history_per_user))
19
- self.chunk_index = defaultdict(lambda: faiss.IndexFlatL2(384))
20
- self.chunk_texts = defaultdict(list)
21
- self.user_queue = deque(maxlen=max_users)
 
 
22
 
 
23
  def add_exchange(self, user_id: str, query: str, response: str, lang: str = "EN"):
24
- if user_id not in self.text_cache:
25
- if len(self.user_queue) >= self.user_queue.maxlen:
26
- oldest = self.user_queue.popleft()
27
- self._drop_user(oldest)
28
- self.user_queue.append(user_id)
29
- # Normalize
30
  self.text_cache[user_id].append((query.strip(), response.strip()))
31
- # Use Gemini to summarize and chunk smartly
32
- chunks = self.chunk_response(response, lang)
33
- # Encode chunk
 
 
 
 
 
34
  for chunk in chunks:
35
- vec = embedding_model.encode(chunk, convert_to_numpy=True)
36
  self.chunk_index[user_id].add(np.array([vec]))
37
- self.chunk_texts[user_id].append(chunk)
 
 
 
38
 
39
- def get_relevant_chunks(self, user_id: str, query: str, top_k: int = 2):
40
- if user_id not in self.chunk_index or self.chunk_index[user_id].ntotal == 0:
 
41
  return []
42
- # Encode user query
43
- vec = embedding_model.encode(query, convert_to_numpy=True)
44
- D, I = self.chunk_index[user_id].search(np.array([vec]), k=top_k)
45
- return [self.chunk_texts[user_id][i] for i in I[0] if i < len(self.chunk_texts[user_id])]
 
 
 
 
46
 
47
- def get_context(self, user_id: str, num_turns: int = 3):
48
  history = list(self.text_cache.get(user_id, []))[-num_turns:]
49
  return "\n".join(f"User: {q}\nBot: {r}" for q, r in history)
50
 
51
  def reset(self, user_id: str):
52
  self._drop_user(user_id)
 
 
 
 
 
53
  if user_id in self.user_queue:
54
  self.user_queue.remove(user_id)
 
55
 
56
- def _drop_user(self, user_id):
57
  self.text_cache.pop(user_id, None)
58
  self.chunk_index.pop(user_id, None)
59
- self.chunk_texts.pop(user_id, None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- def chunk_response(self, response: str, lang: str) -> List[str]:
 
 
 
 
 
 
62
  """
63
- Use Gemini to translate (if needed), summarize, and chunk the response.
64
- Assumes Gemini API is configured via google.genai globally in app.py.
 
 
 
65
  """
66
- # Full instruction
67
  instructions = []
68
- # Only add translation if necessary
69
  if lang.upper() != "EN":
70
  instructions.append("- Translate the response to English.")
71
  instructions.append("- Break the translated (or original) text into semantically distinct parts, grouped by medical topic or symptom.")
72
  instructions.append("- For each part, generate a clear, concise summary. The summary may vary in length depending on the complexity of the topic — do not omit key clinical instructions.")
73
  instructions.append("- Separate each part using three dashes `---` on a new line.")
74
- # Grouped sub-instructions
75
- joined_instructions = "\n".join(instructions)
76
- # Prompting
77
  prompt = f"""
78
  You are a medical assistant helping organize and condense a clinical response.
79
  Below is the user-provided medical response written in `{lang}`:
@@ -81,19 +120,33 @@ class MemoryManager:
81
  {response}
82
  ------------------------
83
  Please perform the following tasks:
84
- {joined_instructions}
 
85
  Output only the structured summaries, separated by dashes.
86
  """
87
  try:
88
  client = genai.Client()
89
  result = client.models.generate_content(
90
- model=_LLM,
91
  contents=prompt,
92
  generation_config={"temperature": 0.4}
93
  )
94
  output = result.text.strip()
95
- logger.info(f"Reasoned RAG result: {output}")
96
- return [chunk.strip() for chunk in output.split('---') if chunk.strip()]
 
 
 
97
  except Exception as e:
98
- print(f"❌ Gemini chunking failed: {e}")
99
- return [response.strip()]
 
 
 
 
 
 
 
 
 
 
 
1
  # memory.py
2
+ import re, time, hashlib, asyncio
3
+ from collections import defaultdict, deque
4
+ from typing import List, Dict
5
  import numpy as np
6
  import faiss
 
 
7
  from sentence_transformers import SentenceTransformer
8
  from google import genai # must be configured in app.py and imported globally
9
  import logging
10
 
11
+ _LLM_SMALL = "gemini-2.5-flash-lite-preview-06-17"
12
  # Load embedding model
13
+ EMBED = SentenceTransformer("/app/model_cache", device="cpu").half()
14
  logger = logging.getLogger("medical-chatbot")
15
 
16
  class MemoryManager:
17
+ def __init__(self, max_users=1000, history_per_user=10, max_chunks=30):
18
+ self.text_cache = defaultdict(lambda: deque(maxlen=history_per_user))
19
+ self.chunk_index = defaultdict(self._new_index) # user_id -> faiss index
20
+ self.chunk_meta = defaultdict(list) # '' -> list[{text,tag}]
21
+ self.user_queue = deque(maxlen=max_users) # LRU of users
22
+ self.max_chunks = max_chunks # hard cap per user
23
+ self.chunk_cache = {} # hash(query+resp) -> [chunks]
24
 
25
+ # ---------- Public API ----------
26
  def add_exchange(self, user_id: str, query: str, response: str, lang: str = "EN"):
27
+ self._touch_user(user_id)
 
 
 
 
 
28
  self.text_cache[user_id].append((query.strip(), response.strip()))
29
+ # Avoid re-chunking identical response
30
+ cache_key = hashlib.md5((query + response).encode()).hexdigest()
31
+ if cache_key in self.chunk_cache:
32
+ chunks = self.chunk_cache[cache_key]
33
+ else:
34
+ chunks = self._chunk_and_tag(response, lang)
35
+ self.chunk_cache[cache_key] = chunks
36
+ # Store chunks → faiss
37
  for chunk in chunks:
38
+ vec = self._embed(chunk["text"])
39
  self.chunk_index[user_id].add(np.array([vec]))
40
+ self.chunk_meta[user_id].append(chunk)
41
+ # Trim to max_chunks to keep latency O(1)
42
+ if len(self.chunk_meta[user_id]) > self.max_chunks:
43
+ self._rebuild_index(user_id, keep_last=self.max_chunks)
44
 
45
+ def get_relevant_chunks(self, user_id: str, query: str, top_k: int = 3, min_sim: float = 0.30) -> List[str]:
46
+ """Return texts of chunks whose cosine similarity min_sim."""
47
+ if self.chunk_index[user_id].ntotal == 0:
48
  return []
49
+ # Encode chunk
50
+ qvec = self._embed(query)
51
+ sims, idxs = self.chunk_index[user_id].search(np.array([qvec]), k=top_k)
52
+ results = []
53
+ for sim, idx in zip(sims[0], idxs[0]):
54
+ if idx < len(self.chunk_meta[user_id]) and sim >= min_sim:
55
+ results.append(self.chunk_meta[user_id][idx]["text"])
56
+ return results
57
 
58
+ def get_context(self, user_id: str, num_turns: int = 3) -> str:
59
  history = list(self.text_cache.get(user_id, []))[-num_turns:]
60
  return "\n".join(f"User: {q}\nBot: {r}" for q, r in history)
61
 
62
  def reset(self, user_id: str):
63
  self._drop_user(user_id)
64
+
65
+ # ---------- Internal helpers ----------
66
+ def _touch_user(self, user_id: str):
67
+ if user_id not in self.text_cache and len(self.user_queue) >= self.user_queue.maxlen:
68
+ self._drop_user(self.user_queue.popleft())
69
  if user_id in self.user_queue:
70
  self.user_queue.remove(user_id)
71
+ self.user_queue.append(user_id)
72
 
73
+ def _drop_user(self, user_id: str):
74
  self.text_cache.pop(user_id, None)
75
  self.chunk_index.pop(user_id, None)
76
+ self.chunk_meta.pop(user_id, None)
77
+ if user_id in self.user_queue:
78
+ self.user_queue.remove(user_id)
79
+
80
+ def _rebuild_index(self, user_id: str, keep_last: int):
81
+ """Trim chunk list + rebuild FAISS index for user."""
82
+ self.chunk_meta[user_id] = self.chunk_meta[user_id][-keep_last:]
83
+ index = self._new_index()
84
+ for chunk in self.chunk_meta[user_id]:
85
+ vec = self._embed(chunk["text"])
86
+ index.add(np.array([vec]))
87
+ self.chunk_index[user_id] = index
88
+
89
+ @staticmethod
90
+ def _new_index():
91
+ # Use cosine similarity (vectors must be L2-normalised)
92
+ return faiss.IndexFlatIP(384)
93
 
94
+ @staticmethod
95
+ def _embed(text: str):
96
+ vec = EMBED.encode(text, convert_to_numpy=True)
97
+ # L2 normalise for cosine on IndexFlatIP
98
+ return vec / (np.linalg.norm(vec) + 1e-9)
99
+
100
+ def chunk_response(self, response: str, lang: str) -> List[Dict]:
101
  """
102
+ Calls Gemini to:
103
+ - Translate (if needed)
104
+ - Chunk by context/topic
105
+ - Summarise
106
+ Returns: [{"tag": ..., "text": ...}, ...]
107
  """
108
+ # Gemini instruction
109
  instructions = []
 
110
  if lang.upper() != "EN":
111
  instructions.append("- Translate the response to English.")
112
  instructions.append("- Break the translated (or original) text into semantically distinct parts, grouped by medical topic or symptom.")
113
  instructions.append("- For each part, generate a clear, concise summary. The summary may vary in length depending on the complexity of the topic — do not omit key clinical instructions.")
114
  instructions.append("- Separate each part using three dashes `---` on a new line.")
115
+ # Gemini prompt
 
 
116
  prompt = f"""
117
  You are a medical assistant helping organize and condense a clinical response.
118
  Below is the user-provided medical response written in `{lang}`:
 
120
  {response}
121
  ------------------------
122
  Please perform the following tasks:
123
+ {chr(10).join(instructions)}
124
+
125
  Output only the structured summaries, separated by dashes.
126
  """
127
  try:
128
  client = genai.Client()
129
  result = client.models.generate_content(
130
+ model=_LLM_SMALL,
131
  contents=prompt,
132
  generation_config={"temperature": 0.4}
133
  )
134
  output = result.text.strip()
135
+ logger.info(f"📦 Gemini summarized chunk output: {output}")
136
+ return [
137
+ {"tag": self._quick_extract_topic(chunk), "text": chunk.strip()}
138
+ for chunk in output.split('---') if chunk.strip()
139
+ ]
140
  except Exception as e:
141
+ logger.warning(f"❌ Gemini chunking failed: {e}")
142
+ return [{"tag": "general", "text": response.strip()}]
143
+
144
+ @staticmethod
145
+ def _quick_extract_topic(chunk: str) -> str:
146
+ """Heuristically extract the topic from a chunk (title line or first 3 words)."""
147
+ lines = chunk.strip().splitlines()
148
+ for line in lines:
149
+ if len(line.split()) <= 8 and line.strip().endswith(":"):
150
+ return line.strip().rstrip(":")
151
+ return " ".join(chunk.split()[:3]).rstrip(":.,")
152
+