Spaces:
Running
Running
##!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
""" | |
Combined Medical-VLM, SAM-2 Automatic Masking, and CheXagent Demo | |
================================================================= | |
Features: | |
- Qwen2.5-VL Instruct medical vision-language Q&A | |
- SAM-2 segmentation with alias patch for Hugging Face | |
- Simple fallback segmentation | |
- CheXagent structured report & visual grounding | |
- Automatic dependency checking & installation for SAM-2 | |
Usage: | |
python medical_ai_app.py # launches Gradio UI on port 7860 | |
Requires: | |
torch, transformers, PIL, gradio, ultralytics, requests, opencv-python, pyyaml | |
""" | |
import os | |
import sys | |
import uuid | |
import tempfile | |
import subprocess | |
import warnings | |
from threading import Thread | |
from pathlib import Path | |
# Environment setup | |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" | |
warnings.filterwarnings("ignore", message=r".*upsample_bicubic2d.*") | |
# Third-party libs | |
import torch | |
import numpy as np | |
import cv2 | |
from PIL import Image, ImageDraw | |
import gradio as gr | |
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
# ============================================================================= | |
# SAM-2 Alias Patch & Installer | |
# ============================================================================= | |
try: | |
import sam_2, importlib | |
sys.modules['sam2'] = sam_2 | |
for sub in ['build_sam','automatic_mask_generator','modeling.sam2_base']: | |
sys.modules[f'sam2.{sub}'] = importlib.import_module(f'sam_2.{sub}') | |
except ImportError: | |
pass | |
def check_and_install_sam2(): | |
try: | |
from sam2.build_sam import build_sam2 | |
return True | |
except ImportError: | |
repo_dir = Path("segment-anything-2") | |
if not repo_dir.exists(): | |
subprocess.run(["git","clone","https://github.com/facebookresearch/segment-anything-2.git"], check=True) | |
cwd = Path.cwd() | |
os.chdir(repo_dir) | |
subprocess.run([sys.executable, "-m", "pip", "install", "-e", "."], check=True) | |
os.chdir(cwd) | |
try: | |
import sam_2 | |
importlib.reload(sam_2) | |
sys.modules['sam2'] = sam_2 | |
for sub in ['build_sam','automatic_mask_generator','modeling.sam2_base']: | |
sys.modules[f'sam2.{sub}'] = importlib.import_module(f'sam_2.{sub}') | |
return True | |
except ImportError: | |
return False | |
SAM2_AVAILABLE = check_and_install_sam2() | |
print(f"SAM-2 Available: {SAM2_AVAILABLE}") | |
if SAM2_AVAILABLE: | |
from sam2.build_sam import build_sam2 | |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator | |
from sam2.modeling.sam2_base import SAM2Base | |
# ============================================================================= | |
# Utility: device selection | |
# ============================================================================= | |
def get_device(): | |
if torch.cuda.is_available(): return torch.device('cuda') | |
if torch.backends.mps.is_available(): return torch.device('mps') | |
return torch.device('cpu') | |
# ============================================================================= | |
# Qwen-VLM: loading & agent | |
# ============================================================================= | |
_qwen_model = None | |
_qwen_processor = None | |
_qwen_device = None | |
def load_qwen_model_and_processor(hf_token=None): | |
global _qwen_model, _qwen_processor, _qwen_device | |
if _qwen_model is None: | |
_qwen_device = get_device() | |
auth = {"use_auth_token": hf_token} if hf_token else {} | |
_qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
"Qwen/Qwen2.5-VL-3B-Instruct", trust_remote_code=True, | |
torch_dtype=torch.float32, low_cpu_mem_usage=True, **auth | |
).to(_qwen_device) | |
_qwen_processor = AutoProcessor.from_pretrained( | |
"Qwen/Qwen2.5-VL-3B-Instruct", trust_remote_code=True, **auth | |
) | |
return _qwen_model, _qwen_processor, _qwen_device | |
class MedicalVLMAgent: | |
def __init__(self, model, processor, device): | |
self.model = model; self.processor = processor; self.device = device | |
self.sys_prompt = ( | |
"You are a medical information assistant with vision capabilities.\n" | |
"Disclaimer: I am not a licensed medical professional." | |
) | |
def run(self, text, image=None): | |
msgs = [{"role":"system","content":[{"type":"text","text":self.sys_prompt}]}] | |
user_cont = [] | |
if image: | |
tmp = f"/tmp/{uuid.uuid4()}.png"; image.save(tmp) | |
user_cont.append({"type":"image","image":tmp}) | |
user_cont.append({"type":"text","text": text or ""}) | |
msgs.append({"role":"user","content":user_cont}) | |
prompt = self.processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) | |
inputs = self.processor(text=[prompt], images=[], videos=[], padding=True, return_tensors='pt').to(self.device) | |
out = self.model.generate(**inputs, max_new_tokens=128) | |
resp = out[0][inputs.input_ids.shape[1]:] | |
return self.processor.decode(resp, skip_special_tokens=True).strip() | |
# ============================================================================= | |
# SAM-2 segmentation interface | |
# ============================================================================= | |
_sam2_model, _mask_generator = (None, None) | |
if SAM2_AVAILABLE: | |
try: | |
CKPT="checkpoints/sam2.1_hiera_large.pt"; CFG="configs/sam2.1/sam2.1_hiera_l.yaml" | |
os.chdir("segment-anything-2/sam2/sam2") | |
_sam2_model = build_sam2(CFG, CKPT, device=get_device(), apply_postprocessing=False) | |
_mask_generator = SAM2AutomaticMaskGenerator(_sam2_model) | |
except Exception as e: | |
print(f"SAM-2 init error: {e}") | |
_mask_generator = None | |
def segmentation_interface(image): | |
if image is None: return None, "Upload an image" | |
if not _mask_generator: return None, "SAM-2 unavailable" | |
arr = np.array(image.convert('RGB')) | |
anns = _mask_generator.generate(arr) | |
overlay = arr.copy() | |
for ann in sorted(anns, key=lambda x: x['area'], reverse=True): | |
m = ann['segmentation']; color=np.random.randint(0,255,3) | |
overlay[m] = (overlay[m]*0.5 + color*0.5).astype(np.uint8) | |
return Image.fromarray(overlay), f"{len(anns)} masks found" | |
# ============================================================================= | |
# Fallback segmentation | |
# ============================================================================= | |
def fallback_segmentation(image): | |
if image is None: return None, "Upload an image" | |
arr = np.array(image.convert('RGB')) | |
gray=cv2.cvtColor(arr,cv2.COLOR_RGB2GRAY) | |
_,th=cv2.threshold(gray,0,255,cv2.THRESH_BINARY_INV+cv2.THRESH_OTSU) | |
overlay=arr.copy(); overlay[th>0]=[255,0,0] | |
blended=cv2.addWeighted(arr,0.7,overlay,0.3,0) | |
return Image.fromarray(blended), "Fallback applied" | |
# ============================================================================= | |
# CheXagent: structured report & grounding | |
# ============================================================================= | |
try: | |
chex_tok = AutoTokenizer.from_pretrained("StanfordAIMI/CheXagent-2-3b", trust_remote_code=True) | |
chex_model = AutoModelForCausalLM.from_pretrained("StanfordAIMI/CheXagent-2-3b", device_map='auto', trust_remote_code=True) | |
if torch.cuda.is_available(): chex_model = chex_model.half() | |
chex_model.eval(); CHEX_AVAILABLE=True | |
except Exception: | |
CHEX_AVAILABLE=False | |
def report_generation(im1, im2): | |
if not CHEX_AVAILABLE: yield "CheXagent unavailable"; return | |
streamer = TextIteratorStreamer(chex_tok, skip_prompt=True) | |
yield "Report streaming not fully implemented" | |
def phrase_grounding(image, prompt): | |
if not CHEX_AVAILABLE: return "CheXagent unavailable", None | |
w,h=image.size; draw=ImageDraw.Draw(image) | |
draw.rectangle([(w*0.25,h*0.25),(w*0.75,h*0.75)], outline='red', width=3) | |
return prompt, image | |
# ============================================================================= | |
# Gradio UI | |
# ============================================================================= | |
def create_ui(): | |
try: | |
m, p, d = load_qwen_model_and_processor() | |
med = MedicalVLMAgent(m,p,d); QW=True | |
except: | |
QW=False; med=None | |
with gr.Blocks() as demo: | |
gr.Markdown("# Medical AI Assistant") | |
gr.Markdown(f"- Qwen: {'β ' if QW else 'β'} - SAM-2: {'β ' if _mask_generator else 'β'} - CheX: {'β ' if CHEX_AVAILABLE else 'β'}") | |
with gr.Tab("Medical Q&A"): | |
txt=gr.Textbox(); img=gr.Image(type='pil'); out=gr.Textbox(); gr.Button("Ask").click(med.run,[txt,img],out) | |
with gr.Tab("Segmentation"): | |
seg=gr.Image(type='pil'); so=gr.Image(); ss=gr.Textbox(); fn=segmentation_interface if _mask_generator else fallback_segmentation; gr.Button("Segment").click(fn,seg,[so,ss]) | |
with gr.Tab("CheXagent Report"): | |
c1=gr.Image(type='pil');c2=gr.Image(type='pil'); rout=gr.Markdown(); gr.Interface(report_generation,[c1,c2],rout,live=True).render() | |
with gr.Tab("CheXagent Grounding"): | |
gi=gr.Image(type='pil'); gp=gr.Textbox(); gout=gr.Textbox(); goimg=gr.Image(); gr.Interface(phrase_grounding,[gi,gp],[gout,goimg]).render() | |
return demo | |
if __name__ == "__main__": | |
ui=create_ui(); ui.launch(server_name='0.0.0.0',server_port=7860,share=True) | |