|
|
|
from fastapi import FastAPI |
|
from pydantic import BaseModel |
|
import onnxruntime as ort |
|
import numpy as np |
|
|
|
app = FastAPI() |
|
|
|
session = ort.InferenceSession("model.onnx") |
|
|
|
class ModelInput(BaseModel): |
|
input_ids: list[int] |
|
attention_mask: list[int] |
|
|
|
@app.post("/predict") |
|
def predict(data: ModelInput): |
|
input_ids = np.array(data.input_ids, dtype=np.int64).reshape(1, -1) |
|
attention_mask = np.array(data.attention_mask, dtype=np.int64).reshape(1, -1) |
|
inputs = { |
|
"input_ids": input_ids, |
|
"attention_mask": attention_mask, |
|
} |
|
outputs = session.run(None, inputs) |
|
return {"output": outputs} |
|
|