minemaster01 commited on
Commit
ec5bad4
Β·
verified Β·
1 Parent(s): 5050833

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -54
app.py CHANGED
@@ -56,75 +56,75 @@ model = model.to(device)
56
  model.eval()
57
 
58
  # Inference function
59
- # def get_word_classifications(text):
60
- # text = " ".join(text.split(" ")[:2048])
61
- # inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
62
- # inputs = {k: v.to(device) for k, v in inputs.items()}
63
- # tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
64
- # with torch.no_grad():
65
- # tags, _ = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
66
- # word_tags = []
67
- # current_word = ""
68
- # current_tag = ""
69
- # for token, tag in zip(tokens, tags[0]):
70
- # if token in ["<s>", "</s>"]:
71
- # continue
72
- # if token.startswith("▁"):
73
- # if current_word:
74
- # word_tags.append(str(current_tag))
75
- # current_word = token[1:] if token != "▁" else ""
76
- # current_tag = tag
77
- # else:
78
- # current_word += token
79
- # if current_word:
80
- # word_tags.append(str(current_tag))
81
- # return word_tags
82
-
83
  def get_word_classifications(text):
84
  text = " ".join(text.split(" ")[:2048])
85
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
86
  inputs = {k: v.to(device) for k, v in inputs.items()}
87
  tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
88
-
89
  with torch.no_grad():
90
- tags, emissions = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
91
-
92
  word_tags = []
93
- color_output = []
94
  current_word = ""
95
- current_prob = 0.0
96
-
97
- for token, prob in zip(tokens, tags[0]):
98
  if token in ["<s>", "</s>"]:
99
  continue
100
  if token.startswith("▁"):
101
  if current_word:
102
- word_tags.append(round(current_prob, 3))
103
- color = (
104
- "green" if current_prob < 0.25 else
105
- "yellow" if current_prob < 0.5 else
106
- "orange" if current_prob < 0.75 else
107
- "red"
108
- )
109
- color_output.append(f'<span style="color:{color}">{current_word}</span>')
110
  current_word = token[1:] if token != "▁" else ""
111
- current_prob = prob
112
  else:
113
  current_word += token
114
- current_prob = max(current_prob, prob)
115
-
116
  if current_word:
117
- word_tags.append(round(current_prob, 3))
118
- color = (
119
- "green" if current_prob < 0.25 else
120
- "yellow" if current_prob < 0.5 else
121
- "orange" if current_prob < 0.75 else
122
- "red"
123
- )
124
- color_output.append(f'<span style="color:{color}">{current_word}</span>')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
- output = " ".join(color_output)
127
- return output, word_tags
128
 
129
 
130
  # HF logging setup
@@ -140,7 +140,7 @@ def setup_hf_dataset():
140
 
141
  # Main inference + logging function
142
  def infer_and_log(text_input):
143
- output, word_tags = get_word_classifications(text_input)
144
  timestamp = datetime.datetime.now().isoformat()
145
  submission_id = str(uuid.uuid4())
146
 
@@ -169,7 +169,7 @@ def infer_and_log(text_input):
169
  except Exception as e:
170
  print(f"Error uploading log: {e}")
171
 
172
- return output
173
 
174
  def clear_fields():
175
  return "", ""
 
56
  model.eval()
57
 
58
  # Inference function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def get_word_classifications(text):
60
  text = " ".join(text.split(" ")[:2048])
61
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
62
  inputs = {k: v.to(device) for k, v in inputs.items()}
63
  tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
 
64
  with torch.no_grad():
65
+ tags, _ = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
 
66
  word_tags = []
 
67
  current_word = ""
68
+ current_tag = ""
69
+ for token, tag in zip(tokens, tags[0]):
 
70
  if token in ["<s>", "</s>"]:
71
  continue
72
  if token.startswith("▁"):
73
  if current_word:
74
+ word_tags.append(str(current_tag))
 
 
 
 
 
 
 
75
  current_word = token[1:] if token != "▁" else ""
76
+ current_tag = tag
77
  else:
78
  current_word += token
 
 
79
  if current_word:
80
+ word_tags.append(str(current_tag))
81
+ return word_tags
82
+
83
+ # def get_word_classifications(text):
84
+ # text = " ".join(text.split(" ")[:2048])
85
+ # inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
86
+ # inputs = {k: v.to(device) for k, v in inputs.items()}
87
+ # tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
88
+
89
+ # with torch.no_grad():
90
+ # tags, emissions = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
91
+
92
+ # word_tags = []
93
+ # color_output = []
94
+ # current_word = ""
95
+ # current_prob = 0.0
96
+
97
+ # for token, prob in zip(tokens, tags[0]):
98
+ # if token in ["<s>", "</s>"]:
99
+ # continue
100
+ # if token.startswith("▁"):
101
+ # if current_word:
102
+ # word_tags.append(round(current_prob, 3))
103
+ # color = (
104
+ # "green" if current_prob < 0.25 else
105
+ # "yellow" if current_prob < 0.5 else
106
+ # "orange" if current_prob < 0.75 else
107
+ # "red"
108
+ # )
109
+ # color_output.append(f'<span style="color:{color}">{current_word}</span>')
110
+ # current_word = token[1:] if token != "▁" else ""
111
+ # current_prob = prob
112
+ # else:
113
+ # current_word += token
114
+ # current_prob = max(current_prob, prob)
115
+
116
+ # if current_word:
117
+ # word_tags.append(round(current_prob, 3))
118
+ # color = (
119
+ # "green" if current_prob < 0.25 else
120
+ # "yellow" if current_prob < 0.5 else
121
+ # "orange" if current_prob < 0.75 else
122
+ # "red"
123
+ # )
124
+ # color_output.append(f'<span style="color:{color}">{current_word}</span>')
125
 
126
+ # output = " ".join(color_output)
127
+ # return output, word_tags
128
 
129
 
130
  # HF logging setup
 
140
 
141
  # Main inference + logging function
142
  def infer_and_log(text_input):
143
+ word_tags = get_word_classifications(text_input)
144
  timestamp = datetime.datetime.now().isoformat()
145
  submission_id = str(uuid.uuid4())
146
 
 
169
  except Exception as e:
170
  print(f"Error uploading log: {e}")
171
 
172
+ return "".join(word_tags)
173
 
174
  def clear_fields():
175
  return "", ""