LiamKhoaLe commited on
Commit
da736bd
·
1 Parent(s): f672b04

GRADIO CLIENT MIGRATION:

Browse files
Files changed (1) hide show
  1. vlm.py +39 -55
vlm.py CHANGED
@@ -1,70 +1,54 @@
1
- # vlm.py
2
  import os, logging, traceback, json, base64
3
  from io import BytesIO
4
  from PIL import Image
5
- from huggingface_hub import InferenceClient # Render model on HF hub
6
- from transformers import pipeline # Render model on transformers
7
  from translation import translate_query
8
-
9
- # Initialise once
10
- HF_TOKEN = os.getenv("HF_TOKEN")
11
- # client = InferenceClient(provider="auto", api_key=HF_TOKEN) # comment in back
12
 
13
  logger = logging.getLogger("vlm-agent")
14
- logging.basicConfig(level=logging.INFO, format="%(asctime)s — %(name)s — %(levelname)s — %(message)s", force=True) # Change INFO to DEBUG for full-ctx JSON loader
15
 
16
- # ✅ Load VLM pipeline once (lazy load allowed)
17
- vlm_pipe = None
18
- def load_vlm():
19
- global vlm_pipe
20
- if vlm_pipe is None:
21
- logger.info("⏳ Loading MedGEMMA model via Transformers pipeline...")
22
- vlm_pipe = pipeline("image-to-text", model="google/medgemma-4b-it", use_auth_token=HF_TOKEN, device_map="auto")
23
- logger.info(" MedGEMMA model ready.")
24
- return vlm_pipe
25
 
26
  def process_medical_image(base64_image: str, prompt: str = None, lang: str = "EN") -> str:
27
- """
28
- Send base64 image + prompt to MedGEMMA and return output.
29
- """
30
  if not prompt:
31
  prompt = "Describe and investigate any clinical findings from this medical image."
32
- elif prompt and (lang.upper() in {"VI", "ZH"}):
33
- user_query = translate_query(user_query, lang.lower())
34
- # Send over API
35
  try:
36
- # HF hub
37
- # response = client.chat.completions.create(
38
- # model="google/medgemma-4b-it",
39
- # messages=[{
40
- # "role": "user",
41
- # "content": [
42
- # {"type": "text", "text": prompt},
43
- # {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
44
- # ]
45
- # }]
46
- # )
47
- # Transformers
48
- image_data = base64.b64decode(base64_image) # Decode base64 to PIL Image
49
- image = Image.open(BytesIO(image_data)).convert("RGB")
50
- pipe = load_vlm()
51
- response = pipe(image, prompt=prompt, max_new_tokens=100)[0]["generated_text"]
52
- # Validate response
53
- if not response or not hasattr(response, "choices") or not response.choices:
54
- raise ValueError("Empty or malformed response from MedGEMMA.")
55
- # Get choice resp
56
- message = response.choices[0].message
57
- if not message or not hasattr(message, "content"):
58
- raise ValueError("MedGEMMA response missing `.content`.")
59
- # Beautify
60
- result = message.content.strip()
61
- logger.info(f"[VLM] MedGemma returned {result}")
62
- return result
63
  except Exception as e:
64
  logger.error(f"[VLM] ❌ Exception: {e}")
65
  logger.error(f"[VLM] 🔍 Traceback:\n{traceback.format_exc()}")
66
- try:
67
- logger.error(f"[VLM] ⚠️ Raw response: {json.dumps(response, default=str, indent=2)}")
68
- except:
69
- logger.warning("[VLM] ⚠️ Response not serializable.")
70
- return f"[VLM] ⚠️ Image diagnosis failed: {str(e)}"
 
 
1
  import os, logging, traceback, json, base64
2
  from io import BytesIO
3
  from PIL import Image
 
 
4
  from translation import translate_query
5
+ from gradio_client import Client, handle_file
6
+ import tempfile
 
 
7
 
8
  logger = logging.getLogger("vlm-agent")
9
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s — %(name)s — %(levelname)s — %(message)s", force=True)
10
 
11
+ # ✅ Load Gradio client once
12
+ gr_client = None
13
+ def load_gradio_client():
14
+ global gr_client
15
+ if gr_client is None:
16
+ logger.info("[VLM] Connecting to MedGEMMA Gradio Space...")
17
+ gr_client = Client("warshanks/medgemma-4b-it")
18
+ logger.info("[VLM] Gradio MedGEMMA client ready.")
19
+ return gr_client
20
 
21
  def process_medical_image(base64_image: str, prompt: str = None, lang: str = "EN") -> str:
 
 
 
22
  if not prompt:
23
  prompt = "Describe and investigate any clinical findings from this medical image."
24
+ elif lang.upper() in {"VI", "ZH"}:
25
+ prompt = translate_query(prompt, lang.lower())
26
+
27
  try:
28
+ # 1️⃣ Decode base64 image to temp file
29
+ image_data = base64.b64decode(base64_image)
30
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
31
+ tmp.write(image_data)
32
+ tmp.flush()
33
+ image_path = tmp.name
34
+
35
+ # 2️⃣ Send to Gradio MedGEMMA
36
+ client = load_gradio_client()
37
+ logger.info(f"[VLM] Sending prompt: {prompt}")
38
+ result = client.predict(
39
+ message={"text": prompt, "files": [handle_file(image_path)]},
40
+ param_2 = "You analyze medical images and report abnormalities, diseases with clear diagnostic insight."
41
+ param_3=2048,
42
+ api_name="/chat"
43
+ )
44
+ if isinstance(result, str):
45
+ logger.info(f"[VLM] Response: {result}")
46
+ return result.strip()
47
+ else:
48
+ logger.warning(f"[VLM] ⚠️ Unexpected result type: {type(result)} — {result}")
49
+ return str(result)
50
+
 
 
 
 
51
  except Exception as e:
52
  logger.error(f"[VLM] ❌ Exception: {e}")
53
  logger.error(f"[VLM] 🔍 Traceback:\n{traceback.format_exc()}")
54
+ return f"[VLM] ⚠️ Failed to process image: {e}"