asdc's picture
Update src/streamlit_app.py
f1b5517 verified
import os
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
import streamlit as st
from typing import List, Tuple
import re
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
# Mapping of label to color
LABEL_COLORS = {
'LABEL-0': '#cccccc', # NONE
'LABEL-1': '#ffadad', # B-DATE
'LABEL-2': '#ebdb98', # I-DATE
'LABEL-3': '#586492', # B-TIME
'LABEL-4': '#ffb788', # I-TIME
'LABEL-5': '#76abbb', # B-DURATION
'LABEL-6': '#a0c4ff', # I-DURATION
'LABEL-7': '#f84252', # B-SET
'LABEL-8': '#ebdb98', # I-SET
}
LABEL_MEANINGS = {
'LABEL-0': 'NONE',
'LABEL-1': 'B-DATE',
'LABEL-2': 'I-DATE',
'LABEL-3': 'B-TIME',
'LABEL-4': 'I-TIME',
'LABEL-5': 'B-DURATION',
'LABEL-6': 'I-DURATION',
'LABEL-7': 'B-SET',
'LABEL-8': 'I-SET',
}
@st.cache_resource(show_spinner=True)
def load_model():
tokenizer = AutoTokenizer.from_pretrained('asdc/Bio-RoBERTime')
model = AutoModelForTokenClassification.from_pretrained('asdc/Bio-RoBERTime')
return tokenizer, model
def ner_with_robertime(text: str) -> List[Tuple[str, str]]:
tokenizer, model = load_model()
tokens = tokenizer(text, return_tensors="pt", truncation=True, is_split_into_words=False)
with torch.no_grad():
outputs = model(**tokens)
predictions = torch.argmax(outputs.logits, dim=2)[0].tolist()
labels = [model.config.id2label[pred] for pred in predictions]
word_ids = tokens.word_ids(batch_index=0)
input_ids = tokens["input_ids"][0]
entities = []
current_word_ids = []
current_label = None
last_word_id = None
for idx, word_id in enumerate(word_ids):
if word_id is None:
continue
label = labels[idx]
if word_id != last_word_id and current_word_ids:
word = tokenizer.decode([input_ids[i] for i in current_word_ids], skip_special_tokens=True)
entities.append((word, current_label))
current_word_ids = [idx]
current_label = label
else:
current_word_ids.append(idx)
current_label = label
last_word_id = word_id
if current_word_ids:
word = tokenizer.decode([input_ids[i] for i in current_word_ids], skip_special_tokens=True)
entities.append((word, current_label))
return entities
def colorize_entities(ner_result: List[Tuple[str, str]]) -> str:
html = ''
for token, label in ner_result:
norm_label = label.replace('_', '-')
if norm_label != 'LABEL-0':
color = LABEL_COLORS.get(norm_label, '#eeeeee')
label_meaning = LABEL_MEANINGS.get(norm_label, norm_label)
html += (
f'<span class="ner-entity" style="background-color:{color};padding:2px 4px;border-radius:4px;margin:1px;" '
f'data-tooltip="{label_meaning}">{token}</span> '
)
else:
html += f'{token} '
return html
def extract_entities(ner_result: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
# Group consecutive tokens with the same entity label (not LABEL-0)
entities = []
current_entity = []
current_label = None
for token, label in ner_result:
if label != 'LABEL-0':
if current_label == label:
current_entity.append(token)
else:
if current_entity:
entities.append((' '.join(current_entity), current_label))
current_entity = [token]
current_label = label
else:
if current_entity:
entities.append((' '.join(current_entity), current_label))
current_entity = []
current_label = None
if current_entity:
entities.append((' '.join(current_entity), current_label))
return entities
def legend_html() -> str:
html = '<div style="display:flex;flex-wrap:wrap;gap:8px;">'
for label, color in LABEL_COLORS.items():
if label == 'LABEL-0':
continue
meaning = LABEL_MEANINGS[label]
html += f'<span style="background-color:{color};padding:2px 8px;border-radius:4px;">{meaning} ({label})</span>'
html += '</div>'
return html
st.title('LLM-powered Named Entity Recognition (NER)')
st.markdown(
'''
<style>
.ner-entity {
position: relative;
cursor: pointer;
}
.ner-entity[data-tooltip]:hover:after {
content: attr(data-tooltip);
position: absolute;
left: 0;
top: 100%;
background: #222;
color: #fff;
padding: 2px 8px;
border-radius: 4px;
white-space: nowrap;
z-index: 10;
font-size: 0.9em;
margin-top: 2px;
}
</style>
''',
unsafe_allow_html=True
)
st.markdown('**Legend:**')
st.markdown(legend_html(), unsafe_allow_html=True)
user_text = st.text_area('Enter text for NER:', height=150)
if user_text:
ner_result = ner_with_robertime(user_text)
has_entity = any(label != 'LABEL-0' for _, label in ner_result)
if has_entity:
st.markdown('#### Entities Highlighted:')
st.markdown(colorize_entities(ner_result), unsafe_allow_html=True)
entities = extract_entities(ner_result)
if entities:
st.markdown('#### Detected Entities:')
for ent, label in entities:
norm_label = label.replace('_', '-')
st.markdown(f'- <span style="background-color:{LABEL_COLORS[norm_label]};padding:2px 8px;border-radius:4px;">{ent}</span> <span style="color:#888;">({LABEL_MEANINGS[norm_label]})</span>', unsafe_allow_html=True)
else:
st.info('No entities detected.')
else:
st.info('No entities detected.')
st.caption('Model: [asdc/Bio-RoBERTime](https://huggingface.co/asdc/Bio-RoBERTime)')