Update app.py
Browse files
app.py
CHANGED
@@ -56,67 +56,59 @@ model = model.to(device)
|
|
56 |
model.eval()
|
57 |
|
58 |
# Inference function
|
59 |
-
|
60 |
def get_word_probabilities(text):
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
inputs = {k: v.to(device) for k, v in inputs.items()}
|
70 |
-
except Exception as e:
|
71 |
-
print("Error during tokenization or moving inputs to device:", e)
|
72 |
-
return []
|
73 |
-
|
74 |
-
try:
|
75 |
-
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
|
76 |
-
except Exception as e:
|
77 |
-
print("Error during token conversion:", e)
|
78 |
-
return []
|
79 |
-
|
80 |
-
try:
|
81 |
-
with torch.no_grad():
|
82 |
-
tags, emission = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
|
83 |
-
except Exception as e:
|
84 |
-
print("Error during model inference:", e)
|
85 |
-
return []
|
86 |
-
|
87 |
-
try:
|
88 |
-
probs = torch.softmax(emission, dim=-1)[0, :, 1].cpu().numpy()
|
89 |
-
except Exception as e:
|
90 |
-
print("Error during softmax or extracting class probabilities:", e)
|
91 |
-
return []
|
92 |
-
|
93 |
word_probs = []
|
|
|
94 |
current_word = ""
|
95 |
current_probs = []
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
if
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
|
|
|
|
120 |
|
121 |
# def get_word_classifications(text):
|
122 |
# text = " ".join(text.split(" ")[:2048])
|
@@ -186,7 +178,7 @@ def infer_and_log(text_input):
|
|
186 |
"id": submission_id,
|
187 |
"timestamp": timestamp,
|
188 |
"input": text_input,
|
189 |
-
"
|
190 |
}
|
191 |
|
192 |
os.makedirs("logs", exist_ok=True)
|
@@ -207,7 +199,7 @@ def infer_and_log(text_input):
|
|
207 |
except Exception as e:
|
208 |
print(f"Error uploading log: {e}")
|
209 |
|
210 |
-
return json.dumps(
|
211 |
|
212 |
|
213 |
def clear_fields():
|
|
|
56 |
model.eval()
|
57 |
|
58 |
# Inference function
|
59 |
+
|
60 |
def get_word_probabilities(text):
|
61 |
+
text = " ".join(text.split(" ")[:2048])
|
62 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
63 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
64 |
+
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
|
65 |
+
with torch.no_grad():
|
66 |
+
tags, emission = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
|
67 |
+
probs = torch.softmax(emission, dim=-1)[0, :, 1].cpu().numpy()
|
68 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
word_probs = []
|
70 |
+
word_colors = []
|
71 |
current_word = ""
|
72 |
current_probs = []
|
73 |
|
74 |
+
for token, prob in zip(tokens, probs):
|
75 |
+
if token in ["<s>", "</s>"]:
|
76 |
+
continue
|
77 |
+
if token.startswith("▁"):
|
78 |
+
if current_word and current_probs:
|
79 |
+
current_prob = sum(current_probs) / len(current_probs)
|
80 |
+
word_probs.append(current_prob)
|
81 |
+
|
82 |
+
# Determine color based on probability
|
83 |
+
color = (
|
84 |
+
"green" if current_prob < 0.25 else
|
85 |
+
"yellow" if current_prob < 0.5 else
|
86 |
+
"orange" if current_prob < 0.75 else
|
87 |
+
"red"
|
88 |
+
)
|
89 |
+
word_colors.append(color)
|
90 |
+
|
91 |
+
current_word = token[1:] if token != "▁" else ""
|
92 |
+
current_probs = [prob]
|
93 |
+
else:
|
94 |
+
current_word += token
|
95 |
+
current_probs.append(prob)
|
96 |
+
|
97 |
+
if current_word and current_probs:
|
98 |
+
current_prob = sum(current_probs) / len(current_probs)
|
99 |
+
word_probs.append(current_prob)
|
100 |
+
|
101 |
+
# Determine color for the last word
|
102 |
+
color = (
|
103 |
+
"green" if current_prob < 0.25 else
|
104 |
+
"yellow" if current_prob < 0.5 else
|
105 |
+
"orange" if current_prob < 0.75 else
|
106 |
+
"red"
|
107 |
+
)
|
108 |
+
word_colors.append(color)
|
109 |
|
110 |
+
word_probs = [float(p) for p in word_probs]
|
111 |
+
return word_probs,
|
112 |
|
113 |
# def get_word_classifications(text):
|
114 |
# text = " ".join(text.split(" ")[:2048])
|
|
|
178 |
"id": submission_id,
|
179 |
"timestamp": timestamp,
|
180 |
"input": text_input,
|
181 |
+
"output_probs": word_probs
|
182 |
}
|
183 |
|
184 |
os.makedirs("logs", exist_ok=True)
|
|
|
199 |
except Exception as e:
|
200 |
print(f"Error uploading log: {e}")
|
201 |
|
202 |
+
return json.dumps(word_probs, indent=2)
|
203 |
|
204 |
|
205 |
def clear_fields():
|