rectified onnx inference code
Browse files
README.md
CHANGED
|
@@ -106,7 +106,7 @@ predicted_class_id = logits.argmax().item()
|
|
| 106 |
loaded_model.config.id2label[predicted_class_id]
|
| 107 |
```
|
| 108 |
|
| 109 |
-
Optimum with ONNX
|
| 110 |
|
| 111 |
Loading the model requires the 🤗 Optimum library installed.
|
| 112 |
```shell
|
|
@@ -115,12 +115,12 @@ pip install transformers optimum[onnxruntime] optimum
|
|
| 115 |
|
| 116 |
```python
|
| 117 |
model_path = "philomath-1209/programming-language-identification"
|
| 118 |
-
|
| 119 |
from transformers import pipeline, AutoTokenizer
|
| 120 |
from optimum.onnxruntime import ORTModelForSequenceClassification
|
| 121 |
|
| 122 |
-
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 123 |
-
model = ORTModelForSequenceClassification.from_pretrained(model_path, export=
|
| 124 |
|
| 125 |
text = """
|
| 126 |
PROGRAM Triangle
|
|
@@ -141,9 +141,10 @@ text = """
|
|
| 141 |
END FUNCTION Area
|
| 142 |
|
| 143 |
"""
|
| 144 |
-
inputs =
|
| 145 |
with torch.no_grad():
|
| 146 |
-
logits =
|
| 147 |
predicted_class_id = logits.argmax().item()
|
| 148 |
-
|
|
|
|
| 149 |
```
|
|
|
|
| 106 |
loaded_model.config.id2label[predicted_class_id]
|
| 107 |
```
|
| 108 |
|
| 109 |
+
### Optimum with ONNX inference
|
| 110 |
|
| 111 |
Loading the model requires the 🤗 Optimum library installed.
|
| 112 |
```shell
|
|
|
|
| 115 |
|
| 116 |
```python
|
| 117 |
model_path = "philomath-1209/programming-language-identification"
|
| 118 |
+
import torch
|
| 119 |
from transformers import pipeline, AutoTokenizer
|
| 120 |
from optimum.onnxruntime import ORTModelForSequenceClassification
|
| 121 |
|
| 122 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="onnx")
|
| 123 |
+
model = ORTModelForSequenceClassification.from_pretrained(model_path, export=False, subfolder="onnx")
|
| 124 |
|
| 125 |
text = """
|
| 126 |
PROGRAM Triangle
|
|
|
|
| 141 |
END FUNCTION Area
|
| 142 |
|
| 143 |
"""
|
| 144 |
+
inputs = tokenizer(text, return_tensors="pt",truncation=True)
|
| 145 |
with torch.no_grad():
|
| 146 |
+
logits = model(**inputs).logits
|
| 147 |
predicted_class_id = logits.argmax().item()
|
| 148 |
+
model.config.id2label[predicted_class_id]
|
| 149 |
+
|
| 150 |
```
|