File size: 636 Bytes
a3e1970 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
# main.py
from fastapi import FastAPI
from pydantic import BaseModel
import onnxruntime as ort
import numpy as np
app = FastAPI()
session = ort.InferenceSession("model.onnx")
class ModelInput(BaseModel):
input_ids: list[int]
attention_mask: list[int]
@app.post("/predict")
def predict(data: ModelInput):
input_ids = np.array(data.input_ids, dtype=np.int64).reshape(1, -1)
attention_mask = np.array(data.attention_mask, dtype=np.int64).reshape(1, -1)
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
outputs = session.run(None, inputs)
return {"output": outputs}
|