pascal-maker commited on
Commit
07f5f6e
·
verified ·
1 Parent(s): 6cd7b7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +233 -559
app.py CHANGED
@@ -1,582 +1,256 @@
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
-
4
  """
5
- Combined Medical-VLM, **SAM-2 automatic masking**, and CheXagent demo.
 
6
 
7
- Changes
8
- -----------
9
- 1. Fixed SAM-2 installation and import issues
10
- 2. Added proper error handling for missing dependencies
11
- 3. Made SAM-2 functionality optional with graceful fallback
12
- 4. Added installation instructions and requirements check
13
  """
 
14
 
15
- # ---------------------------------------------------------------------
16
- # Standard libs
17
- # ---------------------------------------------------------------------
18
- import os
19
  import sys
20
- import uuid
21
- import tempfile
22
  import subprocess
23
- import warnings
24
- from threading import Thread
25
-
26
- # Environment setup
27
- os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
28
- warnings.filterwarnings("ignore", message=r".*upsample_bicubic2d.*")
29
 
30
- # ---------------------------------------------------------------------
31
- # Third-party libs
32
- # ---------------------------------------------------------------------
33
- import torch
34
  import numpy as np
35
- from PIL import Image, ImageDraw
36
- import gradio as gr
37
-
38
- # =============================================================================
39
- # Dependency checker and installer
40
- # =============================================================================
41
- def check_and_install_sam2():
42
- """Check if SAM-2 is available and attempt installation if needed."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  try:
44
- print("[SAM-2 Debug] Attempting to import SAM-2 modules...")
45
  from sam2.build_sam import build_sam2
46
- from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
47
- print("[SAM-2 Debug] Successfully imported SAM-2 modules")
48
  return True, "SAM-2 already available"
49
- except ImportError as e:
50
- print(f"[SAM-2 Debug] Import error: {str(e)}")
51
- print("[SAM-2 Debug] Attempting to install SAM-2...")
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  try:
53
- # Clone SAM-2 repository
54
- if not os.path.exists("segment-anything-2"):
55
- print("[SAM-2 Debug] Cloning SAM-2 repository...")
56
- subprocess.run([
57
- "git", "clone",
58
- "https://github.com/facebookresearch/segment-anything-2.git"
59
- ], check=True)
60
- print("[SAM-2 Debug] Repository cloned successfully")
61
-
62
- # Install SAM-2
63
- print("[SAM-2 Debug] Installing SAM-2...")
64
- original_dir = os.getcwd()
65
- os.chdir("segment-anything-2")
66
- subprocess.run([sys.executable, "-m", "pip", "install", "-e", "."], check=True)
67
- os.chdir(original_dir)
68
- print("[SAM-2 Debug] Installation completed")
69
-
70
- # Add to Python path
71
- sam2_path = os.path.abspath("segment-anything-2")
72
- if sam2_path not in sys.path:
73
- sys.path.insert(0, sam2_path)
74
- print(f"[SAM-2 Debug] Added {sam2_path} to Python path")
75
-
76
- # Try importing again
77
- print("[SAM-2 Debug] Attempting to import SAM-2 modules again...")
78
- from sam2.build_sam import build_sam2
79
- from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
80
- print("[SAM-2 Debug] Successfully imported SAM-2 modules after installation")
81
- return True, "SAM-2 installed successfully"
82
-
83
- except Exception as e:
84
- print(f"[SAM-2 Debug] Installation failed: {str(e)}")
85
- print(f"[SAM-2 Debug] Error type: {type(e).__name__}")
86
- return False, f"SAM-2 installation failed: {e}"
87
 
88
- # Check SAM-2 availability
89
  SAM2_AVAILABLE, SAM2_STATUS = check_and_install_sam2()
90
  print(f"SAM-2 Status: {SAM2_STATUS}")
91
-
92
- # =============================================================================
93
- # SAM-2 imports (conditional)
94
- # =============================================================================
95
  if SAM2_AVAILABLE:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  try:
97
- from sam2.build_sam import build_sam2
98
- from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
99
- from sam2.modeling.sam2_base import SAM2Base
100
- except ImportError as e:
101
- print(f"SAM-2 import error: {e}")
102
- SAM2_AVAILABLE = False
103
-
104
- # =============================================================================
105
- # Qwen-VLM imports & helper
106
- # =============================================================================
107
- from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
108
- from qwen_vl_utils import process_vision_info
109
-
110
- # =============================================================================
111
- # CheXagent imports
112
- # =============================================================================
113
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
114
-
115
- # ---------------------------------------------------------------------
116
- # Devices
117
- # ---------------------------------------------------------------------
118
- def get_device():
119
- if torch.cuda.is_available():
120
- return torch.device("cuda")
121
- if torch.backends.mps.is_available():
122
- return torch.device("mps")
123
- return torch.device("cpu")
124
-
125
- # =============================================================================
126
- # Qwen-VLM model & agent
127
- # =============================================================================
128
- _qwen_model = None
129
- _qwen_processor = None
130
- _qwen_device = None
131
-
132
- def load_qwen_model_and_processor(hf_token=None):
133
- global _qwen_model, _qwen_processor, _qwen_device
134
- if _qwen_model is None:
135
- _qwen_device = "mps" if torch.backends.mps.is_available() else "cpu"
136
- print(f"[Qwen] loading model on {_qwen_device}")
137
- auth_kwargs = {"use_auth_token": hf_token} if hf_token else {}
138
- _qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
139
- "Qwen/Qwen2.5-VL-3B-Instruct",
140
- trust_remote_code=True,
141
- attn_implementation="eager",
142
- torch_dtype=torch.float32,
143
- low_cpu_mem_usage=True,
144
- device_map=None,
145
- **auth_kwargs,
146
- ).to(_qwen_device)
147
- _qwen_processor = AutoProcessor.from_pretrained(
148
- "Qwen/Qwen2.5-VL-3B-Instruct",
149
- trust_remote_code=True,
150
- **auth_kwargs,
151
- )
152
- return _qwen_model, _qwen_processor, _qwen_device
153
-
154
- class MedicalVLMAgent:
155
- """Light wrapper around Qwen-VLM with an optional image."""
156
-
157
- def __init__(self, model, processor, device):
158
- self.model = model
159
- self.processor = processor
160
- self.device = device
161
- self.system_prompt = (
162
- "You are a medical information assistant with vision capabilities.\n"
163
- "Disclaimer: I am not a licensed medical professional. "
164
- "The information provided is for reference only and should not be taken as medical advice."
165
- )
166
-
167
- def run(self, user_text: str, image: Image.Image | None = None) -> str:
168
- messages = [
169
- {"role": "system", "content": [{"type": "text", "text": self.system_prompt}]}
170
- ]
171
- user_content = []
172
- if image is not None:
173
- tmp = f"/tmp/{uuid.uuid4()}.png"
174
- image.save(tmp)
175
- user_content.append({"type": "image", "image": tmp})
176
- user_content.append({"type": "text", "text": user_text or "Please describe the image."})
177
- messages.append({"role": "user", "content": user_content})
178
-
179
- prompt_text = self.processor.apply_chat_template(
180
- messages, tokenize=False, add_generation_prompt=True
181
- )
182
- img_inputs, vid_inputs = process_vision_info(messages)
183
- inputs = self.processor(
184
- text=[prompt_text],
185
- images=img_inputs,
186
- videos=vid_inputs,
187
- padding=True,
188
- return_tensors="pt",
189
- ).to(self.device)
190
-
191
- with torch.no_grad():
192
- out = self.model.generate(**inputs, max_new_tokens=128)
193
- trimmed = out[0][inputs.input_ids.shape[1] :]
194
- return self.processor.decode(trimmed, skip_special_tokens=True).strip()
195
-
196
- # =============================================================================
197
- # SAM-2 model + AutomaticMaskGenerator (final minimal version)
198
- # =============================================================================
199
- import os
200
- import numpy as np
201
- from PIL import Image, ImageDraw
202
- from sam2.build_sam import build_sam2
203
- from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
204
-
205
- def initialize_sam2():
206
- # These two files are already in your repo
207
- CKPT = "checkpoints/sam2.1_hiera_large.pt" # ≈2.7 GB
208
- CFG = "configs/sam2.1/sam2.1_hiera_l.yaml"
209
-
210
- # One chdir so Hydra's search path starts inside sam2/sam2/
211
- os.chdir("sam2/sam2")
212
-
213
- device = get_device()
214
- print(f"[SAM-2] building model on {device}")
215
-
216
- sam2_model = build_sam2(
217
- CFG, # relative to sam2/sam2/
218
- CKPT, # relative after chdir
219
- device=device,
220
- apply_postprocessing=False,
221
- )
222
-
223
- mask_gen = SAM2AutomaticMaskGenerator(
224
- model=sam2_model,
225
- points_per_side=32,
226
- pred_iou_thresh=0.86,
227
- stability_score_thresh=0.92,
228
- crop_n_layers=0,
229
- )
230
- return sam2_model, mask_gen
231
-
232
- # ---------------------- build once ----------------------
233
- try:
234
- _sam2_model, _mask_generator = initialize_sam2()
235
- print("[SAM-2] Successfully initialized!")
236
- except Exception as e:
237
- print(f"[SAM-2] Failed to initialize: {e}")
238
- _sam2_model, _mask_generator = None, None
239
-
240
- def automatic_mask_overlay(image_np: np.ndarray) -> np.ndarray:
241
- """Generate masks and alpha-blend them on top of the original image."""
242
- if _mask_generator is None:
243
- raise RuntimeError("SAM-2 mask generator not initialized")
244
-
245
- anns = _mask_generator.generate(image_np)
246
- if not anns:
247
- return image_np
248
-
249
- overlay = image_np.copy()
250
- if overlay.ndim == 2: # grayscale → RGB
251
- overlay = np.stack([overlay] * 3, axis=2)
252
-
253
- for ann in sorted(anns, key=lambda x: x["area"], reverse=True):
254
- m = ann["segmentation"]
255
- color = np.random.randint(0, 255, 3, dtype=np.uint8)
256
- overlay[m] = (overlay[m] * 0.5 + color * 0.5).astype(np.uint8)
257
-
258
- return overlay
259
-
260
- def tumor_segmentation_interface(image: Image.Image | None):
261
- if image is None:
262
- return None, "Please upload an image."
263
-
264
- if _mask_generator is None:
265
- return None, "SAM-2 not properly initialized. Check the console for errors."
266
-
267
- try:
268
- img_np = np.array(image.convert("RGB"))
269
- out_np = automatic_mask_overlay(img_np)
270
- n_masks = len(_mask_generator.generate(img_np))
271
- return Image.fromarray(out_np), f"{n_masks} masks found."
272
  except Exception as e:
273
- return None, f"SAM-2 error: {e}"
274
-
275
- # =============================================================================
276
- # Simple fallback segmentation (when SAM-2 is not available)
277
- # =============================================================================
278
- def simple_segmentation_fallback(image: Image.Image | None):
279
- """Simple fallback segmentation using basic image processing."""
280
- if image is None:
281
- return None, "Please upload an image."
282
-
283
  try:
284
- import cv2
285
- from skimage import segmentation, color
286
-
287
- # Convert to numpy array
288
- img_np = np.array(image.convert("RGB"))
289
-
290
- # Simple watershed segmentation
291
- gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
292
- _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
293
-
294
- # Remove noise
295
- kernel = np.ones((3,3), np.uint8)
296
- opening = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel, iterations=2)
297
-
298
- # Sure background area
299
- sure_bg = cv2.dilate(opening, kernel, iterations=3)
300
-
301
- # Finding sure foreground area
302
- dist_transform = cv2.distanceTransform(opening, cv2.DIST_L2, 5)
303
- _, sure_fg = cv2.threshold(dist_transform, 0.7*dist_transform.max(), 255, 0)
304
-
305
- # Create overlay
306
- overlay = img_np.copy()
307
- overlay[sure_fg > 0] = [255, 0, 0] # Red overlay
308
-
309
- # Alpha blend
310
- result = cv2.addWeighted(img_np, 0.7, overlay, 0.3, 0)
311
-
312
- return Image.fromarray(result), "Simple segmentation applied (SAM-2 not available)"
313
-
314
- except Exception as e:
315
- return None, f"Fallback segmentation error: {e}"
316
-
317
- # =============================================================================
318
- # CheXagent set-up
319
- # =============================================================================
320
- try:
321
- print("[CheXagent] Starting initialization...")
322
- chex_name = "StanfordAIMI/CheXagent-2-3b"
323
- print(f"[CheXagent] Loading tokenizer from {chex_name}")
324
- chex_tok = AutoTokenizer.from_pretrained(chex_name, trust_remote_code=True)
325
- print("[CheXagent] Tokenizer loaded successfully")
326
-
327
- print("[CheXagent] Loading model...")
328
- chex_model = AutoModelForCausalLM.from_pretrained(
329
- chex_name,
330
- device_map="auto",
331
- trust_remote_code=True,
332
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
333
- )
334
- print("[CheXagent] Model loaded successfully")
335
-
336
- if torch.cuda.is_available():
337
- print("[CheXagent] Converting to half precision for GPU")
338
- chex_model = chex_model.half()
339
  else:
340
- print("[CheXagent] Using full precision for CPU")
341
- chex_model = chex_model.float()
342
-
343
- chex_model.eval()
344
- CHEXAGENT_AVAILABLE = True
345
- print("[CheXagent] Initialization complete")
346
- except Exception as e:
347
- print(f"[CheXagent] Initialization failed: {str(e)}")
348
- print(f"[CheXagent] Error type: {type(e).__name__}")
349
- CHEXAGENT_AVAILABLE = False
350
- chex_tok, chex_model = None, None
351
-
352
- def get_model_device(model):
353
- if model is None:
354
- return torch.device("cpu")
355
- for p in model.parameters():
356
- return p.device
357
- return torch.device("cpu")
358
-
359
- def clean_text(text):
360
- return text.replace("</s>", "")
361
-
362
- @torch.no_grad()
363
- def response_report_generation(pil_image_1, pil_image_2):
364
- """Structured chest-X-ray report (streaming)."""
365
- if not CHEXAGENT_AVAILABLE:
366
- yield "CheXagent is not available. Please check installation."
367
- return
368
-
369
- streamer = TextIteratorStreamer(chex_tok, skip_prompt=True, skip_special_tokens=True)
370
- paths = []
371
- for im in [pil_image_1, pil_image_2]:
372
- if im is None:
373
- continue
374
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tfile:
375
- im.save(tfile.name)
376
- paths.append(tfile.name)
377
-
378
- if not paths:
379
- yield "Please upload at least one image."
380
- return
381
-
382
- device = get_model_device(chex_model)
383
- anatomies = [
384
- "View",
385
- "Airway",
386
- "Breathing",
387
- "Cardiac",
388
- "Diaphragm",
389
- "Everything else (e.g., mediastinal contours, bones, soft tissues, tubes, valves, pacemakers)",
390
- ]
391
- prompts = [
392
- "Determine the view of this CXR",
393
- *[
394
- f'Provide a detailed description of "{a}" in the chest X-ray'
395
- for a in anatomies[1:]
396
- ],
397
- ]
398
-
399
- findings = ""
400
- partial = "## Generating Findings (step-by-step):\n\n"
401
- for idx, (anat, prompt) in enumerate(zip(anatomies, prompts)):
402
- query = chex_tok.from_list_format(
403
- [*[{"image": p} for p in paths], {"text": prompt}]
404
- )
405
- conv = [
406
- {"from": "system", "value": "You are a helpful assistant."},
407
- {"from": "human", "value": query},
408
- ]
409
- inp = chex_tok.apply_chat_template(
410
- conv, add_generation_prompt=True, return_tensors="pt"
411
- ).to(device)
412
- generate_kwargs = dict(
413
- input_ids=inp,
414
- max_new_tokens=512,
415
- do_sample=False,
416
- num_beams=1,
417
- streamer=streamer,
418
- )
419
- Thread(target=chex_model.generate, kwargs=generate_kwargs).start()
420
- partial += f"**Step {idx}: {anat}...**\n\n"
421
- for tok in streamer:
422
- if idx:
423
- findings += tok
424
- partial += tok
425
- yield clean_text(partial)
426
- partial += "\n\n"
427
- findings += " "
428
- findings = findings.strip()
429
-
430
- # Impression
431
- partial += "## Generating Impression\n\n"
432
- prompt = f"Write the Impression section for the following Findings: {findings}"
433
- conv = [
434
- {"from": "system", "value": "You are a helpful assistant."},
435
- {"from": "human", "value": chex_tok.from_list_format([{"text": prompt}])},
436
- ]
437
- inp = chex_tok.apply_chat_template(
438
- conv, add_generation_prompt=True, return_tensors="pt"
439
- ).to(device)
440
- Thread(
441
- target=chex_model.generate,
442
- kwargs=dict(
443
- input_ids=inp,
444
- do_sample=False,
445
- num_beams=1,
446
- max_new_tokens=512,
447
- streamer=streamer,
448
- ),
449
- ).start()
450
- for tok in streamer:
451
- partial += tok
452
- yield clean_text(partial)
453
- yield clean_text(partial)
454
-
455
- @torch.no_grad()
456
- def response_phrase_grounding(pil_image, prompt_text):
457
- """Very simple visual-grounding placeholder."""
458
- if not CHEXAGENT_AVAILABLE:
459
- return "CheXagent is not available. Please check installation.", None
460
-
461
- if pil_image is None:
462
- return "Please upload an image.", None
463
-
464
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tfile:
465
- pil_image.save(tfile.name)
466
- img_path = tfile.name
467
-
468
- device = get_model_device(chex_model)
469
- query = chex_tok.from_list_format([{"image": img_path}, {"text": prompt_text}])
470
- conv = [
471
- {"from": "system", "value": "You are a helpful assistant."},
472
- {"from": "human", "value": query},
473
- ]
474
- inp = chex_tok.apply_chat_template(
475
- conv, add_generation_prompt=True, return_tensors="pt"
476
- ).to(device)
477
- out = chex_model.generate(
478
- input_ids=inp, do_sample=False, num_beams=1, max_new_tokens=512
479
- )
480
- resp = clean_text(chex_tok.decode(out[0][inp.shape[1] :]))
481
-
482
- # simple center box (placeholder)
483
- w, h = pil_image.size
484
- cx, cy, sz = w // 2, h // 2, min(w, h) // 4
485
- draw = ImageDraw.Draw(pil_image)
486
- draw.rectangle([(cx - sz, cy - sz), (cx + sz, cy + sz)], outline="red", width=3)
487
-
488
- return resp, pil_image
489
-
490
- # =============================================================================
491
- # Gradio UI
492
- # =============================================================================
493
- def create_ui():
494
- """Create the Gradio interface."""
495
- # Load Qwen model
496
- try:
497
- qwen_model, qwen_proc, qwen_dev = load_qwen_model_and_processor()
498
- med_agent = MedicalVLMAgent(qwen_model, qwen_proc, qwen_dev)
499
- qwen_available = True
500
- except Exception as e:
501
- print(f"Qwen model not available: {e}")
502
- qwen_available = False
503
- med_agent = None
504
-
505
- with gr.Blocks(title="Medical AI Assistant") as demo:
506
- gr.Markdown("# Combined Medical Q&A · SAM-2 Automatic Masking · CheXagent")
507
-
508
- # Status information
509
- with gr.Row():
510
- gr.Markdown(f"""
511
- **System Status:**
512
- - Qwen VLM: {'✅ Available' if qwen_available else '❌ Not Available'}
513
- - SAM-2: {'✅ Available' if SAM2_AVAILABLE else '❌ Not Available'}
514
- - CheXagent: {'✅ Available' if CHEXAGENT_AVAILABLE else '❌ Not Available'}
515
- """)
516
-
517
- # Medical Q&A Tab
518
- with gr.Tab("Medical Q&A"):
519
- if qwen_available:
520
- q_in = gr.Textbox(label="Question / description", lines=3)
521
- q_img = gr.Image(label="Optional image", type="pil")
522
- q_btn = gr.Button("Submit")
523
- q_out = gr.Textbox(label="Answer")
524
- q_btn.click(fn=med_agent.run, inputs=[q_in, q_img], outputs=q_out)
525
- else:
526
- gr.Markdown("❌ Medical Q&A is not available. Qwen model failed to load.")
527
-
528
- # Segmentation Tab
529
- with gr.Tab("Automatic masking"):
530
- seg_img = gr.Image(label="Upload medical image", type="pil")
531
- seg_btn = gr.Button("Run segmentation")
532
- seg_out = gr.Image(label="Segmentation result", type="pil")
533
- seg_status = gr.Textbox(label="Status", interactive=False)
534
-
535
- if SAM2_AVAILABLE and _mask_generator is not None:
536
- seg_btn.click(
537
- fn=tumor_segmentation_interface,
538
- inputs=seg_img,
539
- outputs=[seg_out, seg_status],
540
- )
541
- else:
542
- seg_btn.click(
543
- fn=simple_segmentation_fallback,
544
- inputs=seg_img,
545
- outputs=[seg_out, seg_status],
546
- )
547
-
548
- # CheXagent Tabs
549
- with gr.Tab("CheXagent – Structured report"):
550
- if CHEXAGENT_AVAILABLE:
551
- gr.Markdown("Upload one or two chest X-ray images; the report streams live.")
552
- cx1 = gr.Image(label="Image 1", image_mode="L", type="pil")
553
- cx2 = gr.Image(label="Image 2", image_mode="L", type="pil")
554
- cx_report = gr.Markdown()
555
- gr.Interface(
556
- fn=response_report_generation,
557
- inputs=[cx1, cx2],
558
- outputs=cx_report,
559
- live=True,
560
- ).render()
561
- else:
562
- gr.Markdown("❌ CheXagent structured report is not available.")
563
-
564
- with gr.Tab("CheXagent – Visual grounding"):
565
- if CHEXAGENT_AVAILABLE:
566
- vg_img = gr.Image(image_mode="L", type="pil")
567
- vg_prompt = gr.Textbox(value="Locate the highlighted finding:")
568
- vg_text = gr.Markdown()
569
- vg_out_img = gr.Image()
570
- gr.Interface(
571
- fn=response_phrase_grounding,
572
- inputs=[vg_img, vg_prompt],
573
- outputs=[vg_text, vg_out_img],
574
- ).render()
575
- else:
576
- gr.Markdown("❌ CheXagent visual grounding is not available.")
577
 
578
- return demo
579
 
580
- if __name__ == "__main__":
581
- demo = create_ui()
582
- demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
1
  #!/usr/bin/env python
 
 
2
  """
3
+ post_analyzer_enhanced.py · Enhanced Post Analysis Tool
4
+ =====================================================
5
 
6
+ Analyzes images of posts by running YOLOv8 inference, applying spatial layout rules,
7
+ computing a nuanced confidence score, and detecting anomalies ("afwijking").
8
+ Generates JSON reports for image directories and uploaded images.
9
+ Includes SAM-2 alias patch for Hugging Face compatibility.
 
 
10
  """
11
+ from __future__ import annotations
12
 
13
+ import argparse
14
+ import json
 
 
15
  import sys
16
+ import os
 
17
  import subprocess
18
+ import tempfile
19
+ from pathlib import Path
20
+ from typing import List, Union
21
+ from datetime import datetime
22
+ from urllib.parse import urlparse
 
23
 
24
+ import cv2
25
+ import yaml
 
 
26
  import numpy as np
27
+ from dataclasses import dataclass
28
+ from ultralytics import YOLO
29
+ import requests
30
+ from PIL import Image
31
+ import io
32
+
33
+ # ───── Data Classes ──────────────────────────────────────────────────────────
34
+ @dataclass
35
+ class PostPart:
36
+ name: str
37
+ x: float # normalized center x
38
+ y: float # normalized center y
39
+ width: float
40
+ height: float
41
+ confidence: float = 1.0
42
+
43
+ @dataclass
44
+ class PostAnalysis:
45
+ image_path: Path
46
+ parts: List[PostPart]
47
+ anomalies: List[PostPart]
48
+ violations: List[str]
49
+ is_conform: bool
50
+ confidence_score: float
51
+
52
+ # ───── Configuration Load ────────────────────────────────────────────────────
53
+ def load_yaml_config(yaml_path: Path) -> dict:
54
+ if not yaml_path.exists():
55
+ sys.exit(f"Required {yaml_path} was not found – aborting.")
56
+ with yaml_path.open("r", encoding="utf-8") as fh:
57
+ data = yaml.safe_load(fh)
58
+ if "names" not in data:
59
+ sys.exit("'names' field missing in data.yaml – unable to continue.")
60
+ return {
61
+ "names": data["names"],
62
+ "class_to_name": {i: n for i, n in enumerate(data["names"])},
63
+ "name_to_class": {n: i for i, n in enumerate(data["names"])},
64
+ }
65
+
66
+ # ───── SAM-2 Alias Patch ─────────────────────────────────────────────────────
67
+ # Maps sam_2 package to sam2 namespace for correct imports
68
+ try:
69
+ import sam_2
70
+ import importlib
71
+ sys.modules['sam2'] = sam_2
72
+ for sub in ['build_sam','automatic_mask_generator','modeling.sam2_base']:
73
+ sys.modules[f'sam2.{sub}'] = importlib.import_module(f'sam_2.{sub}')
74
+ except ImportError:
75
+ pass
76
+
77
+ # ───── Dependency Checker & Installer (SAM-2) ─────────────────────────────────
78
+ def check_and_install_sam2() -> tuple[bool,str]:
79
  try:
 
80
  from sam2.build_sam import build_sam2
 
 
81
  return True, "SAM-2 already available"
82
+ except ImportError:
83
+ # Clone if needed
84
+ if not os.path.exists("segment-anything-2"):
85
+ subprocess.run([
86
+ "git","clone",
87
+ "https://github.com/facebookresearch/segment-anything-2.git"
88
+ ], check=True)
89
+ # Install editable
90
+ cwd = os.getcwd()
91
+ os.chdir("segment-anything-2")
92
+ subprocess.run([sys.executable, "-m", "pip", "install", "-e", "."], check=True)
93
+ os.chdir(cwd)
94
+ # Add to path and re-alias
95
+ path = os.path.abspath("segment-anything-2")
96
+ if path not in sys.path:
97
+ sys.path.insert(0, path)
98
  try:
99
+ import sam_2, importlib
100
+ sys.modules['sam2'] = sam_2
101
+ for sub in ['build_sam','automatic_mask_generator','modeling.sam2_base']:
102
+ sys.modules[f'sam2.{sub}'] = importlib.import_module(f'sam_2.{sub}')
103
+ except ImportError:
104
+ return False, "SAM-2 import failed after install"
105
+ return True, "SAM-2 installed and aliased"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
 
107
  SAM2_AVAILABLE, SAM2_STATUS = check_and_install_sam2()
108
  print(f"SAM-2 Status: {SAM2_STATUS}")
 
 
 
 
109
  if SAM2_AVAILABLE:
110
+ from sam2.build_sam import build_sam2
111
+ from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
112
+ from sam2.modeling.sam2_base import SAM2Base
113
+
114
+ # ───── YOLO Inference ────────────────────────────────────────────────────────
115
+ def infer_parts(
116
+ img_path: Path,
117
+ model: YOLO,
118
+ class_info: dict,
119
+ ) -> tuple[List[PostPart], List[PostPart]]:
120
+ results = model(str(img_path))
121
+ parts, anomalies = [], []
122
+ for det in results[0].boxes:
123
+ x, y, w, h = det.xywh[0].tolist()
124
+ cls_id = int(det.cls[0].item())
125
+ conf = float(det.conf[0].item())
126
+ name = class_info['class_to_name'].get(cls_id, f"unknown-{cls_id}")
127
+ part = PostPart(name, x, y, w, h, conf)
128
+ (anomalies if name=='afwijking' else parts).append(part)
129
+ return parts, anomalies
130
+
131
+ # ───── Spatial Validation ────────────────────────────────────────────────────
132
+ def check_position(part: PostPart, img_w: int, img_h: int) -> bool:
133
+ cx, cy = part.x*img_w, part.y*img_h
134
+ w_px, h_px = part.width*img_w, part.height*img_h
135
+ if part.name=='logo':
136
+ return (cx - w_px/2 >= 0.75*img_w) and (cy + h_px/2 <= 0.25*img_h)
137
+ return True
138
+
139
+ def validate_layout(parts: List[PostPart], image_shape: tuple[int,int]) -> List[str]:
140
+ img_h, img_w = image_shape
141
+ return [f"{p.name} out of expected zone" for p in parts if not check_position(p, img_w, img_h)]
142
+
143
+ # ───── Confidence Scoring ───────────────────────────────────────────────────
144
+ def compute_confidence(
145
+ parts: List[PostPart], anomalies: List[PostPart], violations: List[str]
146
+ ) -> float:
147
+ base = sum(p.confidence for p in parts)/len(parts) if parts else 0.3
148
+ defect_penalty = min(0.1*len(anomalies), 0.5)
149
+ layout_penalty = min(0.05*len(violations), 0.3)
150
+ return max(0.0, base - defect_penalty - layout_penalty)
151
+
152
+ # ───── Core Analysis ────────────────────────────────────────────────────────
153
+ def analyze_post(
154
+ img_path: Path, model: YOLO, class_info: dict, quiet: bool=False
155
+ ) -> PostAnalysis:
156
+ parts, anomalies = infer_parts(img_path, model, class_info)
157
+ img = cv2.imread(str(img_path))
158
+ if img is None: sys.exit(f"Failed to read image {img_path}")
159
+ violations = validate_layout(parts, img.shape[:2])
160
+ score = compute_confidence(parts, anomalies, violations)
161
+ conform = not anomalies and not violations
162
+ if not quiet:
163
+ status = 'CONFORM' if conform else 'NON-CONFORM'
164
+ print(f"{img_path.name}: {status} | parts={len(parts)}, anomalies={len(anomalies)}, violations={len(violations)} | score={score:.2f}")
165
+ return PostAnalysis(img_path, parts, anomalies, violations, conform, score)
166
+
167
+ # ───── Reporting ─────────────────────────────────────────────────────────────
168
+ def write_analysis_report(analyses: List[PostAnalysis], output_dir: Path) -> Path:
169
+ output_dir.mkdir(parents=True, exist_ok=True)
170
+ report = []
171
+ for a in analyses:
172
+ report.append({
173
+ 'image': str(a.image_path), 'is_conform': a.is_conform,
174
+ 'confidence_score': a.confidence_score, 'violations': a.violations,
175
+ 'parts': [vars(p) for p in a.parts], 'anomalies': [vars(d) for d in a.anomalies]
176
+ })
177
+ fp = output_dir/'post_analysis.json'
178
+ with fp.open('w',encoding='utf-8') as f: json.dump(report,f,indent=2)
179
+ return fp
180
+
181
+ # ───── Image Download Helper ─────────────────────────────────────────────────
182
+ def download_image(url: str) -> Union[Path,None]:
183
  try:
184
+ r = requests.get(url,timeout=10); r.raise_for_status()
185
+ parsed = urlparse(url)
186
+ ext = Path(parsed.path).suffix.lower() or '.jpg'
187
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=ext)
188
+ tmp.write(r.content); tmp.close()
189
+ return Path(tmp.name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  except Exception as e:
191
+ print(f"Download error for {url}: {e}"); return None
192
+
193
+ # ───── Process Uploaded Image ─────────────────────────────────────────────────
194
+ def process_uploaded_image(
195
+ image_data: Union[str,bytes,Path], model: YOLO, class_info: dict,
196
+ output_dir: Path, quiet: bool=False
197
+ ) -> PostAnalysis:
198
+ tmp=None
 
 
199
  try:
200
+ if isinstance(image_data,str) and image_data.startswith(('http://','https://')):
201
+ tmp = download_image(image_data); img_path=tmp or sys.exit()
202
+ elif isinstance(image_data,bytes):
203
+ img=Image.open(io.BytesIO(image_data)); fmt=img.format.lower(); ext=f".{fmt if fmt!='jpeg' else 'jpg'}"
204
+ tmp=tempfile.NamedTemporaryFile(delete=False,suffix=ext); tmp.write(image_data); tmp.close(); img_path=Path(tmp.name)
205
+ else:
206
+ img_path=Path(image_data);
207
+ if not img_path.exists(): sys.exit(f"File not found: {img_path}")
208
+ analysis = analyze_post(img_path, model, class_info, quiet)
209
+ out_fp = output_dir/f"analysis_{img_path.stem}.json"
210
+ with out_fp.open('w',encoding='utf-8') as f: json.dump({
211
+ 'image':str(img_path),'is_conform':analysis.is_conform,
212
+ 'confidence_score':analysis.confidence_score,'violations':analysis.violations,
213
+ 'parts':[vars(p) for p in analysis.parts],'anomalies':[vars(d) for d in analysis.anomalies]
214
+ },f,indent=2)
215
+ return analysis
216
+ finally:
217
+ if tmp and Path(tmp.name).exists(): os.remove(tmp.name)
218
+
219
+ # ───── Process Directory & Uploads ───────────────────────────────────────────
220
+ def process_directory(images_dir: Path, output_dir: Path, data_yaml: Path, weights: str, quiet: bool=False):
221
+ ci=load_yaml_config(data_yaml); model=YOLO(weights)
222
+ imgs=[p for p in images_dir.iterdir() if p.suffix.lower() in ['.jpg','.jpeg','.png']]
223
+ if not imgs: sys.exit("No images found.")
224
+ output_dir.mkdir(parents=True,exist_ok=True)
225
+ analyses=[analyze_post(img,model,ci,quiet) for img in imgs]
226
+ rpt=write_analysis_report(analyses,output_dir)
227
+ print(f"Report written to {rpt}")
228
+
229
+ def process_uploaded_images(images: List[Union[str,bytes,Path]], output_dir: Path, data_yaml: Path, weights: str, quiet: bool=False):
230
+ ci=load_yaml_config(data_yaml); model=YOLO(weights); output_dir.mkdir(parents=True,exist_ok=True)
231
+ analyses=[]
232
+ for img in images:
233
+ try: analyses.append(process_uploaded_image(img,model,ci,output_dir,quiet))
234
+ except Exception as e: print(f"Error: {e}")
235
+ print(f"Processed {len(analyses)} uploads.")
236
+ return analyses
237
+
238
+ # ───── CLI Entrypoint ───────────────────────────────────────────────────────
239
+ def main(argv=None):
240
+ p=argparse.ArgumentParser(description="Enhanced post analysis tool")
241
+ p.add_argument("--images",type=Path,help="Directory of images")
242
+ p.add_argument("--upload",nargs="+",help="URLs, paths, or bytes to analyze")
243
+ p.add_argument("--output",type=Path,default="post_analysis_results")
244
+ p.add_argument("--data",type=Path,default="data.yaml")
245
+ p.add_argument("--weights",type=str,default="yolov8n.pt")
246
+ p.add_argument("-q","--quiet",action="store_true")
247
+ args=p.parse_args(argv)
248
+ if args.upload:
249
+ process_uploaded_images(args.upload,args.output,args.data,args.weights,args.quiet)
250
+ elif args.images:
251
+ process_directory(args.images,args.output,args.data,args.weights,args.quiet)
 
 
 
252
  else:
253
+ p.error("Specify --images or --upload")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
+ if __name__ == "__main__": main()
256