pascal-maker commited on
Commit
c04077b
·
1 Parent(s): d478238

Add requirements.txt and medical VLM SAM-2 CheXagent demo

Browse files
Files changed (3) hide show
  1. app.py +435 -0
  2. requirements.txt +32 -0
  3. sam2 +1 -0
app.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Combined Medical-VLM, **SAM-2 automatic masking**, and CheXagent demo.
6
+
7
+ ⭑ Changes ⭑
8
+ -----------
9
+ 1. All Segment-Anything-v1 fallback code has been removed.
10
+ 2. A single **SAM-2 AutomaticMaskGenerator** is built once and reused.
11
+ 3. Tumor-segmentation tab now runs *fully automatic* masking — no bounding-box textbox.
12
+ 4. Fixed SAM-2 config path to use relative path instead of absolute path.
13
+ """
14
+
15
+ # ---------------------------------------------------------------------
16
+ # Standard libs
17
+ # ---------------------------------------------------------------------
18
+ # ---------------------------------------------------------------------
19
+ import os, warnings
20
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # CPU fallback for missing MPS ops
21
+ warnings.filterwarnings("ignore", message=r".*upsample_bicubic2d.*") # hide one-line notice
22
+
23
+ import os
24
+ import sys
25
+ import uuid
26
+ import tempfile
27
+ from threading import Thread
28
+
29
+ # ---------------------------------------------------------------------
30
+ # Third-party libs
31
+
32
+
33
+ # ---------------------------------------------------------------------
34
+ import torch
35
+ import numpy as np
36
+ from PIL import Image, ImageDraw
37
+ import gradio as gr
38
+
39
+ # If you cloned facebookresearch/sam2 into the repo root, make sure it's importable
40
+ sys.path.append(os.path.abspath("."))
41
+
42
+ # =============================================================================
43
+ # Qwen-VLM imports & helper
44
+ # =============================================================================
45
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
46
+ from qwen_vl_utils import process_vision_info
47
+
48
+
49
+ # =============================================================================
50
+ # SAM-2 imports (only SAM-2, no v1 fallback)
51
+ # =============================================================================
52
+ from sam2.build_sam import build_sam2
53
+ from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
54
+
55
+ # Alternative: try direct model loading if build_sam2 continues to fail
56
+ try:
57
+ from sam2.modeling.sam2_base import SAM2Base
58
+ from sam2.utils.misc import get_device_index
59
+ except ImportError:
60
+ print("Could not import additional SAM2 components")
61
+
62
+
63
+ # =============================================================================
64
+ # CheXagent imports
65
+ # =============================================================================
66
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
67
+
68
+
69
+ # ---------------------------------------------------------------------
70
+ # Devices
71
+ # ---------------------------------------------------------------------
72
+ def get_device():
73
+ if torch.cuda.is_available():
74
+ return torch.device("cuda")
75
+ if torch.backends.mps.is_available():
76
+ return torch.device("mps")
77
+ return torch.device("cpu")
78
+
79
+
80
+ # =============================================================================
81
+ # Qwen-VLM model & agent
82
+ # =============================================================================
83
+ _qwen_model = None
84
+ _qwen_processor = None
85
+ _qwen_device = None
86
+
87
+
88
+ def load_qwen_model_and_processor(hf_token=None):
89
+ global _qwen_model, _qwen_processor, _qwen_device
90
+ if _qwen_model is None:
91
+ _qwen_device = "mps" if torch.backends.mps.is_available() else "cpu"
92
+ print(f"[Qwen] loading model on {_qwen_device}")
93
+ auth_kwargs = {"use_auth_token": hf_token} if hf_token else {}
94
+ _qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
95
+ "Qwen/Qwen2.5-VL-3B-Instruct",
96
+ trust_remote_code=True,
97
+ attn_implementation="eager",
98
+ torch_dtype=torch.float32,
99
+ low_cpu_mem_usage=True,
100
+ device_map=None,
101
+ **auth_kwargs,
102
+ ).to(_qwen_device)
103
+ _qwen_processor = AutoProcessor.from_pretrained(
104
+ "Qwen/Qwen2.5-VL-3B-Instruct",
105
+ trust_remote_code=True,
106
+ **auth_kwargs,
107
+ )
108
+ return _qwen_model, _qwen_processor, _qwen_device
109
+
110
+
111
+ class MedicalVLMAgent:
112
+ """Light wrapper around Qwen-VLM with an optional image."""
113
+
114
+ def __init__(self, model, processor, device):
115
+ self.model = model
116
+ self.processor = processor
117
+ self.device = device
118
+ self.system_prompt = (
119
+ "You are a medical information assistant with vision capabilities.\n"
120
+ "Disclaimer: I am not a licensed medical professional. "
121
+ "The information provided is for reference only and should not be taken as medical advice."
122
+ )
123
+
124
+ def run(self, user_text: str, image: Image.Image | None = None) -> str:
125
+ messages = [
126
+ {"role": "system", "content": [{"type": "text", "text": self.system_prompt}]}
127
+ ]
128
+ user_content = []
129
+ if image is not None:
130
+ tmp = f"/tmp/{uuid.uuid4()}.png"
131
+ image.save(tmp)
132
+ user_content.append({"type": "image", "image": tmp})
133
+ user_content.append({"type": "text", "text": user_text or "Please describe the image."})
134
+ messages.append({"role": "user", "content": user_content})
135
+
136
+ prompt_text = self.processor.apply_chat_template(
137
+ messages, tokenize=False, add_generation_prompt=True
138
+ )
139
+ img_inputs, vid_inputs = process_vision_info(messages)
140
+ inputs = self.processor(
141
+ text=[prompt_text],
142
+ images=img_inputs,
143
+ videos=vid_inputs,
144
+ padding=True,
145
+ return_tensors="pt",
146
+ ).to(self.device)
147
+
148
+ with torch.no_grad():
149
+ out = self.model.generate(**inputs, max_new_tokens=128)
150
+ trimmed = out[0][inputs.input_ids.shape[1] :]
151
+ return self.processor.decode(trimmed, skip_special_tokens=True).strip()
152
+
153
+
154
+ # =============================================================================
155
+ # SAM-2 model + AutomaticMaskGenerator
156
+ # =============================================================================
157
+
158
+ # =============================================================================
159
+ # SAM-2.1 model + AutomaticMaskGenerator (concise version)
160
+ # =============================================================================
161
+ # =============================================================================
162
+ # SAM-2.1 model + AutomaticMaskGenerator (final minimal version)
163
+ # =============================================================================
164
+ import os
165
+ from sam2.build_sam import build_sam2
166
+ from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
167
+
168
+ def initialize_sam2():
169
+ # These two files are already in your repo
170
+ CKPT = "checkpoints/sam2.1_hiera_large.pt" # ≈2.7 GB
171
+ CFG = "configs/sam2.1/sam2.1_hiera_l.yaml"
172
+
173
+ # One chdir so Hydra's search path starts inside sam2/sam2/
174
+ os.chdir("sam2/sam2")
175
+
176
+ device = get_device()
177
+ print(f"[SAM-2] building model on {device}")
178
+
179
+ sam2_model = build_sam2(
180
+ CFG, # relative to sam2/sam2/
181
+ CKPT, # relative after chdir
182
+ device=device,
183
+ apply_postprocessing=False,
184
+ )
185
+
186
+ mask_gen = SAM2AutomaticMaskGenerator(
187
+ model=sam2_model,
188
+ points_per_side=32,
189
+ pred_iou_thresh=0.86,
190
+ stability_score_thresh=0.92,
191
+ crop_n_layers=0,
192
+ )
193
+ return sam2_model, mask_gen
194
+
195
+
196
+ # ---------------------- build once ----------------------
197
+ try:
198
+ _sam2_model, _mask_generator = initialize_sam2()
199
+ print("[SAM-2] Successfully initialized!")
200
+ except Exception as e:
201
+ print(f"[SAM-2] Failed to initialize: {e}")
202
+ _sam2_model, _mask_generator = None, None
203
+
204
+ def automatic_mask_overlay(image_np: np.ndarray) -> np.ndarray:
205
+ """Generate masks and alpha-blend them on top of the original image."""
206
+ if _mask_generator is None:
207
+ raise RuntimeError("SAM-2 mask generator not initialized")
208
+
209
+ anns = _mask_generator.generate(image_np)
210
+ if not anns:
211
+ return image_np
212
+
213
+ overlay = image_np.copy()
214
+ if overlay.ndim == 2: # grayscale → RGB
215
+ overlay = np.stack([overlay] * 3, axis=2)
216
+
217
+ for ann in sorted(anns, key=lambda x: x["area"], reverse=True):
218
+ m = ann["segmentation"]
219
+ color = np.random.randint(0, 255, 3, dtype=np.uint8)
220
+ overlay[m] = (overlay[m] * 0.5 + color * 0.5).astype(np.uint8)
221
+
222
+ return overlay
223
+
224
+ def tumor_segmentation_interface(image: Image.Image | None):
225
+ if image is None:
226
+ return None, "Please upload an image."
227
+
228
+ if _mask_generator is None:
229
+ return None, "SAM-2 not properly initialized. Check the console for errors."
230
+
231
+ try:
232
+ img_np = np.array(image.convert("RGB"))
233
+ out_np = automatic_mask_overlay(img_np)
234
+ n_masks = len(_mask_generator.generate(img_np))
235
+ return Image.fromarray(out_np), f"{n_masks} masks found."
236
+ except Exception as e:
237
+ return None, f"SAM-2 error: {e}"
238
+
239
+ # =============================================================================
240
+ # CheXagent set-up (unchanged)
241
+ # =============================================================================
242
+ chex_name = "StanfordAIMI/CheXagent-2-3b"
243
+ chex_tok = AutoTokenizer.from_pretrained(chex_name, trust_remote_code=True)
244
+ chex_model = AutoModelForCausalLM.from_pretrained(
245
+ chex_name, device_map="auto", trust_remote_code=True
246
+ )
247
+ chex_model = chex_model.half() if torch.cuda.is_available() else chex_model.float()
248
+ chex_model.eval()
249
+
250
+
251
+ def get_model_device(model):
252
+ for p in model.parameters():
253
+ return p.device
254
+ return torch.device("cpu")
255
+
256
+
257
+ def clean_text(text):
258
+ return text.replace("</s>", "")
259
+
260
+
261
+ @torch.no_grad()
262
+ def response_report_generation(pil_image_1, pil_image_2):
263
+ """Structured chest-X-ray report (streaming)."""
264
+ streamer = TextIteratorStreamer(chex_tok, skip_prompt=True, skip_special_tokens=True)
265
+ paths = []
266
+ for im in [pil_image_1, pil_image_2]:
267
+ if im is None:
268
+ continue
269
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tfile:
270
+ im.save(tfile.name)
271
+ paths.append(tfile.name)
272
+
273
+ device = get_model_device(chex_model)
274
+ anatomies = [
275
+ "View",
276
+ "Airway",
277
+ "Breathing",
278
+ "Cardiac",
279
+ "Diaphragm",
280
+ "Everything else (e.g., mediastinal contours, bones, soft tissues, tubes, valves, pacemakers)",
281
+ ]
282
+ prompts = [
283
+ "Determine the view of this CXR",
284
+ *[
285
+ f'Provide a detailed description of "{a}" in the chest X-ray'
286
+ for a in anatomies[1:]
287
+ ],
288
+ ]
289
+
290
+ findings = ""
291
+ partial = "## Generating Findings (step-by-step):\n\n"
292
+ for idx, (anat, prompt) in enumerate(zip(anatomies, prompts)):
293
+ query = chex_tok.from_list_format(
294
+ [*[{"image": p} for p in paths], {"text": prompt}]
295
+ )
296
+ conv = [
297
+ {"from": "system", "value": "You are a helpful assistant."},
298
+ {"from": "human", "value": query},
299
+ ]
300
+ inp = chex_tok.apply_chat_template(
301
+ conv, add_generation_prompt=True, return_tensors="pt"
302
+ ).to(device)
303
+ generate_kwargs = dict(
304
+ input_ids=inp,
305
+ max_new_tokens=512,
306
+ do_sample=False,
307
+ num_beams=1,
308
+ streamer=streamer,
309
+ )
310
+ Thread(target=chex_model.generate, kwargs=generate_kwargs).start()
311
+ partial += f"**Step {idx}: {anat}...**\n\n"
312
+ for tok in streamer:
313
+ if idx:
314
+ findings += tok
315
+ partial += tok
316
+ yield clean_text(partial)
317
+ partial += "\n\n"
318
+ findings += " "
319
+ findings = findings.strip()
320
+
321
+ # Impression
322
+ partial += "## Generating Impression\n\n"
323
+ prompt = f"Write the Impression section for the following Findings: {findings}"
324
+ conv = [
325
+ {"from": "system", "value": "You are a helpful assistant."},
326
+ {"from": "human", "value": chex_tok.from_list_format([{"text": prompt}])},
327
+ ]
328
+ inp = chex_tok.apply_chat_template(
329
+ conv, add_generation_prompt=True, return_tensors="pt"
330
+ ).to(device)
331
+ Thread(
332
+ target=chex_model.generate,
333
+ kwargs=dict(
334
+ input_ids=inp,
335
+ do_sample=False,
336
+ num_beams=1,
337
+ max_new_tokens=512,
338
+ streamer=streamer,
339
+ ),
340
+ ).start()
341
+ for tok in streamer:
342
+ partial += tok
343
+ yield clean_text(partial)
344
+ yield clean_text(partial)
345
+
346
+
347
+ @torch.no_grad()
348
+ def response_phrase_grounding(pil_image, prompt_text):
349
+ """Very simple visual-grounding placeholder."""
350
+ if pil_image is None:
351
+ return "Please upload an image.", None
352
+
353
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tfile:
354
+ pil_image.save(tfile.name)
355
+ img_path = tfile.name
356
+
357
+ device = get_model_device(chex_model)
358
+ query = chex_tok.from_list_format([{"image": img_path}, {"text": prompt_text}])
359
+ conv = [
360
+ {"from": "system", "value": "You are a helpful assistant."},
361
+ {"from": "human", "value": query},
362
+ ]
363
+ inp = chex_tok.apply_chat_template(
364
+ conv, add_generation_prompt=True, return_tensors="pt"
365
+ ).to(device)
366
+ out = chex_model.generate(
367
+ input_ids=inp, do_sample=False, num_beams=1, max_new_tokens=512
368
+ )
369
+ resp = clean_text(chex_tok.decode(out[0][inp.shape[1] :]))
370
+
371
+ # simple center box (placeholder)
372
+ w, h = pil_image.size
373
+ cx, cy, sz = w // 2, h // 2, min(w, h) // 4
374
+ draw = ImageDraw.Draw(pil_image)
375
+ draw.rectangle([(cx - sz, cy - sz), (cx + sz, cy + sz)], outline="red", width=3)
376
+
377
+ return resp, pil_image
378
+
379
+
380
+ # =============================================================================
381
+ # Gradio UI
382
+ # =============================================================================
383
+ qwen_model, qwen_proc, qwen_dev = load_qwen_model_and_processor()
384
+ med_agent = MedicalVLMAgent(qwen_model, qwen_proc, qwen_dev)
385
+
386
+ with gr.Blocks() as demo:
387
+ gr.Markdown("# Combined Medical Q&A · SAM-2 Automatic Masking · CheXagent")
388
+
389
+ # ---------------------------------------------------------
390
+ with gr.Tab("Medical Q&A"):
391
+ q_in = gr.Textbox(label="Question / description", lines=3)
392
+ q_img = gr.Image(label="Optional image", type="pil")
393
+ q_btn = gr.Button("Submit")
394
+ q_out = gr.Textbox(label="Answer")
395
+ q_btn.click(fn=med_agent.run, inputs=[q_in, q_img], outputs=q_out)
396
+
397
+ # ---------------------------------------------------------
398
+ with gr.Tab("Automatic masking (SAM-2)"):
399
+ seg_img = gr.Image(label="Image", type="pil")
400
+ seg_btn = gr.Button("Run segmentation")
401
+ seg_out = gr.Image(label="Overlay", type="pil")
402
+ seg_status = gr.Textbox(label="Status", interactive=False)
403
+ seg_btn.click(
404
+ fn=tumor_segmentation_interface,
405
+ inputs=seg_img,
406
+ outputs=[seg_out, seg_status],
407
+ )
408
+
409
+ # ---------------------------------------------------------
410
+ with gr.Tab("CheXagent – Structured report"):
411
+ gr.Markdown("Upload one or two images; the report streams live.")
412
+ cx1 = gr.Image(label="Image 1", image_mode="L", type="pil")
413
+ cx2 = gr.Image(label="Image 2", image_mode="L", type="pil")
414
+ cx_report = gr.Markdown()
415
+ gr.Interface(
416
+ fn=response_report_generation,
417
+ inputs=[cx1, cx2],
418
+ outputs=cx_report,
419
+ live=True,
420
+ ).render()
421
+
422
+ # ---------------------------------------------------------
423
+ with gr.Tab("CheXagent – Visual grounding"):
424
+ vg_img = gr.Image(image_mode="L", type="pil")
425
+ vg_prompt = gr.Textbox(value="Locate the highlighted finding:")
426
+ vg_text = gr.Markdown()
427
+ vg_out_img = gr.Image()
428
+ gr.Interface(
429
+ fn=response_phrase_grounding,
430
+ inputs=[vg_img, vg_prompt],
431
+ outputs=[vg_text, vg_out_img],
432
+ ).render()
433
+
434
+ if __name__ == "__main__":
435
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
requirements.txt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core ML/AI frameworks
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+ numpy>=1.21.0
5
+ pillow>=9.0.0
6
+
7
+ # Transformers and related
8
+ transformers>=4.40.0
9
+ accelerate>=0.20.0
10
+ qwen-vl-utils>=0.0.8
11
+
12
+ # Gradio for web interface
13
+ gradio>=4.0.0
14
+
15
+ # SAM-2 dependencies
16
+ opencv-python>=4.8.0
17
+ matplotlib>=3.5.0
18
+ hydra-core>=1.3.0
19
+ omegaconf>=2.3.0
20
+
21
+ # Additional utilities
22
+ safetensors>=0.3.0
23
+ tokenizers>=0.13.0
24
+ huggingface-hub>=0.16.0
25
+ sentencepiece>=0.1.99
26
+ protobuf>=3.20.0
27
+
28
+ # For CheXagent streaming
29
+ threading-utils
30
+
31
+ # Optional but recommended for better performance
32
+ flash-attn>=2.0.0; sys_platform != "darwin" # Skip on macOS due to compatibility issues
sam2 ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 2b90b9f5ceec907a1c18123530e92e794ad901a4