LiamKhoaLe commited on
Commit
c29409a
·
1 Parent(s): 1ca4ee7

Migrate system using Gem Flash lite for NLP taskings

Browse files
Files changed (1) hide show
  1. memory.py +42 -42
memory.py CHANGED
@@ -5,20 +5,12 @@ import faiss
5
  from collections import defaultdict, deque
6
  from typing import List
7
  from sentence_transformers import SentenceTransformer
8
- from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
9
 
10
- # Embedding model (384d)
 
11
  embedding_model = SentenceTransformer("/app/model_cache", device="cpu").half()
12
 
13
- # English summarizer
14
- summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6", device=-1)
15
-
16
- # Lightweight MarianMT translation models (VI → EN and ZH → EN)
17
- translation_models = {
18
- "VI": pipeline("translation", model="Helsinki-NLP/opus-mt-vi-en", device=-1),
19
- "ZH": pipeline("translation", model="Helsinki-NLP/opus-mt-zh-en", device=-1)
20
- }
21
-
22
  class MemoryManager:
23
  def __init__(self, max_users=1000, history_per_user=10):
24
  self.text_cache = defaultdict(lambda: deque(maxlen=history_per_user))
@@ -32,10 +24,11 @@ class MemoryManager:
32
  oldest = self.user_queue.popleft()
33
  self._drop_user(oldest)
34
  self.user_queue.append(user_id)
35
- # Normalize info
36
  self.text_cache[user_id].append((query.strip(), response.strip()))
 
37
  chunks = self.chunk_response(response, lang)
38
- # Encode history
39
  for chunk in chunks:
40
  vec = embedding_model.encode(chunk, convert_to_numpy=True)
41
  self.chunk_index[user_id].add(np.array([vec]))
@@ -44,7 +37,7 @@ class MemoryManager:
44
  def get_relevant_chunks(self, user_id: str, query: str, top_k: int = 2):
45
  if user_id not in self.chunk_index or self.chunk_index[user_id].ntotal == 0:
46
  return []
47
- # Encode query
48
  vec = embedding_model.encode(query, convert_to_numpy=True)
49
  D, I = self.chunk_index[user_id].search(np.array([vec]), k=top_k)
50
  return [self.chunk_texts[user_id][i] for i in I[0] if i < len(self.chunk_texts[user_id])]
@@ -65,32 +58,39 @@ class MemoryManager:
65
 
66
  def chunk_response(self, response: str, lang: str) -> List[str]:
67
  """
68
- Smart multilingual chunking and summarization:
69
- - Translate VI/ZH into English for processing.
70
- - Chunk semantically.
71
- - Summarize large parts.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  """
73
- # Step 1: Translate if needed
74
- if lang.upper() in translation_models:
75
- try:
76
- translated = translation_models[lang.upper()](response, max_length=512)[0]['translation_text']
77
- except Exception:
78
- translated = response # fallback
79
- else:
80
- translated = response
81
- # Step 2: Split into rough semantic blocks
82
- raw_chunks = [c.strip() for c in re.split(
83
- r'\n{2,}|\n(?=\*\*)|(?<=\.)\s+(?=[A-Z])', translated) if c.strip()]
84
- # Step 3: Summarize long ones
85
- summarized_chunks = []
86
- for chunk in raw_chunks:
87
- if len(chunk.split()) > 50:
88
- try:
89
- summary = summarizer(chunk, max_length=60, min_length=10, do_sample=False)[0]['summary_text']
90
- summarized_chunks.append(summary.strip())
91
- except Exception:
92
- summarized_chunks.append(chunk)
93
- else:
94
- summarized_chunks.append(chunk)
95
- # Final
96
- return summarized_chunks
 
5
  from collections import defaultdict, deque
6
  from typing import List
7
  from sentence_transformers import SentenceTransformer
8
+ from google import genai # must be configured in app.py and imported globally
9
 
10
+ _LLM = "gemini-2.5-flash-lite-preview-06-17" # Small model for NLP simple tasks
11
+ # Load embedding model
12
  embedding_model = SentenceTransformer("/app/model_cache", device="cpu").half()
13
 
 
 
 
 
 
 
 
 
 
14
  class MemoryManager:
15
  def __init__(self, max_users=1000, history_per_user=10):
16
  self.text_cache = defaultdict(lambda: deque(maxlen=history_per_user))
 
24
  oldest = self.user_queue.popleft()
25
  self._drop_user(oldest)
26
  self.user_queue.append(user_id)
27
+
28
  self.text_cache[user_id].append((query.strip(), response.strip()))
29
+ # Use Gemini to summarize and chunk smartly
30
  chunks = self.chunk_response(response, lang)
31
+ # Encode chunk
32
  for chunk in chunks:
33
  vec = embedding_model.encode(chunk, convert_to_numpy=True)
34
  self.chunk_index[user_id].add(np.array([vec]))
 
37
  def get_relevant_chunks(self, user_id: str, query: str, top_k: int = 2):
38
  if user_id not in self.chunk_index or self.chunk_index[user_id].ntotal == 0:
39
  return []
40
+ # Encode user query
41
  vec = embedding_model.encode(query, convert_to_numpy=True)
42
  D, I = self.chunk_index[user_id].search(np.array([vec]), k=top_k)
43
  return [self.chunk_texts[user_id][i] for i in I[0] if i < len(self.chunk_texts[user_id])]
 
58
 
59
  def chunk_response(self, response: str, lang: str) -> List[str]:
60
  """
61
+ Use Gemini to translate (if needed), summarize, and chunk the response.
62
+ Assumes Gemini API is configured via google.genai globally in app.py.
63
+ """
64
+ # Full instruction
65
+ instructions = []
66
+ # Only add translation if necessary
67
+ if lang.upper() != "EN":
68
+ instructions.append("- Translate the response to English.")
69
+ instructions.append("- Break the translated (or original) text into semantically distinct parts, grouped by medical topic or symptom.")
70
+ instructions.append("- For each part, generate a clear, concise summary. The summary may vary in length depending on the complexity of the topic — do not omit key clinical instructions.")
71
+ instructions.append("- Separate each part using three dashes `---` on a new line.")
72
+ # Grouped sub-instructions
73
+ joined_instructions = "\n".join(instructions)
74
+ # Prompting
75
+ prompt = f"""
76
+ You are a medical assistant helping organize and condense a clinical response.
77
+ Below is the user-provided medical response written in `{lang}`:
78
+ ------------------------
79
+ {response}
80
+ ------------------------
81
+ Please perform the following tasks:
82
+ {joined_instructions}
83
+ Output only the structured summaries, separated by dashes.
84
  """
85
+ try:
86
+ client = genai.Client()
87
+ result = client.models.generate_content(
88
+ model=_LLM,
89
+ contents=prompt,
90
+ generation_config={"temperature": 0.4}
91
+ )
92
+ output = result.text.strip()
93
+ return [chunk.strip() for chunk in output.split('---') if chunk.strip()]
94
+ except Exception as e:
95
+ print(f"❌ Gemini chunking failed: {e}")
96
+ return [response.strip()]