File size: 590 Bytes
32854a5
84f1ee8
 
1e9ac73
84f505f
1e9ac73
32854a5
84f505f
84f1ee8
84f505f
32854a5
 
 
 
84f505f
32854a5
 
 
 
 
84f1ee8
32854a5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from fastapi import FastAPI
from onnxruntime import InferenceSession
import numpy as np

app = FastAPI()

# Load ONNX model only
session = InferenceSession("model.onnx")

@app.post("/predict")
async def predict(inputs: dict):
    # Expect pre-tokenized input from client
    input_ids = np.array(inputs["input_ids"], dtype=np.int64)
    attention_mask = np.array(inputs["attention_mask"], dtype=np.int64)
    
    # Run model
    outputs = session.run(None, {
        "input_ids": input_ids,
        "attention_mask": attention_mask
    })
    
    return {"embedding": outputs[0].tolist()}