minemaster01 commited on
Commit
c9abe45
·
verified ·
1 Parent(s): 5a64fc3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -57
app.py CHANGED
@@ -56,67 +56,59 @@ model = model.to(device)
56
  model.eval()
57
 
58
  # Inference function
59
-
60
  def get_word_probabilities(text):
61
- try:
62
- text = " ".join(text.split(" ")[:2048])
63
- except Exception as e:
64
- print("Error during text preprocessing:", e)
65
- return []
66
-
67
- try:
68
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
69
- inputs = {k: v.to(device) for k, v in inputs.items()}
70
- except Exception as e:
71
- print("Error during tokenization or moving inputs to device:", e)
72
- return []
73
-
74
- try:
75
- tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
76
- except Exception as e:
77
- print("Error during token conversion:", e)
78
- return []
79
-
80
- try:
81
- with torch.no_grad():
82
- tags, emission = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
83
- except Exception as e:
84
- print("Error during model inference:", e)
85
- return []
86
-
87
- try:
88
- probs = torch.softmax(emission, dim=-1)[0, :, 1].cpu().numpy()
89
- except Exception as e:
90
- print("Error during softmax or extracting class probabilities:", e)
91
- return []
92
-
93
  word_probs = []
 
94
  current_word = ""
95
  current_probs = []
96
 
97
- try:
98
- for token, prob in zip(tokens, probs):
99
- if token in ["<s>", "</s>"]:
100
- continue
101
- if token.startswith("▁"):
102
- if current_word and current_probs:
103
- word_probs.append(sum(current_probs) / len(current_probs))
104
- current_word = token[1:] if token != "▁" else ""
105
- current_probs = [prob]
106
- else:
107
- current_word += token
108
- current_probs.append(prob)
109
- if current_word and current_probs:
110
- word_probs.append(sum(current_probs) / len(current_probs))
111
- except Exception as e:
112
- print("Error during word aggregation:", e)
113
- return []
114
- word_probs = [float(p) for p in word_probs]
115
-
116
- return word_probs
117
-
118
-
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
 
 
120
 
121
  # def get_word_classifications(text):
122
  # text = " ".join(text.split(" ")[:2048])
@@ -186,7 +178,7 @@ def infer_and_log(text_input):
186
  "id": submission_id,
187
  "timestamp": timestamp,
188
  "input": text_input,
189
- "output_tags": word_tags
190
  }
191
 
192
  os.makedirs("logs", exist_ok=True)
@@ -207,7 +199,7 @@ def infer_and_log(text_input):
207
  except Exception as e:
208
  print(f"Error uploading log: {e}")
209
 
210
- return json.dumps(word_tags, indent=2)
211
 
212
 
213
  def clear_fields():
 
56
  model.eval()
57
 
58
  # Inference function
59
+
60
  def get_word_probabilities(text):
61
+ text = " ".join(text.split(" ")[:2048])
62
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
63
+ inputs = {k: v.to(device) for k, v in inputs.items()}
64
+ tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
65
+ with torch.no_grad():
66
+ tags, emission = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
67
+ probs = torch.softmax(emission, dim=-1)[0, :, 1].cpu().numpy()
68
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  word_probs = []
70
+ word_colors = []
71
  current_word = ""
72
  current_probs = []
73
 
74
+ for token, prob in zip(tokens, probs):
75
+ if token in ["<s>", "</s>"]:
76
+ continue
77
+ if token.startswith("▁"):
78
+ if current_word and current_probs:
79
+ current_prob = sum(current_probs) / len(current_probs)
80
+ word_probs.append(current_prob)
81
+
82
+ # Determine color based on probability
83
+ color = (
84
+ "green" if current_prob < 0.25 else
85
+ "yellow" if current_prob < 0.5 else
86
+ "orange" if current_prob < 0.75 else
87
+ "red"
88
+ )
89
+ word_colors.append(color)
90
+
91
+ current_word = token[1:] if token != "▁" else ""
92
+ current_probs = [prob]
93
+ else:
94
+ current_word += token
95
+ current_probs.append(prob)
96
+
97
+ if current_word and current_probs:
98
+ current_prob = sum(current_probs) / len(current_probs)
99
+ word_probs.append(current_prob)
100
+
101
+ # Determine color for the last word
102
+ color = (
103
+ "green" if current_prob < 0.25 else
104
+ "yellow" if current_prob < 0.5 else
105
+ "orange" if current_prob < 0.75 else
106
+ "red"
107
+ )
108
+ word_colors.append(color)
109
 
110
+ word_probs = [float(p) for p in word_probs]
111
+ return word_probs,
112
 
113
  # def get_word_classifications(text):
114
  # text = " ".join(text.split(" ")[:2048])
 
178
  "id": submission_id,
179
  "timestamp": timestamp,
180
  "input": text_input,
181
+ "output_probs": word_probs
182
  }
183
 
184
  os.makedirs("logs", exist_ok=True)
 
199
  except Exception as e:
200
  print(f"Error uploading log: {e}")
201
 
202
+ return json.dumps(word_probs, indent=2)
203
 
204
 
205
  def clear_fields():