chryzxc's picture
Create app.py
a3e1970 verified
raw
history blame
636 Bytes
# main.py
from fastapi import FastAPI
from pydantic import BaseModel
import onnxruntime as ort
import numpy as np
app = FastAPI()
session = ort.InferenceSession("model.onnx")
class ModelInput(BaseModel):
input_ids: list[int]
attention_mask: list[int]
@app.post("/predict")
def predict(data: ModelInput):
input_ids = np.array(data.input_ids, dtype=np.int64).reshape(1, -1)
attention_mask = np.array(data.attention_mask, dtype=np.int64).reshape(1, -1)
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
outputs = session.run(None, inputs)
return {"output": outputs}