pascal-maker commited on
Commit
45c9883
Β·
verified Β·
1 Parent(s): 72e2729

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -66
app.py CHANGED
@@ -12,7 +12,7 @@ Features:
12
  - Automatic dependency checking & installation for SAM-2
13
 
14
  Usage:
15
- python medical_ai_app.py # launches Gradio UI on port 7860
16
  Requires:
17
  torch, transformers, PIL, gradio, ultralytics, requests, opencv-python, pyyaml
18
  """
@@ -25,6 +25,9 @@ import warnings
25
  from threading import Thread
26
  from pathlib import Path
27
 
 
 
 
28
  # Environment setup
29
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
30
  warnings.filterwarnings("ignore", message=r".*upsample_bicubic2d.*")
@@ -37,11 +40,11 @@ from PIL import Image, ImageDraw
37
  import gradio as gr
38
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
39
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
 
40
 
41
  # =============================================================================
42
  # SAM-2 Alias Patch & Installer
43
  # =============================================================================
44
- import importlib
45
  try:
46
  import sam_2
47
  sys.modules['sam2'] = sam_2
@@ -94,18 +97,29 @@ _qwen_model = None
94
  _qwen_processor = None
95
  _qwen_device = None
96
 
97
- def load_qwen_model_and_processor(hf_token=None):
98
  global _qwen_model, _qwen_processor, _qwen_device
99
  if _qwen_model is None:
100
  _qwen_device = get_device()
101
- auth = {"use_auth_token": hf_token} if hf_token else {}
102
- _qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
103
- "Qwen/Qwen2.5-VL-3B-Instruct", trust_remote_code=True,
104
- torch_dtype=torch.float32, low_cpu_mem_usage=True, **auth
105
- ).to(_qwen_device)
106
- _qwen_processor = AutoProcessor.from_pretrained(
107
- "Qwen/Qwen2.5-VL-3B-Instruct", trust_remote_code=True, **auth
108
- )
 
 
 
 
 
 
 
 
 
 
 
109
  return _qwen_model, _qwen_processor, _qwen_device
110
 
111
  class MedicalVLMAgent:
@@ -116,8 +130,8 @@ class MedicalVLMAgent:
116
  "Disclaimer: I am not a licensed medical professional."
117
  )
118
  def run(self, text, image=None):
119
- if self.model is None:
120
- return "Qwen-VLM model not loaded"
121
  msgs = [{"role":"system","content":[{"type":"text","text":self.sys_prompt}]}]
122
  user_cont = []
123
  if image:
@@ -125,8 +139,12 @@ class MedicalVLMAgent:
125
  user_cont.append({"type":"image","image":tmp})
126
  user_cont.append({"type":"text","text": text or ""})
127
  msgs.append({"role":"user","content":user_cont})
128
- prompt = self.processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
129
- inputs = self.processor(text=[prompt], images=[], videos=[], padding=True, return_tensors='pt').to(self.device)
 
 
 
 
130
  out = self.model.generate(**inputs, max_new_tokens=128)
131
  resp = out[0][inputs.input_ids.shape[1]:]
132
  return self.processor.decode(resp, skip_special_tokens=True).strip()
@@ -137,58 +155,50 @@ class MedicalVLMAgent:
137
  _sam2_model, _mask_generator = (None, None)
138
  if SAM2_AVAILABLE:
139
  try:
140
- CKPT="checkpoints/sam2.1_hiera_large.pt"; CFG="configs/sam2.1/sam2.1_hiera_l.yaml"
 
141
  os.chdir("segment-anything-2/sam2/sam2")
142
- _sam2_model = build_sam2(CFG, CKPT, device=get_device(), apply_postprocessing=False)
 
 
143
  _mask_generator = SAM2AutomaticMaskGenerator(_sam2_model)
144
  except Exception as e:
145
- print(f"SAM-2 init error: {e}")
146
  _mask_generator = None
147
 
148
- def segmentation_interface(image):
149
- if image is None: return None, "Upload an image"
150
- if not _mask_generator: return None, "SAM-2 unavailable"
151
- arr = np.array(image.convert('RGB'))
152
- anns = _mask_generator.generate(arr)
153
- overlay = arr.copy()
154
- for ann in sorted(anns, key=lambda x: x['area'], reverse=True):
155
- m = ann['segmentation']; color=np.random.randint(0,255,3)
156
- overlay[m] = (overlay[m]*0.5 + color*0.5).astype(np.uint8)
157
- return Image.fromarray(overlay), f"{len(anns)} masks found"
158
-
159
- # =============================================================================
160
- # Fallback segmentation
161
- # =============================================================================
162
- def fallback_segmentation(image):
163
- if image is None: return None, "Upload an image"
164
- arr = np.array(image.convert('RGB'))
165
- gray=cv2.cvtColor(arr,cv2.COLOR_RGB2GRAY)
166
- _,th=cv2.threshold(gray,0,255,cv2.THRESH_BINARY_INV+cv2.THRESH_OTSU)
167
- overlay=arr.copy(); overlay[th>0]=[255,0,0]
168
- blended=cv2.addWeighted(arr,0.7,overlay,0.3,0)
169
- return Image.fromarray(blended), "Fallback applied"
170
-
171
  # =============================================================================
172
  # CheXagent: structured report & grounding
173
  # =============================================================================
174
  try:
175
- chex_tok = AutoTokenizer.from_pretrained("StanfordAIMI/CheXagent-2-3b", trust_remote_code=True)
176
- chex_model = AutoModelForCausalLM.from_pretrained("StanfordAIMI/CheXagent-2-3b", device_map='auto', trust_remote_code=True)
 
 
 
 
 
 
 
177
  if torch.cuda.is_available(): chex_model = chex_model.half()
178
- chex_model.eval(); CHEX_AVAILABLE=True
179
- except Exception:
180
- CHEX_AVAILABLE=False
 
 
181
 
182
  @torch.no_grad()
183
  def report_generation(im1, im2):
184
- if not CHEX_AVAILABLE: yield "CheXagent unavailable"; return
 
 
185
  streamer = TextIteratorStreamer(chex_tok, skip_prompt=True)
186
- yield "Report streaming not fully implemented"
187
 
188
  @torch.no_grad()
189
  def phrase_grounding(image, prompt):
190
- if not CHEX_AVAILABLE: return "CheXagent unavailable", None
191
- w,h=image.size; draw=ImageDraw.Draw(image)
 
192
  draw.rectangle([(w*0.25,h*0.25),(w*0.75,h*0.75)], outline='red', width=3)
193
  return prompt, image
194
 
@@ -196,47 +206,46 @@ def phrase_grounding(image, prompt):
196
  # Gradio UI
197
  # =============================================================================
198
  def create_ui():
199
- # Load Qwen agent
200
- try:
201
- m, p, d = load_qwen_model_and_processor()
202
- med = MedicalVLMAgent(m,p,d)
203
- qwen_ok = True
204
- except Exception:
205
- med = None
206
- qwen_ok = False
207
 
208
  with gr.Blocks() as demo:
209
  gr.Markdown("# Medical AI Assistant")
210
- gr.Markdown(f"- Qwen: {'βœ…' if qwen_ok else '❌'} - SAM-2: {'βœ…' if _mask_generator else '❌'} - CheX: {'βœ…' if CHEX_AVAILABLE else '❌'}")
 
 
 
 
211
  with gr.Tab("Medical Q&A"):
212
- if qwen_ok and med is not None:
213
  txt = gr.Textbox(label="Question / description", lines=3)
214
  img = gr.Image(label="Optional image", type='pil')
215
  out = gr.Textbox(label="Answer")
216
- gr.Button("Ask").click(med.run, inputs=[txt, img], outputs=out)
217
  else:
218
- gr.Markdown("❌ Medical Q&A is not available.")
219
  with gr.Tab("Segmentation"):
220
  seg = gr.Image(label="Upload image", type='pil')
221
  so = gr.Image(label="Result")
222
  ss = gr.Textbox(label="Status", interactive=False)
223
  fn = segmentation_interface if _mask_generator else fallback_segmentation
224
- gr.Button("Segment").click(fn, inputs=[seg], outputs=[so, ss])
225
  with gr.Tab("CheXagent Report"):
226
  c1 = gr.Image(type='pil', label="Image 1")
227
  c2 = gr.Image(type='pil', label="Image 2")
228
  rout = gr.Markdown()
229
  if CHEX_AVAILABLE:
230
- gr.Interface(fn=report_generation, inputs=[c1, c2], outputs=rout, live=True).render()
231
  else:
232
- gr.Markdown("❌ CheXagent report not available.")
233
  with gr.Tab("CheXagent Grounding"):
234
  gi = gr.Image(type='pil', label="Image")
235
  gp = gr.Textbox(label="Prompt")
236
  gout = gr.Textbox(label="Response")
237
  goimg = gr.Image(label="Output Image")
238
  if CHEX_AVAILABLE:
239
- gr.Interface(fn=phrase_grounding, inputs=[gi, gp], outputs=[gout, goimg]).render()
240
  else:
241
  gr.Markdown("❌ CheXagent grounding not available.")
242
  return demo
@@ -244,3 +253,4 @@ def create_ui():
244
  if __name__ == "__main__":
245
  ui = create_ui()
246
  ui.launch(server_name='0.0.0.0', server_port=7860, share=True)
 
 
12
  - Automatic dependency checking & installation for SAM-2
13
 
14
  Usage:
15
+ HF_TOKEN=<your_token> python medical_ai_app.py # if private models require auth
16
  Requires:
17
  torch, transformers, PIL, gradio, ultralytics, requests, opencv-python, pyyaml
18
  """
 
25
  from threading import Thread
26
  from pathlib import Path
27
 
28
+ # Hugging Face token (for private models)
29
+ HF_TOKEN = os.getenv("HF_TOKEN")
30
+
31
  # Environment setup
32
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
33
  warnings.filterwarnings("ignore", message=r".*upsample_bicubic2d.*")
 
40
  import gradio as gr
41
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
42
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
43
+ import importlib
44
 
45
  # =============================================================================
46
  # SAM-2 Alias Patch & Installer
47
  # =============================================================================
 
48
  try:
49
  import sam_2
50
  sys.modules['sam2'] = sam_2
 
97
  _qwen_processor = None
98
  _qwen_device = None
99
 
100
+ def load_qwen_model_and_processor():
101
  global _qwen_model, _qwen_processor, _qwen_device
102
  if _qwen_model is None:
103
  _qwen_device = get_device()
104
+ auth = {"use_auth_token": HF_TOKEN} if HF_TOKEN else {}
105
+ print(f"[Qwen] Loading model with auth={'yes' if HF_TOKEN else 'no'} on {_qwen_device}")
106
+ try:
107
+ _qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
108
+ "Qwen/Qwen2.5-VL-3B-Instruct",
109
+ trust_remote_code=True,
110
+ torch_dtype=torch.float32,
111
+ low_cpu_mem_usage=True,
112
+ **auth
113
+ ).to(_qwen_device)
114
+ _qwen_processor = AutoProcessor.from_pretrained(
115
+ "Qwen/Qwen2.5-VL-3B-Instruct",
116
+ trust_remote_code=True,
117
+ **auth
118
+ )
119
+ except Exception as e:
120
+ print(f"[Qwen] Model load failed: {e}")
121
+ _qwen_model = None
122
+ _qwen_processor = None
123
  return _qwen_model, _qwen_processor, _qwen_device
124
 
125
  class MedicalVLMAgent:
 
130
  "Disclaimer: I am not a licensed medical professional."
131
  )
132
  def run(self, text, image=None):
133
+ if not self.model or not self.processor:
134
+ return "Qwen-VLM is not available"
135
  msgs = [{"role":"system","content":[{"type":"text","text":self.sys_prompt}]}]
136
  user_cont = []
137
  if image:
 
139
  user_cont.append({"type":"image","image":tmp})
140
  user_cont.append({"type":"text","text": text or ""})
141
  msgs.append({"role":"user","content":user_cont})
142
+ prompt = self.processor.apply_chat_template(
143
+ msgs, tokenize=False, add_generation_prompt=True
144
+ )
145
+ inputs = self.processor(
146
+ text=[prompt], images=[], videos=[], padding=True, return_tensors='pt'
147
+ ).to(self.device)
148
  out = self.model.generate(**inputs, max_new_tokens=128)
149
  resp = out[0][inputs.input_ids.shape[1]:]
150
  return self.processor.decode(resp, skip_special_tokens=True).strip()
 
155
  _sam2_model, _mask_generator = (None, None)
156
  if SAM2_AVAILABLE:
157
  try:
158
+ CKPT="checkpoints/sam2.1_hiera_large.pt"
159
+ CFG="configs/sam2.1/sam2.1_hiera_l.yaml"
160
  os.chdir("segment-anything-2/sam2/sam2")
161
+ _sam2_model = build_sam2(
162
+ CFG, CKPT, device=get_device(), apply_postprocessing=False
163
+ )
164
  _mask_generator = SAM2AutomaticMaskGenerator(_sam2_model)
165
  except Exception as e:
166
+ print(f"[SAM-2] Initialization error: {e}")
167
  _mask_generator = None
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  # =============================================================================
170
  # CheXagent: structured report & grounding
171
  # =============================================================================
172
  try:
173
+ print(f"[CheXagent] Loading with auth={'yes' if HF_TOKEN else 'no'}")
174
+ chex_tok = AutoTokenizer.from_pretrained(
175
+ "StanfordAIMI/CheXagent-2-3b", trust_remote_code=True,
176
+ use_auth_token=HF_TOKEN
177
+ )
178
+ chex_model = AutoModelForCausalLM.from_pretrained(
179
+ "StanfordAIMI/CheXagent-2-3b", device_map='auto', trust_remote_code=True,
180
+ use_auth_token=HF_TOKEN
181
+ )
182
  if torch.cuda.is_available(): chex_model = chex_model.half()
183
+ chex_model.eval()
184
+ CHEX_AVAILABLE = True
185
+ except Exception as e:
186
+ print(f"[CheXagent] Load failed: {e}")
187
+ CHEX_AVAILABLE = False
188
 
189
  @torch.no_grad()
190
  def report_generation(im1, im2):
191
+ if not CHEX_AVAILABLE:
192
+ yield "CheXagent unavailable"
193
+ return
194
  streamer = TextIteratorStreamer(chex_tok, skip_prompt=True)
195
+ yield "Streaming report... (not fully implemented)"
196
 
197
  @torch.no_grad()
198
  def phrase_grounding(image, prompt):
199
+ if not CHEX_AVAILABLE:
200
+ return "CheXagent unavailable", None
201
+ w,h = image.size; draw = ImageDraw.Draw(image)
202
  draw.rectangle([(w*0.25,h*0.25),(w*0.75,h*0.75)], outline='red', width=3)
203
  return prompt, image
204
 
 
206
  # Gradio UI
207
  # =============================================================================
208
  def create_ui():
209
+ m, p, d = load_qwen_model_and_processor()
210
+ qwen_ok = bool(m and p)
211
+ med = MedicalVLMAgent(m, p, d) if qwen_ok else None
 
 
 
 
 
212
 
213
  with gr.Blocks() as demo:
214
  gr.Markdown("# Medical AI Assistant")
215
+ gr.Markdown(
216
+ f"- Qwen: {'βœ…' if qwen_ok else '❌'} "
217
+ f"- SAM-2: {'βœ…' if _mask_generator else '❌'} "
218
+ f"- CheXagent: {'βœ…' if CHEX_AVAILABLE else '❌'}"
219
+ )
220
  with gr.Tab("Medical Q&A"):
221
+ if qwen_ok:
222
  txt = gr.Textbox(label="Question / description", lines=3)
223
  img = gr.Image(label="Optional image", type='pil')
224
  out = gr.Textbox(label="Answer")
225
+ gr.Button("Ask").click(med.run, [txt, img], out)
226
  else:
227
+ gr.Markdown("❌ Medical Q&A not available. Check HF_TOKEN and connectivity.")
228
  with gr.Tab("Segmentation"):
229
  seg = gr.Image(label="Upload image", type='pil')
230
  so = gr.Image(label="Result")
231
  ss = gr.Textbox(label="Status", interactive=False)
232
  fn = segmentation_interface if _mask_generator else fallback_segmentation
233
+ gr.Button("Segment").click(fn, [seg], [so, ss])
234
  with gr.Tab("CheXagent Report"):
235
  c1 = gr.Image(type='pil', label="Image 1")
236
  c2 = gr.Image(type='pil', label="Image 2")
237
  rout = gr.Markdown()
238
  if CHEX_AVAILABLE:
239
+ gr.Interface(report_generation, [c1, c2], rout, live=True).render()
240
  else:
241
+ gr.Markdown("❌ CheXagent report not available. Check HF_TOKEN and connectivity.")
242
  with gr.Tab("CheXagent Grounding"):
243
  gi = gr.Image(type='pil', label="Image")
244
  gp = gr.Textbox(label="Prompt")
245
  gout = gr.Textbox(label="Response")
246
  goimg = gr.Image(label="Output Image")
247
  if CHEX_AVAILABLE:
248
+ gr.Interface(phrase_grounding, [gi, gp], [gout, goimg]).render()
249
  else:
250
  gr.Markdown("❌ CheXagent grounding not available.")
251
  return demo
 
253
  if __name__ == "__main__":
254
  ui = create_ui()
255
  ui.launch(server_name='0.0.0.0', server_port=7860, share=True)
256
+