minemaster01 commited on
Commit
5050833
Β·
verified Β·
1 Parent(s): 4c81942

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -11
app.py CHANGED
@@ -56,29 +56,76 @@ model = model.to(device)
56
  model.eval()
57
 
58
  # Inference function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def get_word_classifications(text):
60
  text = " ".join(text.split(" ")[:2048])
61
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
62
  inputs = {k: v.to(device) for k, v in inputs.items()}
63
  tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
 
64
  with torch.no_grad():
65
- tags, _ = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
 
66
  word_tags = []
 
67
  current_word = ""
68
- current_tag = None
69
- for token, tag in zip(tokens, tags[0]):
 
70
  if token in ["<s>", "</s>"]:
71
  continue
72
  if token.startswith("▁"):
73
  if current_word:
74
- word_tags.append(str(current_tag))
 
 
 
 
 
 
 
75
  current_word = token[1:] if token != "▁" else ""
76
- current_tag = tag
77
  else:
78
  current_word += token
 
 
79
  if current_word:
80
- word_tags.append(str(current_tag))
81
- return word_tags
 
 
 
 
 
 
 
 
 
 
82
 
83
  # HF logging setup
84
  def setup_hf_dataset():
@@ -93,7 +140,7 @@ def setup_hf_dataset():
93
 
94
  # Main inference + logging function
95
  def infer_and_log(text_input):
96
- word_tags = get_word_classifications(text_input)
97
  timestamp = datetime.datetime.now().isoformat()
98
  submission_id = str(uuid.uuid4())
99
 
@@ -122,7 +169,7 @@ def infer_and_log(text_input):
122
  except Exception as e:
123
  print(f"Error uploading log: {e}")
124
 
125
- return " ".join(word_tags)
126
 
127
  def clear_fields():
128
  return "", ""
@@ -136,12 +183,13 @@ with gr.Blocks() as app:
136
 
137
  with gr.Row():
138
  input_box = gr.Textbox(label="Input Text", lines=10)
139
- output_box = gr.Textbox(label="Output Tags", lines=10, interactive=False)
140
 
141
  with gr.Row():
142
  submit_btn = gr.Button("Submit")
143
  clear_btn = gr.Button("Clear")
144
-
 
145
  submit_btn.click(fn=infer_and_log, inputs=input_box, outputs=output_box)
146
  clear_btn.click(fn=clear_fields, outputs=[input_box, output_box])
147
 
 
56
  model.eval()
57
 
58
  # Inference function
59
+ # def get_word_classifications(text):
60
+ # text = " ".join(text.split(" ")[:2048])
61
+ # inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
62
+ # inputs = {k: v.to(device) for k, v in inputs.items()}
63
+ # tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
64
+ # with torch.no_grad():
65
+ # tags, _ = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
66
+ # word_tags = []
67
+ # current_word = ""
68
+ # current_tag = ""
69
+ # for token, tag in zip(tokens, tags[0]):
70
+ # if token in ["<s>", "</s>"]:
71
+ # continue
72
+ # if token.startswith("▁"):
73
+ # if current_word:
74
+ # word_tags.append(str(current_tag))
75
+ # current_word = token[1:] if token != "▁" else ""
76
+ # current_tag = tag
77
+ # else:
78
+ # current_word += token
79
+ # if current_word:
80
+ # word_tags.append(str(current_tag))
81
+ # return word_tags
82
+
83
  def get_word_classifications(text):
84
  text = " ".join(text.split(" ")[:2048])
85
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
86
  inputs = {k: v.to(device) for k, v in inputs.items()}
87
  tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
88
+
89
  with torch.no_grad():
90
+ tags, emissions = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
91
+
92
  word_tags = []
93
+ color_output = []
94
  current_word = ""
95
+ current_prob = 0.0
96
+
97
+ for token, prob in zip(tokens, tags[0]):
98
  if token in ["<s>", "</s>"]:
99
  continue
100
  if token.startswith("▁"):
101
  if current_word:
102
+ word_tags.append(round(current_prob, 3))
103
+ color = (
104
+ "green" if current_prob < 0.25 else
105
+ "yellow" if current_prob < 0.5 else
106
+ "orange" if current_prob < 0.75 else
107
+ "red"
108
+ )
109
+ color_output.append(f'<span style="color:{color}">{current_word}</span>')
110
  current_word = token[1:] if token != "▁" else ""
111
+ current_prob = prob
112
  else:
113
  current_word += token
114
+ current_prob = max(current_prob, prob)
115
+
116
  if current_word:
117
+ word_tags.append(round(current_prob, 3))
118
+ color = (
119
+ "green" if current_prob < 0.25 else
120
+ "yellow" if current_prob < 0.5 else
121
+ "orange" if current_prob < 0.75 else
122
+ "red"
123
+ )
124
+ color_output.append(f'<span style="color:{color}">{current_word}</span>')
125
+
126
+ output = " ".join(color_output)
127
+ return output, word_tags
128
+
129
 
130
  # HF logging setup
131
  def setup_hf_dataset():
 
140
 
141
  # Main inference + logging function
142
  def infer_and_log(text_input):
143
+ output, word_tags = get_word_classifications(text_input)
144
  timestamp = datetime.datetime.now().isoformat()
145
  submission_id = str(uuid.uuid4())
146
 
 
169
  except Exception as e:
170
  print(f"Error uploading log: {e}")
171
 
172
+ return output
173
 
174
  def clear_fields():
175
  return "", ""
 
183
 
184
  with gr.Row():
185
  input_box = gr.Textbox(label="Input Text", lines=10)
186
+ output_box = gr.HTML(label="Output Tags")
187
 
188
  with gr.Row():
189
  submit_btn = gr.Button("Submit")
190
  clear_btn = gr.Button("Clear")
191
+
192
+
193
  submit_btn.click(fn=infer_and_log, inputs=input_box, outputs=output_box)
194
  clear_btn.click(fn=clear_fields, outputs=[input_box, output_box])
195