|
from crewai.tools import BaseTool |
|
from pydantic import BaseModel, Field |
|
from typing import Optional, Dict |
|
import torch |
|
import numpy as np |
|
import json |
|
import os |
|
from PIL import Image |
|
from pathlib import Path |
|
from open_clip import create_model_from_pretrained |
|
import torch.nn.functional as F |
|
|
|
|
|
class IUImageInput(BaseModel): |
|
|
|
image_path: str = Field(..., description="Absolute path to the query image") |
|
|
|
|
|
class IUImpressionSearchTool(BaseTool): |
|
|
|
name: str = "iu_impression_search_tool" |
|
description: str = ( |
|
"Retrieves the most similar IU X-ray image impression based on visual similarity " |
|
"using BiomedCLIP embeddings." |
|
) |
|
args_schema: type = IUImageInput |
|
metadata: dict = {} |
|
|
|
def _run(self, image_path: str) -> str: |
|
|
|
BASE_DIR = Path(__file__).parent.parent.parent |
|
default_vecs_path = str(BASE_DIR / "data" / "iu_vecs.npy") |
|
default_impr_path = str(BASE_DIR / "data" / "iu_impr.jsonl") |
|
|
|
|
|
vecs_path = self.metadata.get("VEC_PATH", default_vecs_path) |
|
impr_path = self.metadata.get("IMPR_PATH", default_impr_path) |
|
model_id = self.metadata.get("MODEL_ID", "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224") |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
if not os.path.exists(vecs_path): |
|
return f"Error: IU vectors file not found at {vecs_path}" |
|
if not os.path.exists(impr_path): |
|
return f"Error: IU impressions file not found at {impr_path}" |
|
|
|
|
|
try: |
|
model, preprocess = create_model_from_pretrained(model_id) |
|
model = model.to(device).eval() |
|
except Exception as e: |
|
return f"Error loading BiomedCLIP model: {e}" |
|
|
|
|
|
try: |
|
image = Image.open(image_path).convert("RGB") |
|
tensor_img = preprocess(image).unsqueeze(0).to(device) |
|
with torch.no_grad(): |
|
query_vec = model.encode_image(tensor_img) |
|
query_vec = F.normalize(query_vec, dim=-1).cpu().numpy() |
|
except Exception as e: |
|
return f"Error processing input image: {e}" |
|
|
|
|
|
try: |
|
iu_vecs = np.load(vecs_path) |
|
iu_vecs = iu_vecs / np.linalg.norm(iu_vecs, axis=1, keepdims=True) |
|
except Exception as e: |
|
return f"Error loading IU vectors: {e}" |
|
|
|
|
|
similarities = np.dot(iu_vecs, query_vec.squeeze()) |
|
best_idx = int(np.argmax(similarities)) |
|
|
|
|
|
try: |
|
with open(impr_path, "r", encoding="utf-8") as f: |
|
records = [json.loads(line) for line in f] |
|
best_match = records[best_idx] |
|
except Exception as e: |
|
return f"Error loading impressions metadata: {e}" |
|
|
|
|
|
return ( |
|
f"Closest Match UUID: {best_match['uuid']}\n\n" |
|
f"Impression:\n{best_match['impression'].strip()}" |
|
) |