Spaces:
Sleeping
Sleeping
import os | |
os.environ["HF_HOME"] = "/tmp/huggingface" | |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers" | |
import streamlit as st | |
from typing import List, Tuple | |
import re | |
import torch | |
from transformers import AutoTokenizer, AutoModelForTokenClassification | |
# Mapping of label to color | |
LABEL_COLORS = { | |
'LABEL-0': '#cccccc', # NONE | |
'LABEL-1': '#ffadad', # B-DATE | |
'LABEL-2': '#ebdb98', # I-DATE | |
'LABEL-3': '#586492', # B-TIME | |
'LABEL-4': '#ffb788', # I-TIME | |
'LABEL-5': '#76abbb', # B-DURATION | |
'LABEL-6': '#a0c4ff', # I-DURATION | |
'LABEL-7': '#f84252', # B-SET | |
'LABEL-8': '#ebdb98', # I-SET | |
} | |
LABEL_MEANINGS = { | |
'LABEL-0': 'NONE', | |
'LABEL-1': 'B-DATE', | |
'LABEL-2': 'I-DATE', | |
'LABEL-3': 'B-TIME', | |
'LABEL-4': 'I-TIME', | |
'LABEL-5': 'B-DURATION', | |
'LABEL-6': 'I-DURATION', | |
'LABEL-7': 'B-SET', | |
'LABEL-8': 'I-SET', | |
} | |
def load_model(): | |
tokenizer = AutoTokenizer.from_pretrained('asdc/Bio-RoBERTime') | |
model = AutoModelForTokenClassification.from_pretrained('asdc/Bio-RoBERTime') | |
return tokenizer, model | |
def ner_with_robertime(text: str) -> List[Tuple[str, str]]: | |
tokenizer, model = load_model() | |
tokens = tokenizer(text, return_tensors="pt", truncation=True, is_split_into_words=False) | |
with torch.no_grad(): | |
outputs = model(**tokens) | |
predictions = torch.argmax(outputs.logits, dim=2)[0].tolist() | |
labels = [model.config.id2label[pred] for pred in predictions] | |
word_ids = tokens.word_ids(batch_index=0) | |
input_ids = tokens["input_ids"][0] | |
entities = [] | |
current_word_ids = [] | |
current_label = None | |
last_word_id = None | |
for idx, word_id in enumerate(word_ids): | |
if word_id is None: | |
continue | |
label = labels[idx] | |
if word_id != last_word_id and current_word_ids: | |
word = tokenizer.decode([input_ids[i] for i in current_word_ids], skip_special_tokens=True) | |
entities.append((word, current_label)) | |
current_word_ids = [idx] | |
current_label = label | |
else: | |
current_word_ids.append(idx) | |
current_label = label | |
last_word_id = word_id | |
if current_word_ids: | |
word = tokenizer.decode([input_ids[i] for i in current_word_ids], skip_special_tokens=True) | |
entities.append((word, current_label)) | |
return entities | |
def colorize_entities(ner_result: List[Tuple[str, str]]) -> str: | |
html = '' | |
for token, label in ner_result: | |
norm_label = label.replace('_', '-') | |
if norm_label != 'LABEL-0': | |
color = LABEL_COLORS.get(norm_label, '#eeeeee') | |
label_meaning = LABEL_MEANINGS.get(norm_label, norm_label) | |
html += ( | |
f'<span class="ner-entity" style="background-color:{color};padding:2px 4px;border-radius:4px;margin:1px;" ' | |
f'data-tooltip="{label_meaning}">{token}</span> ' | |
) | |
else: | |
html += f'{token} ' | |
return html | |
def extract_entities(ner_result: List[Tuple[str, str]]) -> List[Tuple[str, str]]: | |
# Group consecutive tokens with the same entity label (not LABEL-0) | |
entities = [] | |
current_entity = [] | |
current_label = None | |
for token, label in ner_result: | |
if label != 'LABEL-0': | |
if current_label == label: | |
current_entity.append(token) | |
else: | |
if current_entity: | |
entities.append((' '.join(current_entity), current_label)) | |
current_entity = [token] | |
current_label = label | |
else: | |
if current_entity: | |
entities.append((' '.join(current_entity), current_label)) | |
current_entity = [] | |
current_label = None | |
if current_entity: | |
entities.append((' '.join(current_entity), current_label)) | |
return entities | |
def legend_html() -> str: | |
html = '<div style="display:flex;flex-wrap:wrap;gap:8px;">' | |
for label, color in LABEL_COLORS.items(): | |
if label == 'LABEL-0': | |
continue | |
meaning = LABEL_MEANINGS[label] | |
html += f'<span style="background-color:{color};padding:2px 8px;border-radius:4px;">{meaning} ({label})</span>' | |
html += '</div>' | |
return html | |
st.title('LLM-powered Named Entity Recognition (NER)') | |
st.markdown( | |
''' | |
<style> | |
.ner-entity { | |
position: relative; | |
cursor: pointer; | |
} | |
.ner-entity[data-tooltip]:hover:after { | |
content: attr(data-tooltip); | |
position: absolute; | |
left: 0; | |
top: 100%; | |
background: #222; | |
color: #fff; | |
padding: 2px 8px; | |
border-radius: 4px; | |
white-space: nowrap; | |
z-index: 10; | |
font-size: 0.9em; | |
margin-top: 2px; | |
} | |
</style> | |
''', | |
unsafe_allow_html=True | |
) | |
st.markdown('**Legend:**') | |
st.markdown(legend_html(), unsafe_allow_html=True) | |
user_text = st.text_area('Enter text for NER:', height=150) | |
if user_text: | |
ner_result = ner_with_robertime(user_text) | |
has_entity = any(label != 'LABEL-0' for _, label in ner_result) | |
if has_entity: | |
st.markdown('#### Entities Highlighted:') | |
st.markdown(colorize_entities(ner_result), unsafe_allow_html=True) | |
entities = extract_entities(ner_result) | |
if entities: | |
st.markdown('#### Detected Entities:') | |
for ent, label in entities: | |
norm_label = label.replace('_', '-') | |
st.markdown(f'- <span style="background-color:{LABEL_COLORS[norm_label]};padding:2px 8px;border-radius:4px;">{ent}</span> <span style="color:#888;">({LABEL_MEANINGS[norm_label]})</span>', unsafe_allow_html=True) | |
else: | |
st.info('No entities detected.') | |
else: | |
st.info('No entities detected.') | |
st.caption('Model: [asdc/Bio-RoBERTime](https://huggingface.co/asdc/Bio-RoBERTime)') |