minemaster01 commited on
Commit
a832ce3
·
verified ·
1 Parent(s): c039580

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -2
app.py CHANGED
@@ -54,6 +54,16 @@ model.load_state_dict(checkpoint.get("model_state_dict", checkpoint), strict=Fal
54
  model = model.to(device)
55
  model.eval()
56
 
 
 
 
 
 
 
 
 
 
 
57
  def get_word_probabilities(text):
58
  text = " ".join(text.split(" ")[:2048])
59
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
@@ -73,7 +83,7 @@ def get_word_probabilities(text):
73
  if current_word and current_probs:
74
  current_prob = sum(current_probs) / len(current_probs)
75
  word_probs.append(current_prob)
76
- color = ("green" if current_prob < 0.25 else "yellow" if current_prob < 0.5 else "orange" if current_prob < 0.75 else "red")
77
  word_colors.append(color)
78
  current_word = token[1:] if token != "▁" else ""
79
  current_probs = [prob]
@@ -83,8 +93,17 @@ def get_word_probabilities(text):
83
  if current_word and current_probs:
84
  current_prob = sum(current_probs) / len(current_probs)
85
  word_probs.append(current_prob)
86
- color = ("green" if current_prob < 0.25 else "yellow" if current_prob < 0.5 else "orange" if current_prob < 0.75 else "red")
87
  word_colors.append(color)
 
 
 
 
 
 
 
 
 
88
  word_probs = [float(p) for p in word_probs]
89
  return word_probs, word_colors
90
 
 
54
  model = model.to(device)
55
  model.eval()
56
 
57
+ def get_color(prob):
58
+ if prob < 0.25:
59
+ return "green"
60
+ elif prob < 0.5:
61
+ return "yellow"
62
+ elif prob < 0.75:
63
+ return "orange"
64
+ else:
65
+ return "red"
66
+
67
  def get_word_probabilities(text):
68
  text = " ".join(text.split(" ")[:2048])
69
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
 
83
  if current_word and current_probs:
84
  current_prob = sum(current_probs) / len(current_probs)
85
  word_probs.append(current_prob)
86
+ color = get_color(current_prob)
87
  word_colors.append(color)
88
  current_word = token[1:] if token != "▁" else ""
89
  current_probs = [prob]
 
93
  if current_word and current_probs:
94
  current_prob = sum(current_probs) / len(current_probs)
95
  word_probs.append(current_prob)
96
+ color = get_color(current_prob)
97
  word_colors.append(color)
98
+
99
+ ####### FOR STABLE OUTPUTS
100
+ first_avg = (word_probs[1] + word_probs[2]) / 2
101
+ word_colors[0] = get_color(first_avg)
102
+
103
+ last_avg = (word_probs[-2] + word_probs[-3]) / 2
104
+ word_colors[-1] = get_color(last_avg)
105
+ #########
106
+
107
  word_probs = [float(p) for p in word_probs]
108
  return word_probs, word_colors
109