File size: 1,934 Bytes
0e2d401 84f1ee8 589009a 84f1ee8 80490de 0e2d401 1e9ac73 589009a 1e9ac73 0e2d401 589009a 0e2d401 589009a 80490de 589009a 0e2d401 80490de 4f729af 0e2d401 4f729af 589009a 0e2d401 80490de 0e2d401 017d40d 0e2d401 017d40d 0e2d401 017d40d 80490de 0e2d401 017d40d 589009a 017d40d 0e2d401 017d40d 0e2d401 017d40d d7d161f 80490de 0e2d401 20e9804 0e2d401 20e9804 0e2d401 20e9804 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
from fastapi import FastAPI, HTTPException, Request
from onnxruntime import InferenceSession
from transformers import AutoTokenizer
import numpy as np
import os
import uvicorn
app = FastAPI()
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(
"Xenova/multi-qa-mpnet-base-dot-v1",
use_fast=True,
legacy=False
)
# Load ONNX model
try:
session = InferenceSession("model.onnx")
print("Model loaded successfully")
except Exception as e:
print(f"Failed to load model: {str(e)}")
raise
@app.get("/")
def health_check():
return {"status": "OK", "model": "ONNX"}
@app.post("/api/predict")
async def predict(request: Request):
try:
# Get JSON input
data = await request.json()
text = data.get("text", "")
if not text:
raise HTTPException(status_code=400, detail="No text provided")
# Tokenize input
inputs = tokenizer(
text,
return_tensors="np",
padding="max_length",
truncation=True,
max_length=32
)
# Prepare ONNX inputs with correct shapes
onnx_inputs = {
"input_ids": inputs["input_ids"].astype(np.int64),
"attention_mask": inputs["attention_mask"].astype(np.int64)
}
# Run inference
outputs = session.run(None, onnx_inputs)
# Convert outputs to list and handle numpy types
embedding = outputs[0][0].astype(float).tolist() # First output, first batch
return {
"embedding": embedding,
"tokens": tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(
"app:app",
host="0.0.0.0",
port=7860,
reload=False
) |