pascal-maker commited on
Commit
f22adfd
Β·
verified Β·
1 Parent(s): 69a2a38

update app.py

Browse files
Files changed (1) hide show
  1. app.py +285 -136
app.py CHANGED
@@ -6,66 +6,101 @@ Combined Medical-VLM, **SAM-2 automatic masking**, and CheXagent demo.
6
 
7
  β­‘ Changes β­‘
8
  -----------
9
- 1. All Segment-Anything-v1 fallback code has been removed.
10
- 2. A single **SAM-2 AutomaticMaskGenerator** is built once and reused.
11
- 3. Tumor-segmentation tab now runs *fully automatic* masking β€” no bounding-box textbox.
12
- 4. Fixed SAM-2 config path to use relative path instead of absolute path.
13
  """
14
 
15
  # ---------------------------------------------------------------------
16
  # Standard libs
17
  # ---------------------------------------------------------------------
18
- # ---------------------------------------------------------------------
19
- import os, warnings
20
- os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # CPU fallback for missing MPS ops
21
- warnings.filterwarnings("ignore", message=r".*upsample_bicubic2d.*") # hide one-line notice
22
-
23
  import os
24
  import sys
25
  import uuid
26
  import tempfile
 
 
27
  from threading import Thread
28
 
 
 
 
 
29
  # ---------------------------------------------------------------------
30
  # Third-party libs
31
-
32
-
33
  # ---------------------------------------------------------------------
34
  import torch
35
  import numpy as np
36
  from PIL import Image, ImageDraw
37
  import gradio as gr
38
 
39
- # If you cloned facebookresearch/sam2 into the repo root, make sure it's importable
40
- sys.path.append(os.path.abspath("."))
41
-
42
  # =============================================================================
43
- # Qwen-VLM imports & helper
44
  # =============================================================================
45
- from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
46
- from qwen_vl_utils import process_vision_info
47
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  # =============================================================================
50
- # SAM-2 imports (only SAM-2, no v1 fallback)
51
  # =============================================================================
52
- from sam2.build_sam import build_sam2
53
- from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
54
-
55
- # Alternative: try direct model loading if build_sam2 continues to fail
56
- try:
57
- from sam2.modeling.sam2_base import SAM2Base
58
- from sam2.utils.misc import get_device_index
59
- except ImportError:
60
- print("Could not import additional SAM2 components")
61
 
 
 
 
 
 
62
 
63
  # =============================================================================
64
  # CheXagent imports
65
  # =============================================================================
66
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
67
 
68
-
69
  # ---------------------------------------------------------------------
70
  # Devices
71
  # ---------------------------------------------------------------------
@@ -76,7 +111,6 @@ def get_device():
76
  return torch.device("mps")
77
  return torch.device("cpu")
78
 
79
-
80
  # =============================================================================
81
  # Qwen-VLM model & agent
82
  # =============================================================================
@@ -84,7 +118,6 @@ _qwen_model = None
84
  _qwen_processor = None
85
  _qwen_device = None
86
 
87
-
88
  def load_qwen_model_and_processor(hf_token=None):
89
  global _qwen_model, _qwen_processor, _qwen_device
90
  if _qwen_model is None:
@@ -107,7 +140,6 @@ def load_qwen_model_and_processor(hf_token=None):
107
  )
108
  return _qwen_model, _qwen_processor, _qwen_device
109
 
110
-
111
  class MedicalVLMAgent:
112
  """Light wrapper around Qwen-VLM with an optional image."""
113
 
@@ -150,56 +182,76 @@ class MedicalVLMAgent:
150
  trimmed = out[0][inputs.input_ids.shape[1] :]
151
  return self.processor.decode(trimmed, skip_special_tokens=True).strip()
152
 
153
-
154
- # =============================================================================
155
- # SAM-2 model + AutomaticMaskGenerator
156
- # =============================================================================
157
-
158
- # =============================================================================
159
- # SAM-2.1 model + AutomaticMaskGenerator (concise version)
160
  # =============================================================================
 
161
  # =============================================================================
162
- # SAM-2.1 model + AutomaticMaskGenerator (final minimal version)
163
- # =============================================================================
164
- import os
165
- from sam2.build_sam import build_sam2
166
- from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  def initialize_sam2():
169
- # These two files are already in your repo
170
- CKPT = "checkpoints/sam2.1_hiera_large.pt" # β‰ˆ2.7 GB
171
- CFG = "configs/sam2.1/sam2.1_hiera_l.yaml"
172
-
173
- # One chdir so Hydra's search path starts inside sam2/sam2/
174
- os.chdir("sam2/sam2")
175
-
176
- device = get_device()
177
- print(f"[SAM-2] building model on {device}")
178
-
179
- sam2_model = build_sam2(
180
- CFG, # relative to sam2/sam2/
181
- CKPT, # relative after chdir
182
- device=device,
183
- apply_postprocessing=False,
184
- )
185
-
186
- mask_gen = SAM2AutomaticMaskGenerator(
187
- model=sam2_model,
188
- points_per_side=32,
189
- pred_iou_thresh=0.86,
190
- stability_score_thresh=0.92,
191
- crop_n_layers=0,
192
- )
193
- return sam2_model, mask_gen
194
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
- # ---------------------- build once ----------------------
197
- try:
 
198
  _sam2_model, _mask_generator = initialize_sam2()
199
- print("[SAM-2] Successfully initialized!")
200
- except Exception as e:
201
- print(f"[SAM-2] Failed to initialize: {e}")
202
- _sam2_model, _mask_generator = None, None
203
 
204
  def automatic_mask_overlay(image_np: np.ndarray) -> np.ndarray:
205
  """Generate masks and alpha-blend them on top of the original image."""
@@ -222,9 +274,13 @@ def automatic_mask_overlay(image_np: np.ndarray) -> np.ndarray:
222
  return overlay
223
 
224
  def tumor_segmentation_interface(image: Image.Image | None):
 
225
  if image is None:
226
  return None, "Please upload an image."
227
 
 
 
 
228
  if _mask_generator is None:
229
  return None, "SAM-2 not properly initialized. Check the console for errors."
230
 
@@ -237,30 +293,81 @@ def tumor_segmentation_interface(image: Image.Image | None):
237
  return None, f"SAM-2 error: {e}"
238
 
239
  # =============================================================================
240
- # CheXagent set-up (unchanged)
241
  # =============================================================================
242
- chex_name = "StanfordAIMI/CheXagent-2-3b"
243
- chex_tok = AutoTokenizer.from_pretrained(chex_name, trust_remote_code=True)
244
- chex_model = AutoModelForCausalLM.from_pretrained(
245
- chex_name, device_map="auto", trust_remote_code=True
246
- )
247
- chex_model = chex_model.half() if torch.cuda.is_available() else chex_model.float()
248
- chex_model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
  def get_model_device(model):
 
 
252
  for p in model.parameters():
253
  return p.device
254
  return torch.device("cpu")
255
 
256
-
257
  def clean_text(text):
258
  return text.replace("</s>", "")
259
 
260
-
261
  @torch.no_grad()
262
  def response_report_generation(pil_image_1, pil_image_2):
263
  """Structured chest-X-ray report (streaming)."""
 
 
 
 
264
  streamer = TextIteratorStreamer(chex_tok, skip_prompt=True, skip_special_tokens=True)
265
  paths = []
266
  for im in [pil_image_1, pil_image_2]:
@@ -270,6 +377,10 @@ def response_report_generation(pil_image_1, pil_image_2):
270
  im.save(tfile.name)
271
  paths.append(tfile.name)
272
 
 
 
 
 
273
  device = get_model_device(chex_model)
274
  anatomies = [
275
  "View",
@@ -343,10 +454,12 @@ def response_report_generation(pil_image_1, pil_image_2):
343
  yield clean_text(partial)
344
  yield clean_text(partial)
345
 
346
-
347
  @torch.no_grad()
348
  def response_phrase_grounding(pil_image, prompt_text):
349
  """Very simple visual-grounding placeholder."""
 
 
 
350
  if pil_image is None:
351
  return "Please upload an image.", None
352
 
@@ -376,60 +489,96 @@ def response_phrase_grounding(pil_image, prompt_text):
376
 
377
  return resp, pil_image
378
 
379
-
380
  # =============================================================================
381
  # Gradio UI
382
  # =============================================================================
383
- qwen_model, qwen_proc, qwen_dev = load_qwen_model_and_processor()
384
- med_agent = MedicalVLMAgent(qwen_model, qwen_proc, qwen_dev)
385
-
386
- with gr.Blocks() as demo:
387
- gr.Markdown("# Combined Medical Q&A Β· SAM-2 Automatic Masking Β· CheXagent")
388
-
389
- # ---------------------------------------------------------
390
- with gr.Tab("Medical Q&A"):
391
- q_in = gr.Textbox(label="Question / description", lines=3)
392
- q_img = gr.Image(label="Optional image", type="pil")
393
- q_btn = gr.Button("Submit")
394
- q_out = gr.Textbox(label="Answer")
395
- q_btn.click(fn=med_agent.run, inputs=[q_in, q_img], outputs=q_out)
396
-
397
- # ---------------------------------------------------------
398
- with gr.Tab("Automatic masking (SAM-2)"):
399
- seg_img = gr.Image(label="Image", type="pil")
400
- seg_btn = gr.Button("Run segmentation")
401
- seg_out = gr.Image(label="Overlay", type="pil")
402
- seg_status = gr.Textbox(label="Status", interactive=False)
403
- seg_btn.click(
404
- fn=tumor_segmentation_interface,
405
- inputs=seg_img,
406
- outputs=[seg_out, seg_status],
407
- )
408
 
409
- # ---------------------------------------------------------
410
- with gr.Tab("CheXagent – Structured report"):
411
- gr.Markdown("Upload one or two images; the report streams live.")
412
- cx1 = gr.Image(label="Image 1", image_mode="L", type="pil")
413
- cx2 = gr.Image(label="Image 2", image_mode="L", type="pil")
414
- cx_report = gr.Markdown()
415
- gr.Interface(
416
- fn=response_report_generation,
417
- inputs=[cx1, cx2],
418
- outputs=cx_report,
419
- live=True,
420
- ).render()
421
-
422
- # ---------------------------------------------------------
423
- with gr.Tab("CheXagent – Visual grounding"):
424
- vg_img = gr.Image(image_mode="L", type="pil")
425
- vg_prompt = gr.Textbox(value="Locate the highlighted finding:")
426
- vg_text = gr.Markdown()
427
- vg_out_img = gr.Image()
428
- gr.Interface(
429
- fn=response_phrase_grounding,
430
- inputs=[vg_img, vg_prompt],
431
- outputs=[vg_text, vg_out_img],
432
- ).render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433
 
434
  if __name__ == "__main__":
 
435
  demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
6
 
7
  β­‘ Changes β­‘
8
  -----------
9
+ 1. Fixed SAM-2 installation and import issues
10
+ 2. Added proper error handling for missing dependencies
11
+ 3. Made SAM-2 functionality optional with graceful fallback
12
+ 4. Added installation instructions and requirements check
13
  """
14
 
15
  # ---------------------------------------------------------------------
16
  # Standard libs
17
  # ---------------------------------------------------------------------
 
 
 
 
 
18
  import os
19
  import sys
20
  import uuid
21
  import tempfile
22
+ import subprocess
23
+ import warnings
24
  from threading import Thread
25
 
26
+ # Environment setup
27
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
28
+ warnings.filterwarnings("ignore", message=r".*upsample_bicubic2d.*")
29
+
30
  # ---------------------------------------------------------------------
31
  # Third-party libs
 
 
32
  # ---------------------------------------------------------------------
33
  import torch
34
  import numpy as np
35
  from PIL import Image, ImageDraw
36
  import gradio as gr
37
 
 
 
 
38
  # =============================================================================
39
+ # Dependency checker and installer
40
  # =============================================================================
41
+ def check_and_install_sam2():
42
+ """Check if SAM-2 is available and attempt installation if needed."""
43
+ try:
44
+ # Try importing SAM-2
45
+ from sam2.build_sam import build_sam2
46
+ from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
47
+ return True, "SAM-2 already available"
48
+ except ImportError:
49
+ print("SAM-2 not found. Attempting to install...")
50
+ try:
51
+ # Clone SAM-2 repository
52
+ if not os.path.exists("segment-anything-2"):
53
+ subprocess.run([
54
+ "git", "clone",
55
+ "https://github.com/facebookresearch/segment-anything-2.git"
56
+ ], check=True)
57
+
58
+ # Install SAM-2
59
+ original_dir = os.getcwd()
60
+ os.chdir("segment-anything-2")
61
+ subprocess.run([sys.executable, "-m", "pip", "install", "-e", "."], check=True)
62
+ os.chdir(original_dir)
63
+
64
+ # Add to Python path
65
+ sys.path.insert(0, os.path.abspath("segment-anything-2"))
66
+
67
+ # Try importing again
68
+ from sam2.build_sam import build_sam2
69
+ from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
70
+ return True, "SAM-2 installed successfully"
71
+
72
+ except Exception as e:
73
+ print(f"Failed to install SAM-2: {e}")
74
+ return False, f"SAM-2 installation failed: {e}"
75
+
76
+ # Check SAM-2 availability
77
+ SAM2_AVAILABLE, SAM2_STATUS = check_and_install_sam2()
78
+ print(f"SAM-2 Status: {SAM2_STATUS}")
79
 
80
  # =============================================================================
81
+ # SAM-2 imports (conditional)
82
  # =============================================================================
83
+ if SAM2_AVAILABLE:
84
+ try:
85
+ from sam2.build_sam import build_sam2
86
+ from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
87
+ from sam2.modeling.sam2_base import SAM2Base
88
+ from sam2.utils.misc import get_device_index
89
+ except ImportError as e:
90
+ print(f"SAM-2 import error: {e}")
91
+ SAM2_AVAILABLE = False
92
 
93
+ # =============================================================================
94
+ # Qwen-VLM imports & helper
95
+ # =============================================================================
96
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
97
+ from qwen_vl_utils import process_vision_info
98
 
99
  # =============================================================================
100
  # CheXagent imports
101
  # =============================================================================
102
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
103
 
 
104
  # ---------------------------------------------------------------------
105
  # Devices
106
  # ---------------------------------------------------------------------
 
111
  return torch.device("mps")
112
  return torch.device("cpu")
113
 
 
114
  # =============================================================================
115
  # Qwen-VLM model & agent
116
  # =============================================================================
 
118
  _qwen_processor = None
119
  _qwen_device = None
120
 
 
121
  def load_qwen_model_and_processor(hf_token=None):
122
  global _qwen_model, _qwen_processor, _qwen_device
123
  if _qwen_model is None:
 
140
  )
141
  return _qwen_model, _qwen_processor, _qwen_device
142
 
 
143
  class MedicalVLMAgent:
144
  """Light wrapper around Qwen-VLM with an optional image."""
145
 
 
182
  trimmed = out[0][inputs.input_ids.shape[1] :]
183
  return self.processor.decode(trimmed, skip_special_tokens=True).strip()
184
 
 
 
 
 
 
 
 
185
  # =============================================================================
186
+ # SAM-2 model + AutomaticMaskGenerator (conditional)
187
  # =============================================================================
188
+ def download_sam2_checkpoint():
189
+ """Download SAM-2 checkpoint if not present."""
190
+ checkpoint_dir = "checkpoints"
191
+ checkpoint_file = "sam2.1_hiera_large.pt"
192
+ checkpoint_path = os.path.join(checkpoint_dir, checkpoint_file)
193
+
194
+ if not os.path.exists(checkpoint_path):
195
+ os.makedirs(checkpoint_dir, exist_ok=True)
196
+ print("Downloading SAM-2 checkpoint...")
197
+ try:
198
+ import urllib.request
199
+ url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt"
200
+ urllib.request.urlretrieve(url, checkpoint_path)
201
+ print("SAM-2 checkpoint downloaded successfully")
202
+ except Exception as e:
203
+ print(f"Failed to download SAM-2 checkpoint: {e}")
204
+ return None
205
+
206
+ return checkpoint_path
207
 
208
  def initialize_sam2():
209
+ """Initialize SAM-2 model and mask generator."""
210
+ if not SAM2_AVAILABLE:
211
+ return None, None
212
+
213
+ try:
214
+ # Download checkpoint if needed
215
+ checkpoint_path = download_sam2_checkpoint()
216
+ if checkpoint_path is None:
217
+ return None, None
218
+
219
+ # Config path (you may need to adjust this)
220
+ config_path = "segment-anything-2/sam2/configs/sam2.1/sam2.1_hiera_l.yaml"
221
+ if not os.path.exists(config_path):
222
+ config_path = "configs/sam2.1/sam2.1_hiera_l.yaml"
223
+
224
+ device = get_device()
225
+ print(f"[SAM-2] building model on {device}")
226
+
227
+ sam2_model = build_sam2(
228
+ config_path,
229
+ checkpoint_path,
230
+ device=device,
231
+ apply_postprocessing=False,
232
+ )
 
233
 
234
+ mask_gen = SAM2AutomaticMaskGenerator(
235
+ model=sam2_model,
236
+ points_per_side=32,
237
+ pred_iou_thresh=0.86,
238
+ stability_score_thresh=0.92,
239
+ crop_n_layers=0,
240
+ )
241
+ return sam2_model, mask_gen
242
+
243
+ except Exception as e:
244
+ print(f"[SAM-2] Failed to initialize: {e}")
245
+ return None, None
246
 
247
+ # Initialize SAM-2 (conditional)
248
+ _sam2_model, _mask_generator = None, None
249
+ if SAM2_AVAILABLE:
250
  _sam2_model, _mask_generator = initialize_sam2()
251
+ if _sam2_model is not None:
252
+ print("[SAM-2] Successfully initialized!")
253
+ else:
254
+ print("[SAM-2] Initialization failed")
255
 
256
  def automatic_mask_overlay(image_np: np.ndarray) -> np.ndarray:
257
  """Generate masks and alpha-blend them on top of the original image."""
 
274
  return overlay
275
 
276
  def tumor_segmentation_interface(image: Image.Image | None):
277
+ """Tumor segmentation interface with proper error handling."""
278
  if image is None:
279
  return None, "Please upload an image."
280
 
281
+ if not SAM2_AVAILABLE:
282
+ return None, "SAM-2 is not available. Please check installation."
283
+
284
  if _mask_generator is None:
285
  return None, "SAM-2 not properly initialized. Check the console for errors."
286
 
 
293
  return None, f"SAM-2 error: {e}"
294
 
295
  # =============================================================================
296
+ # Simple fallback segmentation (when SAM-2 is not available)
297
  # =============================================================================
298
+ def simple_segmentation_fallback(image: Image.Image | None):
299
+ """Simple fallback segmentation using basic image processing."""
300
+ if image is None:
301
+ return None, "Please upload an image."
302
+
303
+ try:
304
+ import cv2
305
+ from skimage import segmentation, color
306
+
307
+ # Convert to numpy array
308
+ img_np = np.array(image.convert("RGB"))
309
+
310
+ # Simple watershed segmentation
311
+ gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
312
+ _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
313
+
314
+ # Remove noise
315
+ kernel = np.ones((3,3), np.uint8)
316
+ opening = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel, iterations=2)
317
+
318
+ # Sure background area
319
+ sure_bg = cv2.dilate(opening, kernel, iterations=3)
320
+
321
+ # Finding sure foreground area
322
+ dist_transform = cv2.distanceTransform(opening, cv2.DIST_L2, 5)
323
+ _, sure_fg = cv2.threshold(dist_transform, 0.7*dist_transform.max(), 255, 0)
324
+
325
+ # Create overlay
326
+ overlay = img_np.copy()
327
+ overlay[sure_fg > 0] = [255, 0, 0] # Red overlay
328
+
329
+ # Alpha blend
330
+ result = cv2.addWeighted(img_np, 0.7, overlay, 0.3, 0)
331
+
332
+ return Image.fromarray(result), "Simple segmentation applied (SAM-2 not available)"
333
+
334
+ except Exception as e:
335
+ return None, f"Fallback segmentation error: {e}"
336
 
337
+ # =============================================================================
338
+ # CheXagent set-up
339
+ # =============================================================================
340
+ try:
341
+ chex_name = "StanfordAIMI/CheXagent-2-3b"
342
+ chex_tok = AutoTokenizer.from_pretrained(chex_name, trust_remote_code=True)
343
+ chex_model = AutoModelForCausalLM.from_pretrained(
344
+ chex_name, device_map="auto", trust_remote_code=True
345
+ )
346
+ chex_model = chex_model.half() if torch.cuda.is_available() else chex_model.float()
347
+ chex_model.eval()
348
+ CHEXAGENT_AVAILABLE = True
349
+ except Exception as e:
350
+ print(f"CheXagent not available: {e}")
351
+ CHEXAGENT_AVAILABLE = False
352
+ chex_tok, chex_model = None, None
353
 
354
  def get_model_device(model):
355
+ if model is None:
356
+ return torch.device("cpu")
357
  for p in model.parameters():
358
  return p.device
359
  return torch.device("cpu")
360
 
 
361
  def clean_text(text):
362
  return text.replace("</s>", "")
363
 
 
364
  @torch.no_grad()
365
  def response_report_generation(pil_image_1, pil_image_2):
366
  """Structured chest-X-ray report (streaming)."""
367
+ if not CHEXAGENT_AVAILABLE:
368
+ yield "CheXagent is not available. Please check installation."
369
+ return
370
+
371
  streamer = TextIteratorStreamer(chex_tok, skip_prompt=True, skip_special_tokens=True)
372
  paths = []
373
  for im in [pil_image_1, pil_image_2]:
 
377
  im.save(tfile.name)
378
  paths.append(tfile.name)
379
 
380
+ if not paths:
381
+ yield "Please upload at least one image."
382
+ return
383
+
384
  device = get_model_device(chex_model)
385
  anatomies = [
386
  "View",
 
454
  yield clean_text(partial)
455
  yield clean_text(partial)
456
 
 
457
  @torch.no_grad()
458
  def response_phrase_grounding(pil_image, prompt_text):
459
  """Very simple visual-grounding placeholder."""
460
+ if not CHEXAGENT_AVAILABLE:
461
+ return "CheXagent is not available. Please check installation.", None
462
+
463
  if pil_image is None:
464
  return "Please upload an image.", None
465
 
 
489
 
490
  return resp, pil_image
491
 
 
492
  # =============================================================================
493
  # Gradio UI
494
  # =============================================================================
495
+ def create_ui():
496
+ """Create the Gradio interface."""
497
+ # Load Qwen model
498
+ try:
499
+ qwen_model, qwen_proc, qwen_dev = load_qwen_model_and_processor()
500
+ med_agent = MedicalVLMAgent(qwen_model, qwen_proc, qwen_dev)
501
+ qwen_available = True
502
+ except Exception as e:
503
+ print(f"Qwen model not available: {e}")
504
+ qwen_available = False
505
+ med_agent = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
 
507
+ with gr.Blocks(title="Medical AI Assistant") as demo:
508
+ gr.Markdown("# Combined Medical Q&A Β· SAM-2 Automatic Masking Β· CheXagent")
509
+
510
+ # Status information
511
+ with gr.Row():
512
+ gr.Markdown(f"""
513
+ **System Status:**
514
+ - Qwen VLM: {'βœ… Available' if qwen_available else '❌ Not Available'}
515
+ - SAM-2: {'βœ… Available' if SAM2_AVAILABLE else '❌ Not Available'}
516
+ - CheXagent: {'βœ… Available' if CHEXAGENT_AVAILABLE else '❌ Not Available'}
517
+ """)
518
+
519
+ # Medical Q&A Tab
520
+ with gr.Tab("Medical Q&A"):
521
+ if qwen_available:
522
+ q_in = gr.Textbox(label="Question / description", lines=3)
523
+ q_img = gr.Image(label="Optional image", type="pil")
524
+ q_btn = gr.Button("Submit")
525
+ q_out = gr.Textbox(label="Answer")
526
+ q_btn.click(fn=med_agent.run, inputs=[q_in, q_img], outputs=q_out)
527
+ else:
528
+ gr.Markdown("❌ Medical Q&A is not available. Qwen model failed to load.")
529
+
530
+ # Segmentation Tab
531
+ with gr.Tab("Automatic masking"):
532
+ seg_img = gr.Image(label="Upload medical image", type="pil")
533
+ seg_btn = gr.Button("Run segmentation")
534
+ seg_out = gr.Image(label="Segmentation result", type="pil")
535
+ seg_status = gr.Textbox(label="Status", interactive=False)
536
+
537
+ if SAM2_AVAILABLE and _mask_generator is not None:
538
+ seg_btn.click(
539
+ fn=tumor_segmentation_interface,
540
+ inputs=seg_img,
541
+ outputs=[seg_out, seg_status],
542
+ )
543
+ else:
544
+ seg_btn.click(
545
+ fn=simple_segmentation_fallback,
546
+ inputs=seg_img,
547
+ outputs=[seg_out, seg_status],
548
+ )
549
+
550
+ # CheXagent Tabs
551
+ with gr.Tab("CheXagent – Structured report"):
552
+ if CHEXAGENT_AVAILABLE:
553
+ gr.Markdown("Upload one or two chest X-ray images; the report streams live.")
554
+ cx1 = gr.Image(label="Image 1", image_mode="L", type="pil")
555
+ cx2 = gr.Image(label="Image 2", image_mode="L", type="pil")
556
+ cx_report = gr.Markdown()
557
+ gr.Interface(
558
+ fn=response_report_generation,
559
+ inputs=[cx1, cx2],
560
+ outputs=cx_report,
561
+ live=True,
562
+ ).render()
563
+ else:
564
+ gr.Markdown("❌ CheXagent structured report is not available.")
565
+
566
+ with gr.Tab("CheXagent – Visual grounding"):
567
+ if CHEXAGENT_AVAILABLE:
568
+ vg_img = gr.Image(image_mode="L", type="pil")
569
+ vg_prompt = gr.Textbox(value="Locate the highlighted finding:")
570
+ vg_text = gr.Markdown()
571
+ vg_out_img = gr.Image()
572
+ gr.Interface(
573
+ fn=response_phrase_grounding,
574
+ inputs=[vg_img, vg_prompt],
575
+ outputs=[vg_text, vg_out_img],
576
+ ).render()
577
+ else:
578
+ gr.Markdown("❌ CheXagent visual grounding is not available.")
579
+
580
+ return demo
581
 
582
  if __name__ == "__main__":
583
+ demo = create_ui()
584
  demo.launch(server_name="0.0.0.0", server_port=7860, share=True)