File size: 3,055 Bytes
ab64fd6
2f2c452
 
 
 
ab64fd6
2f2c452
 
 
 
 
 
 
 
 
 
 
 
ab64fd6
2f2c452
 
 
 
 
ab64fd6
2f2c452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab64fd6
2f2c452
 
 
 
 
 
 
 
 
ab64fd6
2f2c452
ab64fd6
2f2c452
ab64fd6
2f2c452
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import streamlit as st
from typing import List, Tuple
import re
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification

# Mapping of label to color
LABEL_COLORS = {
    'LABEL-0': '#cccccc',  # NONE
    'LABEL-1': '#ffadad',  # B-DATE
    'LABEL-2': '#ffd6a5',  # I-DATE
    'LABEL-3': '#fdffb6',  # B-TIME
    'LABEL-4': '#caffbf',  # I-TIME
    'LABEL-5': '#9bf6ff',  # B-DURATION
    'LABEL-6': '#a0c4ff',  # I-DURATION
    'LABEL-7': '#bdb2ff',  # B-SET
    'LABEL-8': '#ffc6ff',  # I-SET
}

@st.cache_resource(show_spinner=True)
def load_model():
    tokenizer = AutoTokenizer.from_pretrained('asdc/Bio-RoBERTime')
    model = AutoModelForTokenClassification.from_pretrained('asdc/Bio-RoBERTime')
    return tokenizer, model

def ner_with_robertime(text: str) -> List[Tuple[str, str]]:
    tokenizer, model = load_model()
    # Tokenize and get input tensors
    tokens = tokenizer(text, return_tensors="pt", truncation=True, is_split_into_words=False)
    with torch.no_grad():
        outputs = model(**tokens)
    predictions = torch.argmax(outputs.logits, dim=2)[0].tolist()
    # Map ids to labels
    labels = [model.config.id2label[pred] for pred in predictions]
    # Get tokens (handling subwords)
    word_ids = tokens.word_ids(batch_index=0)
    token_list = tokenizer.convert_ids_to_tokens(tokens["input_ids"][0])
    # Merge subwords and assign entity labels
    entities = []
    current_word = ''
    current_label = None
    last_word_id = None
    for idx, word_id in enumerate(word_ids):
        if word_id is None:
            continue
        token = token_list[idx]
        label = labels[idx]
        if token.startswith('▁') or token.startswith('##') or token.startswith('Ġ'):
            token = token.lstrip('▁#Ġ')
        if word_id != last_word_id and current_word:
            entities.append((current_word, current_label))
            current_word = token
            current_label = label
        else:
            if current_word:
                current_word += token if token.startswith("'") else f' {token}'
            else:
                current_word = token
            current_label = label
        last_word_id = word_id
    if current_word:
        entities.append((current_word, current_label))
    return entities

def colorize_entities(ner_result: List[Tuple[str, str]]) -> str:
    html = ''
    for token, label in ner_result:
        color = LABEL_COLORS.get(label, '#eeeeee')
        if label != 'LABEL-0':
            html += f'<span style="background-color:{color};padding:2px 4px;border-radius:4px;margin:1px;">{token}</span> '
        else:
            html += f'{token} '
    return html

st.title('LLM-powered Named Entity Recognition (NER)')

user_text = st.text_area('Enter text for NER:', height=150)

if user_text:
    ner_result = ner_with_robertime(user_text)
    st.markdown('#### Entities:')
    st.markdown(colorize_entities(ner_result), unsafe_allow_html=True)
    st.caption('Model: [asdc/Bio-RoBERTime](https://huggingface.co/asdc/Bio-RoBERTime)')