Spaces:
Sleeping
Sleeping
Chidam Gopal
commited on
updates for onnx
Browse files- infer_intent.py +19 -14
infer_intent.py
CHANGED
|
@@ -41,20 +41,25 @@ class IntentClassifier:
|
|
| 41 |
truncation=True, # Truncate if the text is too long
|
| 42 |
max_length=64)
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
|
| 60 |
def main():
|
|
|
|
| 41 |
truncation=True, # Truncate if the text is too long
|
| 42 |
max_length=64)
|
| 43 |
|
| 44 |
+
# Convert inputs to NumPy arrays
|
| 45 |
+
onnx_inputs = {k: v for k, v in inputs.items()}
|
| 46 |
+
|
| 47 |
+
# Run the ONNX model
|
| 48 |
+
logits = self.ort_session.run(None, onnx_inputs)[0]
|
| 49 |
+
|
| 50 |
+
# Get the prediction
|
| 51 |
+
prediction = np.argmax(logits, axis=1)[0]
|
| 52 |
+
probabilities = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
|
| 53 |
+
rounded_probabilities = np.round(probabilities, decimals=3)
|
| 54 |
+
|
| 55 |
+
pred_result = self.id2label[prediction]
|
| 56 |
+
proba_result = dict(zip(self.label2id.keys(), rounded_probabilities[0].tolist()))
|
| 57 |
+
|
| 58 |
+
if verbose:
|
| 59 |
+
print(sequence + " -> " + pred_result)
|
| 60 |
+
print(proba_result, "\n")
|
| 61 |
+
|
| 62 |
+
return pred_result, proba_result
|
| 63 |
|
| 64 |
|
| 65 |
def main():
|