Update app.py
Browse files
app.py
CHANGED
@@ -54,6 +54,16 @@ model.load_state_dict(checkpoint.get("model_state_dict", checkpoint), strict=Fal
|
|
54 |
model = model.to(device)
|
55 |
model.eval()
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
def get_word_probabilities(text):
|
58 |
text = " ".join(text.split(" ")[:2048])
|
59 |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
@@ -73,7 +83,7 @@ def get_word_probabilities(text):
|
|
73 |
if current_word and current_probs:
|
74 |
current_prob = sum(current_probs) / len(current_probs)
|
75 |
word_probs.append(current_prob)
|
76 |
-
color = (
|
77 |
word_colors.append(color)
|
78 |
current_word = token[1:] if token != "▁" else ""
|
79 |
current_probs = [prob]
|
@@ -83,8 +93,17 @@ def get_word_probabilities(text):
|
|
83 |
if current_word and current_probs:
|
84 |
current_prob = sum(current_probs) / len(current_probs)
|
85 |
word_probs.append(current_prob)
|
86 |
-
color = (
|
87 |
word_colors.append(color)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
word_probs = [float(p) for p in word_probs]
|
89 |
return word_probs, word_colors
|
90 |
|
|
|
54 |
model = model.to(device)
|
55 |
model.eval()
|
56 |
|
57 |
+
def get_color(prob):
|
58 |
+
if prob < 0.25:
|
59 |
+
return "green"
|
60 |
+
elif prob < 0.5:
|
61 |
+
return "yellow"
|
62 |
+
elif prob < 0.75:
|
63 |
+
return "orange"
|
64 |
+
else:
|
65 |
+
return "red"
|
66 |
+
|
67 |
def get_word_probabilities(text):
|
68 |
text = " ".join(text.split(" ")[:2048])
|
69 |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
|
|
83 |
if current_word and current_probs:
|
84 |
current_prob = sum(current_probs) / len(current_probs)
|
85 |
word_probs.append(current_prob)
|
86 |
+
color = get_color(current_prob)
|
87 |
word_colors.append(color)
|
88 |
current_word = token[1:] if token != "▁" else ""
|
89 |
current_probs = [prob]
|
|
|
93 |
if current_word and current_probs:
|
94 |
current_prob = sum(current_probs) / len(current_probs)
|
95 |
word_probs.append(current_prob)
|
96 |
+
color = get_color(current_prob)
|
97 |
word_colors.append(color)
|
98 |
+
|
99 |
+
####### FOR STABLE OUTPUTS
|
100 |
+
first_avg = (word_probs[1] + word_probs[2]) / 2
|
101 |
+
word_colors[0] = get_color(first_avg)
|
102 |
+
|
103 |
+
last_avg = (word_probs[-2] + word_probs[-3]) / 2
|
104 |
+
word_colors[-1] = get_color(last_avg)
|
105 |
+
#########
|
106 |
+
|
107 |
word_probs = [float(p) for p in word_probs]
|
108 |
return word_probs, word_colors
|
109 |
|