File size: 2,996 Bytes
ef17fc5
afa884d
ef17fc5
 
 
 
afa884d
ef17fc5
 
afa884d
 
ef17fc5
afa884d
 
 
 
 
ef17fc5
afa884d
ef17fc5
 
 
 
 
 
 
 
 
afa884d
 
 
 
 
 
 
ef17fc5
afa884d
ef17fc5
afa884d
ef17fc5
72145e5
afa884d
ef17fc5
72145e5
ef17fc5
 
 
afa884d
 
ef17fc5
afa884d
 
 
ef17fc5
afa884d
 
ef17fc5
afa884d
ef17fc5
afa884d
ef17fc5
afa884d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef17fc5
afa884d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import json
import streamlit as st
from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np
import faiss

# Page configuration
st.set_page_config(page_title='KRISSBERT UMLS Linker', layout='wide')
st.title('๐Ÿงฌ KRISSBERT + UMLS Entity Linker (Local FAISS)')

# File paths
METADATA_PATH = 'umls_metadata.json'
EMBED_PATH = 'umls_embeddings.npy'
INDEX_PATH = 'umls_index.faiss'
MODEL_NAME = 'microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL'

# Load model & tokenizer
@st.cache_resource
def load_model():
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModel.from_pretrained(MODEL_NAME)
    model.eval()
    return tokenizer, model

tokenizer, model = load_model()

# Load UMLS FAISS index + metadata
@st.cache_resource
def load_umls_index():
    meta = json.load(open(METADATA_PATH, 'r'))
    embeddings = np.load(EMBED_PATH)
    index = faiss.read_index(INDEX_PATH)
    return index, meta

faiss_index, umls_meta = load_umls_index()

# Embed text
@st.cache_resource
def embed_text(text, _tokenizer, _model):
    inputs = _tokenizer(text, return_tensors='pt', truncation=True, padding=True)
    with torch.no_grad():
        outputs = _model(**inputs)
    emb = outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy()
    return emb / np.linalg.norm(emb)

# UI: examples and input
st.markdown('Enter a biomedical sentence to link entities via local UMLS FAISS index and KRISSBERT:')
examples = [
    'The patient was administered metformin for type 2 diabetes.',
    'ER crowding has become a widespread issue in hospitals.',
    'Tamoxifen is used in the treatment of ER-positive breast cancer.'
]
selected = st.selectbox('๐Ÿ” Example queries', ['Choose...'] + examples)
sentence = st.text_area('๐Ÿ“ Sentence:', value=(selected if selected != 'Choose...' else ''))

if st.button('๐Ÿ”— Link Entities'):
    if not sentence.strip():
        st.warning('Please enter a sentence first.')
    else:
        with st.spinner('Embedding sentence and searching FAISSโ€ฆ'):
            sent_emb = embed_text(sentence, tokenizer, model).reshape(1, -1)
            distances, indices = faiss_index.search(sent_emb, 5)
            results = []
            for idx in indices[0]:
                entry = umls_meta.get(str(idx), {})
                results.append({
                    'cui': entry.get('cui', ''),
                    'name': entry.get('name', ''),
                    'definition': entry.get('definition', ''),
                    'source': entry.get('source', '')
                })
        # Display
        if results:
            st.success('Top UMLS candidates:')
            for item in results:
                st.markdown('**' + item['name'] + '** (CUI: `' + item['cui'] + '`)')
                if item['definition']:
                    st.markdown('> ' + item['definition'] + '\n')
                st.markdown('_Source: ' + item['source'] + '_\n---')
        else:
            st.info('No matches found in UMLS index.')