UMLS / app.py
mgbam's picture
Update app.py
afa884d verified
raw
history blame
3 kB
import os
import json
import streamlit as st
from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np
import faiss
# Page configuration
st.set_page_config(page_title='KRISSBERT UMLS Linker', layout='wide')
st.title('🧬 KRISSBERT + UMLS Entity Linker (Local FAISS)')
# File paths
METADATA_PATH = 'umls_metadata.json'
EMBED_PATH = 'umls_embeddings.npy'
INDEX_PATH = 'umls_index.faiss'
MODEL_NAME = 'microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL'
# Load model & tokenizer
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
model.eval()
return tokenizer, model
tokenizer, model = load_model()
# Load UMLS FAISS index + metadata
@st.cache_resource
def load_umls_index():
meta = json.load(open(METADATA_PATH, 'r'))
embeddings = np.load(EMBED_PATH)
index = faiss.read_index(INDEX_PATH)
return index, meta
faiss_index, umls_meta = load_umls_index()
# Embed text
@st.cache_resource
def embed_text(text, _tokenizer, _model):
inputs = _tokenizer(text, return_tensors='pt', truncation=True, padding=True)
with torch.no_grad():
outputs = _model(**inputs)
emb = outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy()
return emb / np.linalg.norm(emb)
# UI: examples and input
st.markdown('Enter a biomedical sentence to link entities via local UMLS FAISS index and KRISSBERT:')
examples = [
'The patient was administered metformin for type 2 diabetes.',
'ER crowding has become a widespread issue in hospitals.',
'Tamoxifen is used in the treatment of ER-positive breast cancer.'
]
selected = st.selectbox('πŸ” Example queries', ['Choose...'] + examples)
sentence = st.text_area('πŸ“ Sentence:', value=(selected if selected != 'Choose...' else ''))
if st.button('πŸ”— Link Entities'):
if not sentence.strip():
st.warning('Please enter a sentence first.')
else:
with st.spinner('Embedding sentence and searching FAISS…'):
sent_emb = embed_text(sentence, tokenizer, model).reshape(1, -1)
distances, indices = faiss_index.search(sent_emb, 5)
results = []
for idx in indices[0]:
entry = umls_meta.get(str(idx), {})
results.append({
'cui': entry.get('cui', ''),
'name': entry.get('name', ''),
'definition': entry.get('definition', ''),
'source': entry.get('source', '')
})
# Display
if results:
st.success('Top UMLS candidates:')
for item in results:
st.markdown('**' + item['name'] + '** (CUI: `' + item['cui'] + '`)')
if item['definition']:
st.markdown('> ' + item['definition'] + '\n')
st.markdown('_Source: ' + item['source'] + '_\n---')
else:
st.info('No matches found in UMLS index.')