LiamKhoaLe commited on
Commit
6db39d6
·
1 Parent(s): 3dcd314

Update Gemma3 VLM dynamic fuser

Browse files
Files changed (2) hide show
  1. app.py +24 -6
  2. vlm.py +31 -0
app.py CHANGED
@@ -12,7 +12,7 @@ from sentence_transformers import SentenceTransformer
12
  from sentence_transformers.util import cos_sim
13
  from memory import MemoryManager
14
  from translation import translate_query
15
-
16
 
17
  # ✅ Enable Logging for Debugging
18
  import logging
@@ -221,7 +221,7 @@ class RAGMedicalChatbot:
221
  self.model_name = model_name
222
  self.retrieve = retrieve_function
223
 
224
- def chat(self, user_id: str, user_query: str, lang: str = "EN") -> str:
225
  # 0. Translate query if not EN, this help our RAG system
226
  if lang.upper() in {"VI", "ZH"}:
227
  user_query = translate_query(user_query, lang.lower())
@@ -240,6 +240,13 @@ class RAGMedicalChatbot:
240
  parts = ["You are a medical chatbot, designed to answer medical questions."]
241
  parts.append("Please format your answer using MarkDown.")
242
  parts.append("**Bold for titles**, *italic for emphasis*, and clear headings.")
 
 
 
 
 
 
 
243
  # Historical chat retrieval case
244
  if context:
245
  parts.append("Relevant context from prior conversation:\n" + "\n".join(context))
@@ -270,11 +277,22 @@ async def chat_endpoint(req: Request):
270
  user_id = body.get("user_id", "anonymous")
271
  query = body.get("query", "").strip()
272
  lang = body.get("lang", "EN")
273
- # Error
274
- if not query:
275
- return JSONResponse({"response": "No query provided."})
 
276
  start = time.time()
277
- answer = chatbot.chat(user_id, query, lang)
 
 
 
 
 
 
 
 
 
 
278
  elapsed = time.time() - start
279
  # Final
280
  return JSONResponse({"response": f"{answer}\n\n(Response time: {elapsed:.2f}s)"})
 
12
  from sentence_transformers.util import cos_sim
13
  from memory import MemoryManager
14
  from translation import translate_query
15
+ from vlm import process_medical_image
16
 
17
  # ✅ Enable Logging for Debugging
18
  import logging
 
221
  self.model_name = model_name
222
  self.retrieve = retrieve_function
223
 
224
+ def chat(self, user_id: str, user_query: str, lang: str = "EN", image_diagnosis: str = "") -> str:
225
  # 0. Translate query if not EN, this help our RAG system
226
  if lang.upper() in {"VI", "ZH"}:
227
  user_query = translate_query(user_query, lang.lower())
 
240
  parts = ["You are a medical chatbot, designed to answer medical questions."]
241
  parts.append("Please format your answer using MarkDown.")
242
  parts.append("**Bold for titles**, *italic for emphasis*, and clear headings.")
243
+ # Append image diagnosis from VLM
244
+ if image_diagnosis:
245
+ parts.append(
246
+ "User medical image is diagnosed by VLM agent:\n"
247
+ f"{image_diagnosis}\n\n"
248
+ "➡️ Please incorporate the above findings in your response if medically relevant.\n\n"
249
+ )
250
  # Historical chat retrieval case
251
  if context:
252
  parts.append("Relevant context from prior conversation:\n" + "\n".join(context))
 
277
  user_id = body.get("user_id", "anonymous")
278
  query = body.get("query", "").strip()
279
  lang = body.get("lang", "EN")
280
+ image_base64 = body.get("image_base64", None)
281
+ # LLM Only
282
+ if not query and not image_base64:
283
+ logger.info("[BOT] LLM scenario.")
284
  start = time.time()
285
+ # If image is present → diagnose first
286
+ image_diagnosis = ""
287
+ # Img size safe processor
288
+ if image_base64 and len(image_base64.encode("utf-8")) > 5_000_000:
289
+ return JSONResponse({"response": "⚠️ Image too large. Please upload smaller images (<5MB)."})
290
+ # LLM+VLM
291
+ if image_base64:
292
+ logger.info("[BOT] VLM+LLM scenario.")
293
+ prompt = query or "Describe and investigate any clinical findings from this medical image."
294
+ image_diagnosis = process_medical_image(image_base64, prompt, lang)
295
+ answer = chatbot.chat(user_id, query, lang, image_diagnosis)
296
  elapsed = time.time() - start
297
  # Final
298
  return JSONResponse({"response": f"{answer}\n\n(Response time: {elapsed:.2f}s)"})
vlm.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vlm.py
2
+ import os
3
+ from huggingface_hub import InferenceClient
4
+ from translation import translate_query
5
+ # Initialise once
6
+ HF_TOKEN = os.getenv("HF_TOKEN")
7
+ client = InferenceClient(provider="auto", api_key=HF_TOKEN)
8
+
9
+ def process_medical_image(base64_image: str, prompt: str = None, lang: str = "EN") -> str:
10
+ """
11
+ Send base64 image + prompt to MedGEMMA and return output.
12
+ """
13
+ if not prompt:
14
+ prompt = "Describe and investigate any clinical findings from this medical image."
15
+ elif prompt and (lang.upper() in {"VI", "ZH"}):
16
+ user_query = translate_query(user_query, lang.lower())
17
+ # Send over API
18
+ try:
19
+ response = client.chat.completions.create(
20
+ model="google/medgemma-4b-it",
21
+ messages=[{
22
+ "role": "user",
23
+ "content": [
24
+ {"type": "text", "text": prompt},
25
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
26
+ ]
27
+ }]
28
+ )
29
+ return response.choices[0].message.content.strip()
30
+ except Exception as e:
31
+ return f"⚠️ Error from image diagnosis model: {e}"