Spaces:
Running
Running
File size: 9,468 Bytes
727fad1 c04077b 727fad1 c04077b 07f5f6e 727fad1 07f5f6e 727fad1 4d030cf 07f5f6e 727fad1 07f5f6e 727fad1 07f5f6e 727fad1 07f5f6e 727fad1 f22adfd 727fad1 07f5f6e 727fad1 4d030cf 727fad1 07f5f6e f22adfd 4d030cf 07f5f6e 727fad1 07f5f6e 727fad1 4d030cf 2fb54d3 07f5f6e 727fad1 4d030cf 727fad1 f22adfd 727fad1 4d030cf 727fad1 4d030cf 727fad1 4d030cf 727fad1 4d030cf 727fad1 4d030cf 727fad1 f22adfd 4d030cf 727fad1 4d030cf 727fad1 4d030cf 727fad1 4d030cf 727fad1 4d030cf 727fad1 4d030cf 727fad1 4d030cf 727fad1 4d030cf c04077b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
##!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Combined Medical-VLM, SAM-2 Automatic Masking, and CheXagent Demo
=================================================================
Features:
- Qwen2.5-VL Instruct medical vision-language Q&A
- SAM-2 segmentation with alias patch for Hugging Face
- Simple fallback segmentation
- CheXagent structured report & visual grounding
- Automatic dependency checking & installation for SAM-2
Usage:
python medical_ai_app.py # launches Gradio UI on port 7860
Requires:
torch, transformers, PIL, gradio, ultralytics, requests, opencv-python, pyyaml
"""
import os
import sys
import uuid
import tempfile
import subprocess
import warnings
from threading import Thread
from pathlib import Path
# Environment setup
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
warnings.filterwarnings("ignore", message=r".*upsample_bicubic2d.*")
# Third-party libs
import torch
import numpy as np
import cv2
from PIL import Image, ImageDraw
import gradio as gr
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
# =============================================================================
# SAM-2 Alias Patch & Installer
# =============================================================================
try:
import sam_2, importlib
sys.modules['sam2'] = sam_2
for sub in ['build_sam','automatic_mask_generator','modeling.sam2_base']:
sys.modules[f'sam2.{sub}'] = importlib.import_module(f'sam_2.{sub}')
except ImportError:
pass
def check_and_install_sam2():
try:
from sam2.build_sam import build_sam2
return True
except ImportError:
repo_dir = Path("segment-anything-2")
if not repo_dir.exists():
subprocess.run(["git","clone","https://github.com/facebookresearch/segment-anything-2.git"], check=True)
cwd = Path.cwd()
os.chdir(repo_dir)
subprocess.run([sys.executable, "-m", "pip", "install", "-e", "."], check=True)
os.chdir(cwd)
try:
import sam_2
importlib.reload(sam_2)
sys.modules['sam2'] = sam_2
for sub in ['build_sam','automatic_mask_generator','modeling.sam2_base']:
sys.modules[f'sam2.{sub}'] = importlib.import_module(f'sam_2.{sub}')
return True
except ImportError:
return False
SAM2_AVAILABLE = check_and_install_sam2()
print(f"SAM-2 Available: {SAM2_AVAILABLE}")
if SAM2_AVAILABLE:
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.modeling.sam2_base import SAM2Base
# =============================================================================
# Utility: device selection
# =============================================================================
def get_device():
if torch.cuda.is_available(): return torch.device('cuda')
if torch.backends.mps.is_available(): return torch.device('mps')
return torch.device('cpu')
# =============================================================================
# Qwen-VLM: loading & agent
# =============================================================================
_qwen_model = None
_qwen_processor = None
_qwen_device = None
def load_qwen_model_and_processor(hf_token=None):
global _qwen_model, _qwen_processor, _qwen_device
if _qwen_model is None:
_qwen_device = get_device()
auth = {"use_auth_token": hf_token} if hf_token else {}
_qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct", trust_remote_code=True,
torch_dtype=torch.float32, low_cpu_mem_usage=True, **auth
).to(_qwen_device)
_qwen_processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct", trust_remote_code=True, **auth
)
return _qwen_model, _qwen_processor, _qwen_device
class MedicalVLMAgent:
def __init__(self, model, processor, device):
self.model = model; self.processor = processor; self.device = device
self.sys_prompt = (
"You are a medical information assistant with vision capabilities.\n"
"Disclaimer: I am not a licensed medical professional."
)
def run(self, text, image=None):
msgs = [{"role":"system","content":[{"type":"text","text":self.sys_prompt}]}]
user_cont = []
if image:
tmp = f"/tmp/{uuid.uuid4()}.png"; image.save(tmp)
user_cont.append({"type":"image","image":tmp})
user_cont.append({"type":"text","text": text or ""})
msgs.append({"role":"user","content":user_cont})
prompt = self.processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
inputs = self.processor(text=[prompt], images=[], videos=[], padding=True, return_tensors='pt').to(self.device)
out = self.model.generate(**inputs, max_new_tokens=128)
resp = out[0][inputs.input_ids.shape[1]:]
return self.processor.decode(resp, skip_special_tokens=True).strip()
# =============================================================================
# SAM-2 segmentation interface
# =============================================================================
_sam2_model, _mask_generator = (None, None)
if SAM2_AVAILABLE:
try:
CKPT="checkpoints/sam2.1_hiera_large.pt"; CFG="configs/sam2.1/sam2.1_hiera_l.yaml"
os.chdir("segment-anything-2/sam2/sam2")
_sam2_model = build_sam2(CFG, CKPT, device=get_device(), apply_postprocessing=False)
_mask_generator = SAM2AutomaticMaskGenerator(_sam2_model)
except Exception as e:
print(f"SAM-2 init error: {e}")
_mask_generator = None
def segmentation_interface(image):
if image is None: return None, "Upload an image"
if not _mask_generator: return None, "SAM-2 unavailable"
arr = np.array(image.convert('RGB'))
anns = _mask_generator.generate(arr)
overlay = arr.copy()
for ann in sorted(anns, key=lambda x: x['area'], reverse=True):
m = ann['segmentation']; color=np.random.randint(0,255,3)
overlay[m] = (overlay[m]*0.5 + color*0.5).astype(np.uint8)
return Image.fromarray(overlay), f"{len(anns)} masks found"
# =============================================================================
# Fallback segmentation
# =============================================================================
def fallback_segmentation(image):
if image is None: return None, "Upload an image"
arr = np.array(image.convert('RGB'))
gray=cv2.cvtColor(arr,cv2.COLOR_RGB2GRAY)
_,th=cv2.threshold(gray,0,255,cv2.THRESH_BINARY_INV+cv2.THRESH_OTSU)
overlay=arr.copy(); overlay[th>0]=[255,0,0]
blended=cv2.addWeighted(arr,0.7,overlay,0.3,0)
return Image.fromarray(blended), "Fallback applied"
# =============================================================================
# CheXagent: structured report & grounding
# =============================================================================
try:
chex_tok = AutoTokenizer.from_pretrained("StanfordAIMI/CheXagent-2-3b", trust_remote_code=True)
chex_model = AutoModelForCausalLM.from_pretrained("StanfordAIMI/CheXagent-2-3b", device_map='auto', trust_remote_code=True)
if torch.cuda.is_available(): chex_model = chex_model.half()
chex_model.eval(); CHEX_AVAILABLE=True
except Exception:
CHEX_AVAILABLE=False
@torch.no_grad()
def report_generation(im1, im2):
if not CHEX_AVAILABLE: yield "CheXagent unavailable"; return
streamer = TextIteratorStreamer(chex_tok, skip_prompt=True)
yield "Report streaming not fully implemented"
@torch.no_grad()
def phrase_grounding(image, prompt):
if not CHEX_AVAILABLE: return "CheXagent unavailable", None
w,h=image.size; draw=ImageDraw.Draw(image)
draw.rectangle([(w*0.25,h*0.25),(w*0.75,h*0.75)], outline='red', width=3)
return prompt, image
# =============================================================================
# Gradio UI
# =============================================================================
def create_ui():
try:
m, p, d = load_qwen_model_and_processor()
med = MedicalVLMAgent(m,p,d); QW=True
except:
QW=False; med=None
with gr.Blocks() as demo:
gr.Markdown("# Medical AI Assistant")
gr.Markdown(f"- Qwen: {'β
' if QW else 'β'} - SAM-2: {'β
' if _mask_generator else 'β'} - CheX: {'β
' if CHEX_AVAILABLE else 'β'}")
with gr.Tab("Medical Q&A"):
txt=gr.Textbox(); img=gr.Image(type='pil'); out=gr.Textbox(); gr.Button("Ask").click(med.run,[txt,img],out)
with gr.Tab("Segmentation"):
seg=gr.Image(type='pil'); so=gr.Image(); ss=gr.Textbox(); fn=segmentation_interface if _mask_generator else fallback_segmentation; gr.Button("Segment").click(fn,seg,[so,ss])
with gr.Tab("CheXagent Report"):
c1=gr.Image(type='pil');c2=gr.Image(type='pil'); rout=gr.Markdown(); gr.Interface(report_generation,[c1,c2],rout,live=True).render()
with gr.Tab("CheXagent Grounding"):
gi=gr.Image(type='pil'); gp=gr.Textbox(); gout=gr.Textbox(); goimg=gr.Image(); gr.Interface(phrase_grounding,[gi,gp],[gout,goimg]).render()
return demo
if __name__ == "__main__":
ui=create_ui(); ui.launch(server_name='0.0.0.0',server_port=7860,share=True)
|