File size: 1,983 Bytes
0e2d401 0d7d4cd 84f1ee8 589009a 84f1ee8 0e2d401 1e9ac73 589009a 1e9ac73 0e2d401 dd92c0c 80490de 589009a 0d7d4cd 80490de 0d7d4cd 4f729af 589009a 0e2d401 80490de 0e2d401 855d918 80490de 0d7d4cd 017d40d 0d7d4cd 0e2d401 0d7d4cd 017d40d d7d161f 3be5098 0d7d4cd 80490de 0e2d401 20e9804 0d7d4cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
from fastapi import FastAPI, HTTPException, Request
from fastapi.encoders import jsonable_encoder
from onnxruntime import InferenceSession
from transformers import AutoTokenizer
import numpy as np
import uvicorn
app = FastAPI()
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(
"Xenova/multi-qa-mpnet-base-dot-v1",
use_fast=True,
legacy=False
)
# Load ONNX model
session = InferenceSession("model.onnx")
def convert_output(value):
"""Recursively convert numpy types to native Python types"""
if isinstance(value, (np.generic, np.ndarray)):
if value.size == 1:
return float(value.item()) # Convert single values to float
return value.astype(float).tolist() # Convert arrays to list
elif isinstance(value, list):
return [convert_output(x) for x in value]
elif isinstance(value, dict):
return {k: convert_output(v) for k, v in value.items()}
return value
@app.post("/api/predict")
async def predict(request: Request):
try:
data = await request.json()
text = data.get("text", "")
if not text:
raise HTTPException(status_code=400, detail="No text provided")
# Tokenize input
inputs = tokenizer(text)
# Run model
outputs = session.run(None, {
"input_ids": inputs["input_ids"].astype(np.int64),
"attention_mask": inputs["attention_mask"].astype(np.int64)
})
# Prepare response with converted types
response = {
"embedding": convert_output(outputs[0]), # Process main output
"tokens": tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
}
print("embeddings", response["embedding"])
return jsonable_encoder(response)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860) |