File size: 2,200 Bytes
0e2d401 0d7d4cd 84f1ee8 589009a 84f1ee8 0e2d401 1e9ac73 589009a 1e9ac73 0e2d401 dd92c0c 80490de 589009a 0d7d4cd 80490de 0d7d4cd 4f729af 589009a 0e2d401 80490de 0e2d401 6d7e0c5 80490de 6d7e0c5 c47d871 0d7d4cd 6d7e0c5 bfdd5dc 6d7e0c5 d7d161f 0d7d4cd 80490de 0e2d401 20e9804 0d7d4cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
from fastapi import FastAPI, HTTPException, Request
from fastapi.encoders import jsonable_encoder
from onnxruntime import InferenceSession
from transformers import AutoTokenizer
import numpy as np
import uvicorn
app = FastAPI()
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(
"Xenova/multi-qa-mpnet-base-dot-v1",
use_fast=True,
legacy=False
)
# Load ONNX model
session = InferenceSession("model.onnx")
def convert_output(value):
"""Recursively convert numpy types to native Python types"""
if isinstance(value, (np.generic, np.ndarray)):
if value.size == 1:
return float(value.item()) # Convert single values to float
return value.astype(float).tolist() # Convert arrays to list
elif isinstance(value, list):
return [convert_output(x) for x in value]
elif isinstance(value, dict):
return {k: convert_output(v) for k, v in value.items()}
return value
@app.post("/api/predict")
async def predict(request: Request):
try:
data = await request.json()
text = data.get("text", "")
if not text:
raise HTTPException(status_code=400, detail="No text provided")
# Tokenize input
inputs = tokenizer(
text,
return_tensors="np",
padding=False, # Disable padding
truncation=False, # Disable truncation
add_special_tokens=True # Ensure CLS/SEP tokens
)
onnx_inputs = {
"input_ids": np.array(inputs["input_ids"], dtype=np.int64),
"attention_mask": np.array(inputs["attention_mask"], dtype=np.int64)
}
outputs = session.run(None, onnx_inputs)
print("OUTPUTS",outputs)
# Prepare response with converted types
return {
"embedding": outputs[0][0].astype(float).tolist(),
"input_ids": inputs["input_ids"][0].tolist(),
"attention_mask": inputs["attention_mask"][0].tolist()
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860) |