Spaces:
Sleeping
Sleeping
File size: 5,495 Bytes
b6ac0c6 ab64fd6 2f2c452 ab64fd6 2f2c452 ab64fd6 d9a6a4e 2f2c452 ab64fd6 2f2c452 ab64fd6 2f2c452 65003ad 2f2c452 ab64fd6 d9a6a4e 2f2c452 ab64fd6 d9a6a4e 2f2c452 ab64fd6 2f2c452 d9a6a4e 6c2c9be d9a6a4e 2f2c452 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import os
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
import streamlit as st
from typing import List, Tuple
import re
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
# Mapping of label to color
LABEL_COLORS = {
'LABEL-0': '#cccccc', # NONE
'LABEL-1': '#ffadad', # B-DATE
'LABEL-2': '#ffd6a5', # I-DATE
'LABEL-3': '#fdffb6', # B-TIME
'LABEL-4': '#caffbf', # I-TIME
'LABEL-5': '#9bf6ff', # B-DURATION
'LABEL-6': '#a0c4ff', # I-DURATION
'LABEL-7': '#bdb2ff', # B-SET
'LABEL-8': '#ffc6ff', # I-SET
}
LABEL_MEANINGS = {
'LABEL-0': 'NONE',
'LABEL-1': 'B-DATE',
'LABEL-2': 'I-DATE',
'LABEL-3': 'B-TIME',
'LABEL-4': 'I-TIME',
'LABEL-5': 'B-DURATION',
'LABEL-6': 'I-DURATION',
'LABEL-7': 'B-SET',
'LABEL-8': 'I-SET',
}
@st.cache_resource(show_spinner=True)
def load_model():
tokenizer = AutoTokenizer.from_pretrained('asdc/Bio-RoBERTime')
model = AutoModelForTokenClassification.from_pretrained('asdc/Bio-RoBERTime')
return tokenizer, model
def ner_with_robertime(text: str) -> List[Tuple[str, str]]:
tokenizer, model = load_model()
# Tokenize and get input tensors
tokens = tokenizer(text, return_tensors="pt", truncation=True, is_split_into_words=False)
with torch.no_grad():
outputs = model(**tokens)
predictions = torch.argmax(outputs.logits, dim=2)[0].tolist()
# Map ids to labels
labels = [model.config.id2label[pred] for pred in predictions]
# Get tokens (handling subwords)
word_ids = tokens.word_ids(batch_index=0)
token_list = tokenizer.convert_ids_to_tokens(tokens["input_ids"][0])
# Merge subwords and assign entity labels
entities = []
current_word = ''
current_label = None
last_word_id = None
for idx, word_id in enumerate(word_ids):
if word_id is None:
continue
token = token_list[idx]
label = labels[idx]
if token.startswith('▁') or token.startswith('##') or token.startswith('Ġ'):
token = token.lstrip('▁#Ġ')
if word_id != last_word_id and current_word:
entities.append((current_word, current_label))
current_word = token
current_label = label
else:
if current_word:
current_word += token if token.startswith("'") else f' {token}'
else:
current_word = token
current_label = label
last_word_id = word_id
if current_word:
entities.append((current_word, current_label))
return entities
def colorize_entities(ner_result: List[Tuple[str, str]]) -> str:
html = ''
for token, label in ner_result:
norm_label = label.replace('_', '-')
if norm_label != 'LABEL-0':
color = LABEL_COLORS.get(norm_label, '#eeeeee')
html += f'<span style="background-color:{color};padding:2px 4px;border-radius:4px;margin:1px;">{token}</span> '
else:
html += f'{token} '
return html
def extract_entities(ner_result: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
# Group consecutive tokens with the same entity label (not LABEL-0)
entities = []
current_entity = []
current_label = None
for token, label in ner_result:
if label != 'LABEL-0':
if current_label == label:
current_entity.append(token)
else:
if current_entity:
entities.append((' '.join(current_entity), current_label))
current_entity = [token]
current_label = label
else:
if current_entity:
entities.append((' '.join(current_entity), current_label))
current_entity = []
current_label = None
if current_entity:
entities.append((' '.join(current_entity), current_label))
return entities
def legend_html() -> str:
html = '<div style="display:flex;flex-wrap:wrap;gap:8px;">'
for label, color in LABEL_COLORS.items():
if label == 'LABEL-0':
continue
meaning = LABEL_MEANINGS[label]
html += f'<span style="background-color:{color};padding:2px 8px;border-radius:4px;">{meaning} ({label})</span>'
html += '</div>'
return html
st.title('LLM-powered Named Entity Recognition (NER)')
st.markdown('**Legend:**')
st.markdown(legend_html(), unsafe_allow_html=True)
user_text = st.text_area('Enter text for NER:', height=150)
if user_text:
ner_result = ner_with_robertime(user_text)
has_entity = any(label != 'LABEL-0' for _, label in ner_result)
if has_entity:
st.markdown('#### Entities Highlighted:')
st.markdown(colorize_entities(ner_result), unsafe_allow_html=True)
entities = extract_entities(ner_result)
if entities:
st.markdown('#### Detected Entities:')
for ent, label in entities:
norm_label = label.replace('_', '-')
st.markdown(f'- <span style="background-color:{LABEL_COLORS[norm_label]};padding:2px 8px;border-radius:4px;">{ent}</span> <span style="color:#888;">({LABEL_MEANINGS[norm_label]})</span>', unsafe_allow_html=True)
else:
st.info('No entities detected.')
else:
st.info('No entities detected.')
st.caption('Model: [asdc/Bio-RoBERTime](https://huggingface.co/asdc/Bio-RoBERTime)') |