|
from crewai.tools import BaseTool |
|
from pydantic import BaseModel, Field |
|
from typing import List, Dict, Optional, Any |
|
import os |
|
import json |
|
import torch |
|
import faiss |
|
import numpy as np |
|
from pathlib import Path |
|
from .specter2_embedder import embed_texts_specter2 |
|
|
|
|
|
class PubmedQueryInput(BaseModel): |
|
caption: str |
|
|
|
|
|
class PubmedRetrievalTool(BaseTool): |
|
|
|
name: str = "pubmed_retrieval_tool" |
|
description: str = ( |
|
"Retrieves the most relevant PubMed articles for a given radiology caption." |
|
) |
|
args_schema: type = PubmedQueryInput |
|
metadata: dict = {} |
|
|
|
def __init__(self, **data): |
|
|
|
super().__init__(**data) |
|
|
|
def _run(self, caption: str = None, **kwargs) -> str: |
|
""" |
|
Retrieves relevant PubMed articles based on a radiology caption. |
|
""" |
|
|
|
if not caption and 'caption' in kwargs: |
|
caption = kwargs['caption'] |
|
|
|
|
|
if not caption or not str(caption).strip(): |
|
return "Error: No caption provided. Unable to search PubMed." |
|
|
|
caption = str(caption).strip() |
|
|
|
|
|
BASE_DIR = Path(__file__).parent.parent.parent |
|
default_data_dir = str(BASE_DIR / "data") |
|
|
|
data_dir = self.metadata.get("DATA_DIR", default_data_dir) |
|
top_k = self.metadata.get("TOP_K", 3) |
|
|
|
try: |
|
|
|
index_path = os.path.join(data_dir, "text_faiss.bin") |
|
metadata_path = os.path.join(data_dir, "raw_abstracts.jsonl") |
|
|
|
|
|
if not os.path.exists(index_path): |
|
return f"Error: FAISS index not found at {index_path}" |
|
if not os.path.exists(metadata_path): |
|
return f"Error: Metadata file not found at {metadata_path}" |
|
|
|
|
|
index = faiss.read_index(index_path) |
|
|
|
with open(metadata_path, "r", encoding="utf-8") as f: |
|
metadata = [json.loads(line) for line in f] |
|
|
|
|
|
query_vec = embed_texts_specter2([caption]).astype("float32") |
|
|
|
scores, indices = index.search(query_vec, top_k) |
|
|
|
|
|
formatted = [] |
|
for i, (score, idx) in enumerate(zip(scores[0], indices[0]), 1): |
|
entry = metadata[idx] |
|
formatted.append( |
|
f"Citation {i}:\n" |
|
f"PMID: {entry.get('pmid', 'Unknown')}\n" |
|
f"Similarity Score: {score:.3f}\n" |
|
f"Title: {entry.get('title', 'Untitled').strip()}\n" |
|
f"Abstract: {entry.get('abstract', 'No abstract available.').strip()}\n" |
|
) |
|
|
|
|
|
return "\n---\n".join(formatted) |
|
|
|
except Exception as e: |
|
|
|
return f"Error during PubMed search: {str(e)}" |