import os os.environ["HF_HOME"] = "/tmp/huggingface" os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers" import streamlit as st from typing import List, Tuple import re import torch from transformers import AutoTokenizer, AutoModelForTokenClassification # Mapping of label to color LABEL_COLORS = { 'LABEL-0': '#cccccc', # NONE 'LABEL-1': '#ffadad', # B-DATE 'LABEL-2': '#ebdb98', # I-DATE 'LABEL-3': '#586492', # B-TIME 'LABEL-4': '#ffb788', # I-TIME 'LABEL-5': '#76abbb', # B-DURATION 'LABEL-6': '#a0c4ff', # I-DURATION 'LABEL-7': '#f84252', # B-SET 'LABEL-8': '#ebdb98', # I-SET } LABEL_MEANINGS = { 'LABEL-0': 'NONE', 'LABEL-1': 'B-DATE', 'LABEL-2': 'I-DATE', 'LABEL-3': 'B-TIME', 'LABEL-4': 'I-TIME', 'LABEL-5': 'B-DURATION', 'LABEL-6': 'I-DURATION', 'LABEL-7': 'B-SET', 'LABEL-8': 'I-SET', } @st.cache_resource(show_spinner=True) def load_model(): tokenizer = AutoTokenizer.from_pretrained('asdc/Bio-RoBERTime') model = AutoModelForTokenClassification.from_pretrained('asdc/Bio-RoBERTime') return tokenizer, model def ner_with_robertime(text: str) -> List[Tuple[str, str]]: tokenizer, model = load_model() tokens = tokenizer(text, return_tensors="pt", truncation=True, is_split_into_words=False) with torch.no_grad(): outputs = model(**tokens) predictions = torch.argmax(outputs.logits, dim=2)[0].tolist() labels = [model.config.id2label[pred] for pred in predictions] word_ids = tokens.word_ids(batch_index=0) input_ids = tokens["input_ids"][0] entities = [] current_word_ids = [] current_label = None last_word_id = None for idx, word_id in enumerate(word_ids): if word_id is None: continue label = labels[idx] if word_id != last_word_id and current_word_ids: word = tokenizer.decode([input_ids[i] for i in current_word_ids], skip_special_tokens=True) entities.append((word, current_label)) current_word_ids = [idx] current_label = label else: current_word_ids.append(idx) current_label = label last_word_id = word_id if current_word_ids: word = tokenizer.decode([input_ids[i] for i in current_word_ids], skip_special_tokens=True) entities.append((word, current_label)) return entities def colorize_entities(ner_result: List[Tuple[str, str]]) -> str: html = '' for token, label in ner_result: norm_label = label.replace('_', '-') if norm_label != 'LABEL-0': color = LABEL_COLORS.get(norm_label, '#eeeeee') label_meaning = LABEL_MEANINGS.get(norm_label, norm_label) html += ( f'{token} ' ) else: html += f'{token} ' return html def extract_entities(ner_result: List[Tuple[str, str]]) -> List[Tuple[str, str]]: # Group consecutive tokens with the same entity label (not LABEL-0) entities = [] current_entity = [] current_label = None for token, label in ner_result: if label != 'LABEL-0': if current_label == label: current_entity.append(token) else: if current_entity: entities.append((' '.join(current_entity), current_label)) current_entity = [token] current_label = label else: if current_entity: entities.append((' '.join(current_entity), current_label)) current_entity = [] current_label = None if current_entity: entities.append((' '.join(current_entity), current_label)) return entities def legend_html() -> str: html = '
' for label, color in LABEL_COLORS.items(): if label == 'LABEL-0': continue meaning = LABEL_MEANINGS[label] html += f'{meaning} ({label})' html += '
' return html st.title('LLM-powered Named Entity Recognition (NER)') st.markdown( ''' ''', unsafe_allow_html=True ) st.markdown('**Legend:**') st.markdown(legend_html(), unsafe_allow_html=True) user_text = st.text_area('Enter text for NER:', height=150) if user_text: ner_result = ner_with_robertime(user_text) has_entity = any(label != 'LABEL-0' for _, label in ner_result) if has_entity: st.markdown('#### Entities Highlighted:') st.markdown(colorize_entities(ner_result), unsafe_allow_html=True) entities = extract_entities(ner_result) if entities: st.markdown('#### Detected Entities:') for ent, label in entities: norm_label = label.replace('_', '-') st.markdown(f'- {ent} ({LABEL_MEANINGS[norm_label]})', unsafe_allow_html=True) else: st.info('No entities detected.') else: st.info('No entities detected.') st.caption('Model: [asdc/Bio-RoBERTime](https://huggingface.co/asdc/Bio-RoBERTime)')