minemaster01 commited on
Commit
1f36bf0
·
verified ·
1 Parent(s): 142cc53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -22
app.py CHANGED
@@ -58,35 +58,63 @@ model.eval()
58
  # Inference function
59
 
60
  def get_word_probabilities(text):
61
- text = " ".join(text.split(" ")[:2048])
62
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
63
- inputs = {k: v.to(device) for k, v in inputs.items()}
64
- tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
65
- with torch.no_grad():
66
- tags, emission = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
67
- probs = torch.softmax(emission, dim=-1)[0, :, 1].cpu().numpy()
68
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  word_probs = []
70
  current_word = ""
71
  current_probs = []
72
 
73
- for token, prob in zip(tokens, probs):
74
- if token in ["<s>", "</s>"]:
75
- continue
76
- if token.startswith("▁"):
77
- if current_word and current_probs:
78
- word_probs.append(sum(current_probs) / len(current_probs))
79
- current_word = token[1:] if token != "▁" else ""
80
- current_probs = [prob]
81
- else:
82
- current_word += token
83
- current_probs.append(prob)
84
-
85
- if current_word and current_probs:
86
- word_probs.append(sum(current_probs) / len(current_probs))
 
 
 
87
 
88
  return word_probs
89
 
 
90
 
91
 
92
  # def get_word_classifications(text):
 
58
  # Inference function
59
 
60
  def get_word_probabilities(text):
61
+ try:
62
+ text = " ".join(text.split(" ")[:2048])
63
+ except Exception as e:
64
+ print("Error during text preprocessing:", e)
65
+ return []
66
+
67
+ try:
68
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
69
+ inputs = {k: v.to(device) for k, v in inputs.items()}
70
+ except Exception as e:
71
+ print("Error during tokenization or moving inputs to device:", e)
72
+ return []
73
+
74
+ try:
75
+ tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
76
+ except Exception as e:
77
+ print("Error during token conversion:", e)
78
+ return []
79
+
80
+ try:
81
+ with torch.no_grad():
82
+ tags, emission = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
83
+ except Exception as e:
84
+ print("Error during model inference:", e)
85
+ return []
86
+
87
+ try:
88
+ probs = torch.softmax(emission, dim=-1)[0, :, 1].cpu().numpy()
89
+ except Exception as e:
90
+ print("Error during softmax or extracting class probabilities:", e)
91
+ return []
92
+
93
  word_probs = []
94
  current_word = ""
95
  current_probs = []
96
 
97
+ try:
98
+ for token, prob in zip(tokens, probs):
99
+ if token in ["<s>", "</s>"]:
100
+ continue
101
+ if token.startswith("▁"):
102
+ if current_word and current_probs:
103
+ word_probs.append(sum(current_probs) / len(current_probs))
104
+ current_word = token[1:] if token != "▁" else ""
105
+ current_probs = [prob]
106
+ else:
107
+ current_word += token
108
+ current_probs.append(prob)
109
+ if current_word and current_probs:
110
+ word_probs.append(sum(current_probs) / len(current_probs))
111
+ except Exception as e:
112
+ print("Error during word aggregation:", e)
113
+ return []
114
 
115
  return word_probs
116
 
117
+
118
 
119
 
120
  # def get_word_classifications(text):