import streamlit as st from typing import List, Tuple import re import torch from transformers import AutoTokenizer, AutoModelForTokenClassification # Mapping of label to color LABEL_COLORS = { 'LABEL-0': '#cccccc', # NONE 'LABEL-1': '#ffadad', # B-DATE 'LABEL-2': '#ffd6a5', # I-DATE 'LABEL-3': '#fdffb6', # B-TIME 'LABEL-4': '#caffbf', # I-TIME 'LABEL-5': '#9bf6ff', # B-DURATION 'LABEL-6': '#a0c4ff', # I-DURATION 'LABEL-7': '#bdb2ff', # B-SET 'LABEL-8': '#ffc6ff', # I-SET } @st.cache_resource(show_spinner=True) def load_model(): tokenizer = AutoTokenizer.from_pretrained('asdc/Bio-RoBERTime') model = AutoModelForTokenClassification.from_pretrained('asdc/Bio-RoBERTime') return tokenizer, model def ner_with_robertime(text: str) -> List[Tuple[str, str]]: tokenizer, model = load_model() # Tokenize and get input tensors tokens = tokenizer(text, return_tensors="pt", truncation=True, is_split_into_words=False) with torch.no_grad(): outputs = model(**tokens) predictions = torch.argmax(outputs.logits, dim=2)[0].tolist() # Map ids to labels labels = [model.config.id2label[pred] for pred in predictions] # Get tokens (handling subwords) word_ids = tokens.word_ids(batch_index=0) token_list = tokenizer.convert_ids_to_tokens(tokens["input_ids"][0]) # Merge subwords and assign entity labels entities = [] current_word = '' current_label = None last_word_id = None for idx, word_id in enumerate(word_ids): if word_id is None: continue token = token_list[idx] label = labels[idx] if token.startswith('▁') or token.startswith('##') or token.startswith('Ġ'): token = token.lstrip('▁#Ġ') if word_id != last_word_id and current_word: entities.append((current_word, current_label)) current_word = token current_label = label else: if current_word: current_word += token if token.startswith("'") else f' {token}' else: current_word = token current_label = label last_word_id = word_id if current_word: entities.append((current_word, current_label)) return entities def colorize_entities(ner_result: List[Tuple[str, str]]) -> str: html = '' for token, label in ner_result: color = LABEL_COLORS.get(label, '#eeeeee') if label != 'LABEL-0': html += f'{token} ' else: html += f'{token} ' return html st.title('LLM-powered Named Entity Recognition (NER)') user_text = st.text_area('Enter text for NER:', height=150) if user_text: ner_result = ner_with_robertime(user_text) st.markdown('#### Entities:') st.markdown(colorize_entities(ner_result), unsafe_allow_html=True) st.caption('Model: [asdc/Bio-RoBERTime](https://huggingface.co/asdc/Bio-RoBERTime)')