Spaces:
Sleeping
Sleeping
Upload streamlit_app.py
Browse files- src/streamlit_app.py +21 -28
src/streamlit_app.py
CHANGED
@@ -9,15 +9,15 @@ from transformers import AutoTokenizer, AutoModelForTokenClassification
|
|
9 |
|
10 |
# Mapping of label to color
|
11 |
LABEL_COLORS = {
|
12 |
-
'LABEL-0': '#
|
13 |
-
'LABEL-1': '#
|
14 |
-
'LABEL-2': '#
|
15 |
-
'LABEL-3': '#
|
16 |
-
'LABEL-4': '#
|
17 |
-
'LABEL-5': '#
|
18 |
-
'LABEL-6': '#
|
19 |
-
'LABEL-7': '#
|
20 |
-
'LABEL-8': '#
|
21 |
}
|
22 |
|
23 |
LABEL_MEANINGS = {
|
@@ -40,41 +40,33 @@ def load_model():
|
|
40 |
|
41 |
def ner_with_robertime(text: str) -> List[Tuple[str, str]]:
|
42 |
tokenizer, model = load_model()
|
43 |
-
# Tokenize and get input tensors
|
44 |
tokens = tokenizer(text, return_tensors="pt", truncation=True, is_split_into_words=False)
|
45 |
with torch.no_grad():
|
46 |
outputs = model(**tokens)
|
47 |
predictions = torch.argmax(outputs.logits, dim=2)[0].tolist()
|
48 |
-
# Map ids to labels
|
49 |
labels = [model.config.id2label[pred] for pred in predictions]
|
50 |
-
# Get tokens (handling subwords)
|
51 |
word_ids = tokens.word_ids(batch_index=0)
|
52 |
-
|
53 |
-
# Merge subwords and assign entity labels
|
54 |
entities = []
|
55 |
-
|
56 |
current_label = None
|
57 |
last_word_id = None
|
58 |
for idx, word_id in enumerate(word_ids):
|
59 |
if word_id is None:
|
60 |
continue
|
61 |
-
token = token_list[idx]
|
62 |
label = labels[idx]
|
63 |
-
if
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
current_word = token
|
68 |
current_label = label
|
69 |
else:
|
70 |
-
|
71 |
-
current_word += token if token.startswith("'") else f' {token}'
|
72 |
-
else:
|
73 |
-
current_word = token
|
74 |
current_label = label
|
75 |
last_word_id = word_id
|
76 |
-
if
|
77 |
-
|
|
|
78 |
return entities
|
79 |
|
80 |
def colorize_entities(ner_result: List[Tuple[str, str]]) -> str:
|
@@ -83,7 +75,8 @@ def colorize_entities(ner_result: List[Tuple[str, str]]) -> str:
|
|
83 |
norm_label = label.replace('_', '-')
|
84 |
if norm_label != 'LABEL-0':
|
85 |
color = LABEL_COLORS.get(norm_label, '#eeeeee')
|
86 |
-
|
|
|
87 |
else:
|
88 |
html += f'{token} '
|
89 |
return html
|
|
|
9 |
|
10 |
# Mapping of label to color
|
11 |
LABEL_COLORS = {
|
12 |
+
'LABEL-0': '#ffffff', # NONE (no color)
|
13 |
+
'LABEL-1': '#fff4e6', # B-DATE (creamy orange)
|
14 |
+
'LABEL-2': '#ffe9ec', # I-DATE (creamy pink)
|
15 |
+
'LABEL-3': '#f3ffe3', # B-TIME (creamy green)
|
16 |
+
'LABEL-4': '#e6f7ff', # I-TIME (creamy blue)
|
17 |
+
'LABEL-5': '#f9f7e8', # B-DURATION (creamy yellow)
|
18 |
+
'LABEL-6': '#f6eaff', # I-DURATION (creamy purple)
|
19 |
+
'LABEL-7': '#fdf6ec', # B-SET (creamy beige)
|
20 |
+
'LABEL-8': '#f6fff8', # I-SET (creamy mint)
|
21 |
}
|
22 |
|
23 |
LABEL_MEANINGS = {
|
|
|
40 |
|
41 |
def ner_with_robertime(text: str) -> List[Tuple[str, str]]:
|
42 |
tokenizer, model = load_model()
|
|
|
43 |
tokens = tokenizer(text, return_tensors="pt", truncation=True, is_split_into_words=False)
|
44 |
with torch.no_grad():
|
45 |
outputs = model(**tokens)
|
46 |
predictions = torch.argmax(outputs.logits, dim=2)[0].tolist()
|
|
|
47 |
labels = [model.config.id2label[pred] for pred in predictions]
|
|
|
48 |
word_ids = tokens.word_ids(batch_index=0)
|
49 |
+
input_ids = tokens["input_ids"][0]
|
|
|
50 |
entities = []
|
51 |
+
current_word_ids = []
|
52 |
current_label = None
|
53 |
last_word_id = None
|
54 |
for idx, word_id in enumerate(word_ids):
|
55 |
if word_id is None:
|
56 |
continue
|
|
|
57 |
label = labels[idx]
|
58 |
+
if word_id != last_word_id and current_word_ids:
|
59 |
+
word = tokenizer.decode([input_ids[i] for i in current_word_ids], skip_special_tokens=True)
|
60 |
+
entities.append((word, current_label))
|
61 |
+
current_word_ids = [idx]
|
|
|
62 |
current_label = label
|
63 |
else:
|
64 |
+
current_word_ids.append(idx)
|
|
|
|
|
|
|
65 |
current_label = label
|
66 |
last_word_id = word_id
|
67 |
+
if current_word_ids:
|
68 |
+
word = tokenizer.decode([input_ids[i] for i in current_word_ids], skip_special_tokens=True)
|
69 |
+
entities.append((word, current_label))
|
70 |
return entities
|
71 |
|
72 |
def colorize_entities(ner_result: List[Tuple[str, str]]) -> str:
|
|
|
75 |
norm_label = label.replace('_', '-')
|
76 |
if norm_label != 'LABEL-0':
|
77 |
color = LABEL_COLORS.get(norm_label, '#eeeeee')
|
78 |
+
label_meaning = LABEL_MEANINGS.get(norm_label, norm_label)
|
79 |
+
html += f'<span style="background-color:{color};padding:2px 4px;border-radius:4px;margin:1px;" title="{label_meaning}">{token}</span> '
|
80 |
else:
|
81 |
html += f'{token} '
|
82 |
return html
|