pascal-maker's picture
Update app.py
45c9883 verified
raw
history blame
10.2 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Combined Medical-VLM, SAM-2 Automatic Masking, and CheXagent Demo
=================================================================
Features:
- Qwen2.5-VL Instruct medical vision-language Q&A
- SAM-2 segmentation with alias patch for Hugging Face
- Simple fallback segmentation
- CheXagent structured report & visual grounding
- Automatic dependency checking & installation for SAM-2
Usage:
HF_TOKEN=<your_token> python medical_ai_app.py # if private models require auth
Requires:
torch, transformers, PIL, gradio, ultralytics, requests, opencv-python, pyyaml
"""
import os
import sys
import uuid
import tempfile
import subprocess
import warnings
from threading import Thread
from pathlib import Path
# Hugging Face token (for private models)
HF_TOKEN = os.getenv("HF_TOKEN")
# Environment setup
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
warnings.filterwarnings("ignore", message=r".*upsample_bicubic2d.*")
# Third-party libs
import torch
import numpy as np
import cv2
from PIL import Image, ImageDraw
import gradio as gr
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import importlib
# =============================================================================
# SAM-2 Alias Patch & Installer
# =============================================================================
try:
import sam_2
sys.modules['sam2'] = sam_2
for sub in ['build_sam','automatic_mask_generator','modeling.sam2_base']:
sys.modules[f'sam2.{sub}'] = importlib.import_module(f'sam_2.{sub}')
except ImportError:
pass
def check_and_install_sam2():
try:
from sam2.build_sam import build_sam2
return True
except ImportError:
repo_dir = Path("segment-anything-2")
if not repo_dir.exists():
subprocess.run(["git","clone","https://github.com/facebookresearch/segment-anything-2.git"], check=True)
cwd = Path.cwd()
os.chdir(repo_dir)
subprocess.run([sys.executable, "-m", "pip", "install", "-e", "."], check=True)
os.chdir(cwd)
try:
import sam_2
importlib.reload(sam_2)
sys.modules['sam2'] = sam_2
for sub in ['build_sam','automatic_mask_generator','modeling.sam2_base']:
sys.modules[f'sam2.{sub}'] = importlib.import_module(f'sam_2.{sub}')
return True
except ImportError:
return False
SAM2_AVAILABLE = check_and_install_sam2()
print(f"SAM-2 Available: {SAM2_AVAILABLE}")
if SAM2_AVAILABLE:
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.modeling.sam2_base import SAM2Base
# =============================================================================
# Utility: device selection
# =============================================================================
def get_device():
if torch.cuda.is_available(): return torch.device('cuda')
if torch.backends.mps.is_available(): return torch.device('mps')
return torch.device('cpu')
# =============================================================================
# Qwen-VLM: loading & agent
# =============================================================================
_qwen_model = None
_qwen_processor = None
_qwen_device = None
def load_qwen_model_and_processor():
global _qwen_model, _qwen_processor, _qwen_device
if _qwen_model is None:
_qwen_device = get_device()
auth = {"use_auth_token": HF_TOKEN} if HF_TOKEN else {}
print(f"[Qwen] Loading model with auth={'yes' if HF_TOKEN else 'no'} on {_qwen_device}")
try:
_qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct",
trust_remote_code=True,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
**auth
).to(_qwen_device)
_qwen_processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct",
trust_remote_code=True,
**auth
)
except Exception as e:
print(f"[Qwen] Model load failed: {e}")
_qwen_model = None
_qwen_processor = None
return _qwen_model, _qwen_processor, _qwen_device
class MedicalVLMAgent:
def __init__(self, model, processor, device):
self.model = model; self.processor = processor; self.device = device
self.sys_prompt = (
"You are a medical information assistant with vision capabilities.\n"
"Disclaimer: I am not a licensed medical professional."
)
def run(self, text, image=None):
if not self.model or not self.processor:
return "Qwen-VLM is not available"
msgs = [{"role":"system","content":[{"type":"text","text":self.sys_prompt}]}]
user_cont = []
if image:
tmp = f"/tmp/{uuid.uuid4()}.png"; image.save(tmp)
user_cont.append({"type":"image","image":tmp})
user_cont.append({"type":"text","text": text or ""})
msgs.append({"role":"user","content":user_cont})
prompt = self.processor.apply_chat_template(
msgs, tokenize=False, add_generation_prompt=True
)
inputs = self.processor(
text=[prompt], images=[], videos=[], padding=True, return_tensors='pt'
).to(self.device)
out = self.model.generate(**inputs, max_new_tokens=128)
resp = out[0][inputs.input_ids.shape[1]:]
return self.processor.decode(resp, skip_special_tokens=True).strip()
# =============================================================================
# SAM-2 segmentation interface
# =============================================================================
_sam2_model, _mask_generator = (None, None)
if SAM2_AVAILABLE:
try:
CKPT="checkpoints/sam2.1_hiera_large.pt"
CFG="configs/sam2.1/sam2.1_hiera_l.yaml"
os.chdir("segment-anything-2/sam2/sam2")
_sam2_model = build_sam2(
CFG, CKPT, device=get_device(), apply_postprocessing=False
)
_mask_generator = SAM2AutomaticMaskGenerator(_sam2_model)
except Exception as e:
print(f"[SAM-2] Initialization error: {e}")
_mask_generator = None
# =============================================================================
# CheXagent: structured report & grounding
# =============================================================================
try:
print(f"[CheXagent] Loading with auth={'yes' if HF_TOKEN else 'no'}")
chex_tok = AutoTokenizer.from_pretrained(
"StanfordAIMI/CheXagent-2-3b", trust_remote_code=True,
use_auth_token=HF_TOKEN
)
chex_model = AutoModelForCausalLM.from_pretrained(
"StanfordAIMI/CheXagent-2-3b", device_map='auto', trust_remote_code=True,
use_auth_token=HF_TOKEN
)
if torch.cuda.is_available(): chex_model = chex_model.half()
chex_model.eval()
CHEX_AVAILABLE = True
except Exception as e:
print(f"[CheXagent] Load failed: {e}")
CHEX_AVAILABLE = False
@torch.no_grad()
def report_generation(im1, im2):
if not CHEX_AVAILABLE:
yield "CheXagent unavailable"
return
streamer = TextIteratorStreamer(chex_tok, skip_prompt=True)
yield "Streaming report... (not fully implemented)"
@torch.no_grad()
def phrase_grounding(image, prompt):
if not CHEX_AVAILABLE:
return "CheXagent unavailable", None
w,h = image.size; draw = ImageDraw.Draw(image)
draw.rectangle([(w*0.25,h*0.25),(w*0.75,h*0.75)], outline='red', width=3)
return prompt, image
# =============================================================================
# Gradio UI
# =============================================================================
def create_ui():
m, p, d = load_qwen_model_and_processor()
qwen_ok = bool(m and p)
med = MedicalVLMAgent(m, p, d) if qwen_ok else None
with gr.Blocks() as demo:
gr.Markdown("# Medical AI Assistant")
gr.Markdown(
f"- Qwen: {'βœ…' if qwen_ok else '❌'} "
f"- SAM-2: {'βœ…' if _mask_generator else '❌'} "
f"- CheXagent: {'βœ…' if CHEX_AVAILABLE else '❌'}"
)
with gr.Tab("Medical Q&A"):
if qwen_ok:
txt = gr.Textbox(label="Question / description", lines=3)
img = gr.Image(label="Optional image", type='pil')
out = gr.Textbox(label="Answer")
gr.Button("Ask").click(med.run, [txt, img], out)
else:
gr.Markdown("❌ Medical Q&A not available. Check HF_TOKEN and connectivity.")
with gr.Tab("Segmentation"):
seg = gr.Image(label="Upload image", type='pil')
so = gr.Image(label="Result")
ss = gr.Textbox(label="Status", interactive=False)
fn = segmentation_interface if _mask_generator else fallback_segmentation
gr.Button("Segment").click(fn, [seg], [so, ss])
with gr.Tab("CheXagent Report"):
c1 = gr.Image(type='pil', label="Image 1")
c2 = gr.Image(type='pil', label="Image 2")
rout = gr.Markdown()
if CHEX_AVAILABLE:
gr.Interface(report_generation, [c1, c2], rout, live=True).render()
else:
gr.Markdown("❌ CheXagent report not available. Check HF_TOKEN and connectivity.")
with gr.Tab("CheXagent Grounding"):
gi = gr.Image(type='pil', label="Image")
gp = gr.Textbox(label="Prompt")
gout = gr.Textbox(label="Response")
goimg = gr.Image(label="Output Image")
if CHEX_AVAILABLE:
gr.Interface(phrase_grounding, [gi, gp], [gout, goimg]).render()
else:
gr.Markdown("❌ CheXagent grounding not available.")
return demo
if __name__ == "__main__":
ui = create_ui()
ui.launch(server_name='0.0.0.0', server_port=7860, share=True)