Spaces:
Running
Running
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
""" | |
Combined Medical-VLM, SAM-2 Automatic Masking, and CheXagent Demo | |
================================================================= | |
Features: | |
- Qwen2.5-VL Instruct medical vision-language Q&A | |
- SAM-2 segmentation with alias patch for Hugging Face | |
- Simple fallback segmentation | |
- CheXagent structured report & visual grounding | |
- Automatic dependency checking & installation for SAM-2 | |
Usage: | |
HF_TOKEN=<your_token> python medical_ai_app.py # if private models require auth | |
Requires: | |
torch, transformers, PIL, gradio, ultralytics, requests, opencv-python, pyyaml | |
""" | |
import os | |
import sys | |
import uuid | |
import tempfile | |
import subprocess | |
import warnings | |
from threading import Thread | |
from pathlib import Path | |
# Hugging Face token (for private models) | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
# Environment setup | |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" | |
warnings.filterwarnings("ignore", message=r".*upsample_bicubic2d.*") | |
# Third-party libs | |
import torch | |
import numpy as np | |
import cv2 | |
from PIL import Image, ImageDraw | |
import gradio as gr | |
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
import importlib | |
# ============================================================================= | |
# SAM-2 Alias Patch & Installer | |
# ============================================================================= | |
try: | |
import sam_2 | |
sys.modules['sam2'] = sam_2 | |
for sub in ['build_sam','automatic_mask_generator','modeling.sam2_base']: | |
sys.modules[f'sam2.{sub}'] = importlib.import_module(f'sam_2.{sub}') | |
except ImportError: | |
pass | |
def check_and_install_sam2(): | |
try: | |
from sam2.build_sam import build_sam2 | |
return True | |
except ImportError: | |
repo_dir = Path("segment-anything-2") | |
if not repo_dir.exists(): | |
subprocess.run(["git","clone","https://github.com/facebookresearch/segment-anything-2.git"], check=True) | |
cwd = Path.cwd() | |
os.chdir(repo_dir) | |
subprocess.run([sys.executable, "-m", "pip", "install", "-e", "."], check=True) | |
os.chdir(cwd) | |
try: | |
import sam_2 | |
importlib.reload(sam_2) | |
sys.modules['sam2'] = sam_2 | |
for sub in ['build_sam','automatic_mask_generator','modeling.sam2_base']: | |
sys.modules[f'sam2.{sub}'] = importlib.import_module(f'sam_2.{sub}') | |
return True | |
except ImportError: | |
return False | |
SAM2_AVAILABLE = check_and_install_sam2() | |
print(f"SAM-2 Available: {SAM2_AVAILABLE}") | |
if SAM2_AVAILABLE: | |
from sam2.build_sam import build_sam2 | |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator | |
from sam2.modeling.sam2_base import SAM2Base | |
# ============================================================================= | |
# Utility: device selection | |
# ============================================================================= | |
def get_device(): | |
if torch.cuda.is_available(): return torch.device('cuda') | |
if torch.backends.mps.is_available(): return torch.device('mps') | |
return torch.device('cpu') | |
# ============================================================================= | |
# Qwen-VLM: loading & agent | |
# ============================================================================= | |
_qwen_model = None | |
_qwen_processor = None | |
_qwen_device = None | |
def load_qwen_model_and_processor(): | |
global _qwen_model, _qwen_processor, _qwen_device | |
if _qwen_model is None: | |
_qwen_device = get_device() | |
auth = {"use_auth_token": HF_TOKEN} if HF_TOKEN else {} | |
print(f"[Qwen] Loading model with auth={'yes' if HF_TOKEN else 'no'} on {_qwen_device}") | |
try: | |
_qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
"Qwen/Qwen2.5-VL-3B-Instruct", | |
trust_remote_code=True, | |
torch_dtype=torch.float32, | |
low_cpu_mem_usage=True, | |
**auth | |
).to(_qwen_device) | |
_qwen_processor = AutoProcessor.from_pretrained( | |
"Qwen/Qwen2.5-VL-3B-Instruct", | |
trust_remote_code=True, | |
**auth | |
) | |
except Exception as e: | |
print(f"[Qwen] Model load failed: {e}") | |
_qwen_model = None | |
_qwen_processor = None | |
return _qwen_model, _qwen_processor, _qwen_device | |
class MedicalVLMAgent: | |
def __init__(self, model, processor, device): | |
self.model = model; self.processor = processor; self.device = device | |
self.sys_prompt = ( | |
"You are a medical information assistant with vision capabilities.\n" | |
"Disclaimer: I am not a licensed medical professional." | |
) | |
def run(self, text, image=None): | |
if not self.model or not self.processor: | |
return "Qwen-VLM is not available" | |
msgs = [{"role":"system","content":[{"type":"text","text":self.sys_prompt}]}] | |
user_cont = [] | |
if image: | |
tmp = f"/tmp/{uuid.uuid4()}.png"; image.save(tmp) | |
user_cont.append({"type":"image","image":tmp}) | |
user_cont.append({"type":"text","text": text or ""}) | |
msgs.append({"role":"user","content":user_cont}) | |
prompt = self.processor.apply_chat_template( | |
msgs, tokenize=False, add_generation_prompt=True | |
) | |
inputs = self.processor( | |
text=[prompt], images=[], videos=[], padding=True, return_tensors='pt' | |
).to(self.device) | |
out = self.model.generate(**inputs, max_new_tokens=128) | |
resp = out[0][inputs.input_ids.shape[1]:] | |
return self.processor.decode(resp, skip_special_tokens=True).strip() | |
# ============================================================================= | |
# SAM-2 segmentation interface | |
# ============================================================================= | |
_sam2_model, _mask_generator = (None, None) | |
if SAM2_AVAILABLE: | |
try: | |
CKPT="checkpoints/sam2.1_hiera_large.pt" | |
CFG="configs/sam2.1/sam2.1_hiera_l.yaml" | |
os.chdir("segment-anything-2/sam2/sam2") | |
_sam2_model = build_sam2( | |
CFG, CKPT, device=get_device(), apply_postprocessing=False | |
) | |
_mask_generator = SAM2AutomaticMaskGenerator(_sam2_model) | |
except Exception as e: | |
print(f"[SAM-2] Initialization error: {e}") | |
_mask_generator = None | |
# ============================================================================= | |
# CheXagent: structured report & grounding | |
# ============================================================================= | |
try: | |
print(f"[CheXagent] Loading with auth={'yes' if HF_TOKEN else 'no'}") | |
chex_tok = AutoTokenizer.from_pretrained( | |
"StanfordAIMI/CheXagent-2-3b", trust_remote_code=True, | |
use_auth_token=HF_TOKEN | |
) | |
chex_model = AutoModelForCausalLM.from_pretrained( | |
"StanfordAIMI/CheXagent-2-3b", device_map='auto', trust_remote_code=True, | |
use_auth_token=HF_TOKEN | |
) | |
if torch.cuda.is_available(): chex_model = chex_model.half() | |
chex_model.eval() | |
CHEX_AVAILABLE = True | |
except Exception as e: | |
print(f"[CheXagent] Load failed: {e}") | |
CHEX_AVAILABLE = False | |
def report_generation(im1, im2): | |
if not CHEX_AVAILABLE: | |
yield "CheXagent unavailable" | |
return | |
streamer = TextIteratorStreamer(chex_tok, skip_prompt=True) | |
yield "Streaming report... (not fully implemented)" | |
def phrase_grounding(image, prompt): | |
if not CHEX_AVAILABLE: | |
return "CheXagent unavailable", None | |
w,h = image.size; draw = ImageDraw.Draw(image) | |
draw.rectangle([(w*0.25,h*0.25),(w*0.75,h*0.75)], outline='red', width=3) | |
return prompt, image | |
# ============================================================================= | |
# Gradio UI | |
# ============================================================================= | |
def create_ui(): | |
m, p, d = load_qwen_model_and_processor() | |
qwen_ok = bool(m and p) | |
med = MedicalVLMAgent(m, p, d) if qwen_ok else None | |
with gr.Blocks() as demo: | |
gr.Markdown("# Medical AI Assistant") | |
gr.Markdown( | |
f"- Qwen: {'β ' if qwen_ok else 'β'} " | |
f"- SAM-2: {'β ' if _mask_generator else 'β'} " | |
f"- CheXagent: {'β ' if CHEX_AVAILABLE else 'β'}" | |
) | |
with gr.Tab("Medical Q&A"): | |
if qwen_ok: | |
txt = gr.Textbox(label="Question / description", lines=3) | |
img = gr.Image(label="Optional image", type='pil') | |
out = gr.Textbox(label="Answer") | |
gr.Button("Ask").click(med.run, [txt, img], out) | |
else: | |
gr.Markdown("β Medical Q&A not available. Check HF_TOKEN and connectivity.") | |
with gr.Tab("Segmentation"): | |
seg = gr.Image(label="Upload image", type='pil') | |
so = gr.Image(label="Result") | |
ss = gr.Textbox(label="Status", interactive=False) | |
fn = segmentation_interface if _mask_generator else fallback_segmentation | |
gr.Button("Segment").click(fn, [seg], [so, ss]) | |
with gr.Tab("CheXagent Report"): | |
c1 = gr.Image(type='pil', label="Image 1") | |
c2 = gr.Image(type='pil', label="Image 2") | |
rout = gr.Markdown() | |
if CHEX_AVAILABLE: | |
gr.Interface(report_generation, [c1, c2], rout, live=True).render() | |
else: | |
gr.Markdown("β CheXagent report not available. Check HF_TOKEN and connectivity.") | |
with gr.Tab("CheXagent Grounding"): | |
gi = gr.Image(type='pil', label="Image") | |
gp = gr.Textbox(label="Prompt") | |
gout = gr.Textbox(label="Response") | |
goimg = gr.Image(label="Output Image") | |
if CHEX_AVAILABLE: | |
gr.Interface(phrase_grounding, [gi, gp], [gout, goimg]).render() | |
else: | |
gr.Markdown("β CheXagent grounding not available.") | |
return demo | |
if __name__ == "__main__": | |
ui = create_ui() | |
ui.launch(server_name='0.0.0.0', server_port=7860, share=True) | |