asdc commited on
Commit
6e9b058
·
verified ·
1 Parent(s): 65003ad

Upload streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +21 -28
src/streamlit_app.py CHANGED
@@ -9,15 +9,15 @@ from transformers import AutoTokenizer, AutoModelForTokenClassification
9
 
10
  # Mapping of label to color
11
  LABEL_COLORS = {
12
- 'LABEL-0': '#cccccc', # NONE
13
- 'LABEL-1': '#ffadad', # B-DATE
14
- 'LABEL-2': '#ffd6a5', # I-DATE
15
- 'LABEL-3': '#fdffb6', # B-TIME
16
- 'LABEL-4': '#caffbf', # I-TIME
17
- 'LABEL-5': '#9bf6ff', # B-DURATION
18
- 'LABEL-6': '#a0c4ff', # I-DURATION
19
- 'LABEL-7': '#bdb2ff', # B-SET
20
- 'LABEL-8': '#ffc6ff', # I-SET
21
  }
22
 
23
  LABEL_MEANINGS = {
@@ -40,41 +40,33 @@ def load_model():
40
 
41
  def ner_with_robertime(text: str) -> List[Tuple[str, str]]:
42
  tokenizer, model = load_model()
43
- # Tokenize and get input tensors
44
  tokens = tokenizer(text, return_tensors="pt", truncation=True, is_split_into_words=False)
45
  with torch.no_grad():
46
  outputs = model(**tokens)
47
  predictions = torch.argmax(outputs.logits, dim=2)[0].tolist()
48
- # Map ids to labels
49
  labels = [model.config.id2label[pred] for pred in predictions]
50
- # Get tokens (handling subwords)
51
  word_ids = tokens.word_ids(batch_index=0)
52
- token_list = tokenizer.convert_ids_to_tokens(tokens["input_ids"][0])
53
- # Merge subwords and assign entity labels
54
  entities = []
55
- current_word = ''
56
  current_label = None
57
  last_word_id = None
58
  for idx, word_id in enumerate(word_ids):
59
  if word_id is None:
60
  continue
61
- token = token_list[idx]
62
  label = labels[idx]
63
- if token.startswith('▁') or token.startswith('##') or token.startswith('Ġ'):
64
- token = token.lstrip('▁#Ġ')
65
- if word_id != last_word_id and current_word:
66
- entities.append((current_word, current_label))
67
- current_word = token
68
  current_label = label
69
  else:
70
- if current_word:
71
- current_word += token if token.startswith("'") else f' {token}'
72
- else:
73
- current_word = token
74
  current_label = label
75
  last_word_id = word_id
76
- if current_word:
77
- entities.append((current_word, current_label))
 
78
  return entities
79
 
80
  def colorize_entities(ner_result: List[Tuple[str, str]]) -> str:
@@ -83,7 +75,8 @@ def colorize_entities(ner_result: List[Tuple[str, str]]) -> str:
83
  norm_label = label.replace('_', '-')
84
  if norm_label != 'LABEL-0':
85
  color = LABEL_COLORS.get(norm_label, '#eeeeee')
86
- html += f'<span style="background-color:{color};padding:2px 4px;border-radius:4px;margin:1px;">{token}</span> '
 
87
  else:
88
  html += f'{token} '
89
  return html
 
9
 
10
  # Mapping of label to color
11
  LABEL_COLORS = {
12
+ 'LABEL-0': '#ffffff', # NONE (no color)
13
+ 'LABEL-1': '#fff4e6', # B-DATE (creamy orange)
14
+ 'LABEL-2': '#ffe9ec', # I-DATE (creamy pink)
15
+ 'LABEL-3': '#f3ffe3', # B-TIME (creamy green)
16
+ 'LABEL-4': '#e6f7ff', # I-TIME (creamy blue)
17
+ 'LABEL-5': '#f9f7e8', # B-DURATION (creamy yellow)
18
+ 'LABEL-6': '#f6eaff', # I-DURATION (creamy purple)
19
+ 'LABEL-7': '#fdf6ec', # B-SET (creamy beige)
20
+ 'LABEL-8': '#f6fff8', # I-SET (creamy mint)
21
  }
22
 
23
  LABEL_MEANINGS = {
 
40
 
41
  def ner_with_robertime(text: str) -> List[Tuple[str, str]]:
42
  tokenizer, model = load_model()
 
43
  tokens = tokenizer(text, return_tensors="pt", truncation=True, is_split_into_words=False)
44
  with torch.no_grad():
45
  outputs = model(**tokens)
46
  predictions = torch.argmax(outputs.logits, dim=2)[0].tolist()
 
47
  labels = [model.config.id2label[pred] for pred in predictions]
 
48
  word_ids = tokens.word_ids(batch_index=0)
49
+ input_ids = tokens["input_ids"][0]
 
50
  entities = []
51
+ current_word_ids = []
52
  current_label = None
53
  last_word_id = None
54
  for idx, word_id in enumerate(word_ids):
55
  if word_id is None:
56
  continue
 
57
  label = labels[idx]
58
+ if word_id != last_word_id and current_word_ids:
59
+ word = tokenizer.decode([input_ids[i] for i in current_word_ids], skip_special_tokens=True)
60
+ entities.append((word, current_label))
61
+ current_word_ids = [idx]
 
62
  current_label = label
63
  else:
64
+ current_word_ids.append(idx)
 
 
 
65
  current_label = label
66
  last_word_id = word_id
67
+ if current_word_ids:
68
+ word = tokenizer.decode([input_ids[i] for i in current_word_ids], skip_special_tokens=True)
69
+ entities.append((word, current_label))
70
  return entities
71
 
72
  def colorize_entities(ner_result: List[Tuple[str, str]]) -> str:
 
75
  norm_label = label.replace('_', '-')
76
  if norm_label != 'LABEL-0':
77
  color = LABEL_COLORS.get(norm_label, '#eeeeee')
78
+ label_meaning = LABEL_MEANINGS.get(norm_label, norm_label)
79
+ html += f'<span style="background-color:{color};padding:2px 4px;border-radius:4px;margin:1px;" title="{label_meaning}">{token}</span> '
80
  else:
81
  html += f'{token} '
82
  return html