minemaster01 commited on
Commit
03d2b1d
·
verified ·
1 Parent(s): 2ebed35

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -0
app.py CHANGED
@@ -80,6 +80,38 @@ def get_word_classifications(text):
80
  word_tags.append(str(current_tag))
81
  return word_tags
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  # def get_word_classifications(text):
84
  # text = " ".join(text.split(" ")[:2048])
85
  # inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
 
80
  word_tags.append(str(current_tag))
81
  return word_tags
82
 
83
+ def get_word_probabilities(text):
84
+ text = " ".join(text.split(" ")[:2048])
85
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
86
+ inputs = {k: v.to(device) for k, v in inputs.items()}
87
+ tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
88
+ with torch.no_grad():
89
+ tags, emission = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
90
+ probs = torch.softmax(emission, dim=-1)[0, :, 1].cpu().numpy()
91
+
92
+ word_probs = []
93
+ current_word = ""
94
+ current_probs = []
95
+
96
+ for token, prob in zip(tokens, probs):
97
+ if token in ["<s>", "</s>"]:
98
+ continue
99
+ if token.startswith("▁"):
100
+ if current_word and current_probs:
101
+ word_probs.append(sum(current_probs) / len(current_probs))
102
+ current_word = token[1:] if token != "▁" else ""
103
+ current_probs = [prob]
104
+ else:
105
+ current_word += token
106
+ current_probs.append(prob)
107
+
108
+ if current_word and current_probs:
109
+ word_probs.append(sum(current_probs) / len(current_probs))
110
+
111
+ return word_probs
112
+
113
+
114
+
115
  # def get_word_classifications(text):
116
  # text = " ".join(text.split(" ")[:2048])
117
  # inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)