MCP_Res / mcp /umls.py
mgbam's picture
Update mcp/umls.py
4f7b321 verified
import os
import re
import httpx
import asyncio
from functools import lru_cache
from pathlib import Path
from typing import List, Optional, Dict, Any
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
UMLS_API_KEY = os.getenv("UMLS_KEY")
UMLS_AUTH_URL = "https://utslogin.nlm.nih.gov/cas/v1/api-key"
UMLS_SEARCH_URL = "https://uts-ws.nlm.nih.gov/rest/search/current"
# ---------------------------------------------------------------------------
# Named types
# ---------------------------------------------------------------------------
class UMLSResult(Dict[str, Optional[str]]):
"""
Represents a single UMLS lookup result.
Keys: term, cui, name, definition
"""
pass
# ---------------------------------------------------------------------------
# NLP model loading with caching
# ---------------------------------------------------------------------------
@lru_cache(maxsize=None)
def _load_spacy_model(model_name: str):
import spacy
return spacy.load(model_name)
@lru_cache(maxsize=None)
def _load_scispacy_model():
# Prefer the BioNLP model; fall back to the smaller sci model
try:
return _load_spacy_model("en_ner_bionlp13cg_md")
except Exception:
return _load_spacy_model("en_core_sci_sm")
@lru_cache(maxsize=None)
def _load_general_spacy():
return _load_spacy_model("en_core_web_sm")
# ---------------------------------------------------------------------------
# Concept extraction utilities
# ---------------------------------------------------------------------------
def _extract_entities(nlp, text: str, min_length: int) -> List[str]:
"""
Run a spaCy nlp pipeline over text and return unique entity texts
of at least min_length.
"""
doc = nlp(text)
ents = {ent.text.strip() for ent in doc.ents if len(ent.text.strip()) >= min_length}
return list(ents)
def _regex_fallback(text: str, min_length: int) -> List[str]:
"""
Simple regex-based token extraction for fallback.
"""
tokens = re.findall(r"\b[a-zA-Z0-9\-]+\b", text)
return list({t for t in tokens if len(t) >= min_length})
def extract_umls_concepts(text: str, min_length: int = 3) -> List[str]:
"""
Extract biomedical concepts from text in priority order:
1. SciSpaCy (en_ner_bionlp13cg_md or en_core_sci_sm)
2. spaCy general NER (en_core_web_sm)
3. Regex tokens
Guaranteed to return a list of unique strings.
"""
# 1) SciSpaCy pipeline
try:
scispacy_nlp = _load_scispacy_model()
entities = _extract_entities(scispacy_nlp, text, min_length)
if entities:
return entities
except ImportError:
# SciSpaCy not installed
pass
except Exception:
# Unexpected failure in scispacy
pass
# 2) General spaCy pipeline
try:
general_nlp = _load_general_spacy()
entities = _extract_entities(general_nlp, text, min_length)
if entities:
return entities
except Exception:
pass
# 3) Regex fallback
return _regex_fallback(text, min_length)
# ---------------------------------------------------------------------------
# UMLS API integration
# ---------------------------------------------------------------------------
async def _get_umls_ticket() -> Optional[str]:
"""
Obtain a UMLS service ticket for subsequent queries.
Returns None if API key is missing or authentication fails.
"""
if not UMLS_API_KEY:
return None
try:
async with httpx.AsyncClient(timeout=10) as client:
response = await client.post(
UMLS_AUTH_URL, data={"apikey": UMLS_API_KEY}
)
response.raise_for_status()
tgt_url = response.text.split('action="')[1].split('"')[0]
service_resp = await client.post(
tgt_url, data={"service": "http://umlsks.nlm.nih.gov"}
)
return service_resp.text
except Exception:
return None
@lru_cache(maxsize=512)
async def lookup_umls(term: str) -> UMLSResult:
"""
Look up a term in the UMLS API.
Returns a dict containing the original term, its CUI, preferred name, and definition.
On failure or quota issues, returns all values except 'term' as None.
"""
ticket = await _get_umls_ticket()
if not ticket:
return {"term": term, "cui": None, "name": None, "definition": None}
params = {"string": term, "ticket": ticket, "pageSize": 1}
try:
async with httpx.AsyncClient(timeout=8) as client:
resp = await client.get(UMLS_SEARCH_URL, params=params)
resp.raise_for_status()
results = resp.json().get("result", {}).get("results", [])
first = results[0] if results else {}
return {
"term": term,
"cui": first.get("ui"),
"name": first.get("name"),
"definition": first.get("definition") or first.get("rootSource"),
}
except Exception:
return {"term": term, "cui": None, "name": None, "definition": None}