from fastapi import FastAPI | |
from onnxruntime import InferenceSession | |
import numpy as np | |
app = FastAPI() | |
# Load ONNX model only | |
session = InferenceSession("model.onnx") | |
async def predict(inputs: dict): | |
# Expect pre-tokenized input from client | |
input_ids = np.array(inputs["input_ids"], dtype=np.int64) | |
attention_mask = np.array(inputs["attention_mask"], dtype=np.int64) | |
# Run model | |
outputs = session.run(None, { | |
"input_ids": input_ids, | |
"attention_mask": attention_mask | |
}) | |
return {"embedding": outputs[0].tolist()} |