pascal-maker's picture
Update app.py
4d030cf verified
raw
history blame
9.47 kB
##!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Combined Medical-VLM, SAM-2 Automatic Masking, and CheXagent Demo
=================================================================
Features:
- Qwen2.5-VL Instruct medical vision-language Q&A
- SAM-2 segmentation with alias patch for Hugging Face
- Simple fallback segmentation
- CheXagent structured report & visual grounding
- Automatic dependency checking & installation for SAM-2
Usage:
python medical_ai_app.py # launches Gradio UI on port 7860
Requires:
torch, transformers, PIL, gradio, ultralytics, requests, opencv-python, pyyaml
"""
import os
import sys
import uuid
import tempfile
import subprocess
import warnings
from threading import Thread
from pathlib import Path
# Environment setup
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
warnings.filterwarnings("ignore", message=r".*upsample_bicubic2d.*")
# Third-party libs
import torch
import numpy as np
import cv2
from PIL import Image, ImageDraw
import gradio as gr
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
# =============================================================================
# SAM-2 Alias Patch & Installer
# =============================================================================
try:
import sam_2, importlib
sys.modules['sam2'] = sam_2
for sub in ['build_sam','automatic_mask_generator','modeling.sam2_base']:
sys.modules[f'sam2.{sub}'] = importlib.import_module(f'sam_2.{sub}')
except ImportError:
pass
def check_and_install_sam2():
try:
from sam2.build_sam import build_sam2
return True
except ImportError:
repo_dir = Path("segment-anything-2")
if not repo_dir.exists():
subprocess.run(["git","clone","https://github.com/facebookresearch/segment-anything-2.git"], check=True)
cwd = Path.cwd()
os.chdir(repo_dir)
subprocess.run([sys.executable, "-m", "pip", "install", "-e", "."], check=True)
os.chdir(cwd)
try:
import sam_2
importlib.reload(sam_2)
sys.modules['sam2'] = sam_2
for sub in ['build_sam','automatic_mask_generator','modeling.sam2_base']:
sys.modules[f'sam2.{sub}'] = importlib.import_module(f'sam_2.{sub}')
return True
except ImportError:
return False
SAM2_AVAILABLE = check_and_install_sam2()
print(f"SAM-2 Available: {SAM2_AVAILABLE}")
if SAM2_AVAILABLE:
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.modeling.sam2_base import SAM2Base
# =============================================================================
# Utility: device selection
# =============================================================================
def get_device():
if torch.cuda.is_available(): return torch.device('cuda')
if torch.backends.mps.is_available(): return torch.device('mps')
return torch.device('cpu')
# =============================================================================
# Qwen-VLM: loading & agent
# =============================================================================
_qwen_model = None
_qwen_processor = None
_qwen_device = None
def load_qwen_model_and_processor(hf_token=None):
global _qwen_model, _qwen_processor, _qwen_device
if _qwen_model is None:
_qwen_device = get_device()
auth = {"use_auth_token": hf_token} if hf_token else {}
_qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct", trust_remote_code=True,
torch_dtype=torch.float32, low_cpu_mem_usage=True, **auth
).to(_qwen_device)
_qwen_processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct", trust_remote_code=True, **auth
)
return _qwen_model, _qwen_processor, _qwen_device
class MedicalVLMAgent:
def __init__(self, model, processor, device):
self.model = model; self.processor = processor; self.device = device
self.sys_prompt = (
"You are a medical information assistant with vision capabilities.\n"
"Disclaimer: I am not a licensed medical professional."
)
def run(self, text, image=None):
msgs = [{"role":"system","content":[{"type":"text","text":self.sys_prompt}]}]
user_cont = []
if image:
tmp = f"/tmp/{uuid.uuid4()}.png"; image.save(tmp)
user_cont.append({"type":"image","image":tmp})
user_cont.append({"type":"text","text": text or ""})
msgs.append({"role":"user","content":user_cont})
prompt = self.processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
inputs = self.processor(text=[prompt], images=[], videos=[], padding=True, return_tensors='pt').to(self.device)
out = self.model.generate(**inputs, max_new_tokens=128)
resp = out[0][inputs.input_ids.shape[1]:]
return self.processor.decode(resp, skip_special_tokens=True).strip()
# =============================================================================
# SAM-2 segmentation interface
# =============================================================================
_sam2_model, _mask_generator = (None, None)
if SAM2_AVAILABLE:
try:
CKPT="checkpoints/sam2.1_hiera_large.pt"; CFG="configs/sam2.1/sam2.1_hiera_l.yaml"
os.chdir("segment-anything-2/sam2/sam2")
_sam2_model = build_sam2(CFG, CKPT, device=get_device(), apply_postprocessing=False)
_mask_generator = SAM2AutomaticMaskGenerator(_sam2_model)
except Exception as e:
print(f"SAM-2 init error: {e}")
_mask_generator = None
def segmentation_interface(image):
if image is None: return None, "Upload an image"
if not _mask_generator: return None, "SAM-2 unavailable"
arr = np.array(image.convert('RGB'))
anns = _mask_generator.generate(arr)
overlay = arr.copy()
for ann in sorted(anns, key=lambda x: x['area'], reverse=True):
m = ann['segmentation']; color=np.random.randint(0,255,3)
overlay[m] = (overlay[m]*0.5 + color*0.5).astype(np.uint8)
return Image.fromarray(overlay), f"{len(anns)} masks found"
# =============================================================================
# Fallback segmentation
# =============================================================================
def fallback_segmentation(image):
if image is None: return None, "Upload an image"
arr = np.array(image.convert('RGB'))
gray=cv2.cvtColor(arr,cv2.COLOR_RGB2GRAY)
_,th=cv2.threshold(gray,0,255,cv2.THRESH_BINARY_INV+cv2.THRESH_OTSU)
overlay=arr.copy(); overlay[th>0]=[255,0,0]
blended=cv2.addWeighted(arr,0.7,overlay,0.3,0)
return Image.fromarray(blended), "Fallback applied"
# =============================================================================
# CheXagent: structured report & grounding
# =============================================================================
try:
chex_tok = AutoTokenizer.from_pretrained("StanfordAIMI/CheXagent-2-3b", trust_remote_code=True)
chex_model = AutoModelForCausalLM.from_pretrained("StanfordAIMI/CheXagent-2-3b", device_map='auto', trust_remote_code=True)
if torch.cuda.is_available(): chex_model = chex_model.half()
chex_model.eval(); CHEX_AVAILABLE=True
except Exception:
CHEX_AVAILABLE=False
@torch.no_grad()
def report_generation(im1, im2):
if not CHEX_AVAILABLE: yield "CheXagent unavailable"; return
streamer = TextIteratorStreamer(chex_tok, skip_prompt=True)
yield "Report streaming not fully implemented"
@torch.no_grad()
def phrase_grounding(image, prompt):
if not CHEX_AVAILABLE: return "CheXagent unavailable", None
w,h=image.size; draw=ImageDraw.Draw(image)
draw.rectangle([(w*0.25,h*0.25),(w*0.75,h*0.75)], outline='red', width=3)
return prompt, image
# =============================================================================
# Gradio UI
# =============================================================================
def create_ui():
try:
m, p, d = load_qwen_model_and_processor()
med = MedicalVLMAgent(m,p,d); QW=True
except:
QW=False; med=None
with gr.Blocks() as demo:
gr.Markdown("# Medical AI Assistant")
gr.Markdown(f"- Qwen: {'βœ…' if QW else '❌'} - SAM-2: {'βœ…' if _mask_generator else '❌'} - CheX: {'βœ…' if CHEX_AVAILABLE else '❌'}")
with gr.Tab("Medical Q&A"):
txt=gr.Textbox(); img=gr.Image(type='pil'); out=gr.Textbox(); gr.Button("Ask").click(med.run,[txt,img],out)
with gr.Tab("Segmentation"):
seg=gr.Image(type='pil'); so=gr.Image(); ss=gr.Textbox(); fn=segmentation_interface if _mask_generator else fallback_segmentation; gr.Button("Segment").click(fn,seg,[so,ss])
with gr.Tab("CheXagent Report"):
c1=gr.Image(type='pil');c2=gr.Image(type='pil'); rout=gr.Markdown(); gr.Interface(report_generation,[c1,c2],rout,live=True).render()
with gr.Tab("CheXagent Grounding"):
gi=gr.Image(type='pil'); gp=gr.Textbox(); gout=gr.Textbox(); goimg=gr.Image(); gr.Interface(phrase_grounding,[gi,gp],[gout,goimg]).render()
return demo
if __name__ == "__main__":
ui=create_ui(); ui.launch(server_name='0.0.0.0',server_port=7860,share=True)