chryzxc's picture
Update app.py
0e2d401 verified
raw
history blame
1.93 kB
from fastapi import FastAPI, HTTPException, Request
from onnxruntime import InferenceSession
from transformers import AutoTokenizer
import numpy as np
import os
import uvicorn
app = FastAPI()
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(
"Xenova/multi-qa-mpnet-base-dot-v1",
use_fast=True,
legacy=False
)
# Load ONNX model
try:
session = InferenceSession("model.onnx")
print("Model loaded successfully")
except Exception as e:
print(f"Failed to load model: {str(e)}")
raise
@app.get("/")
def health_check():
return {"status": "OK", "model": "ONNX"}
@app.post("/api/predict")
async def predict(request: Request):
try:
# Get JSON input
data = await request.json()
text = data.get("text", "")
if not text:
raise HTTPException(status_code=400, detail="No text provided")
# Tokenize input
inputs = tokenizer(
text,
return_tensors="np",
padding="max_length",
truncation=True,
max_length=32
)
# Prepare ONNX inputs with correct shapes
onnx_inputs = {
"input_ids": inputs["input_ids"].astype(np.int64),
"attention_mask": inputs["attention_mask"].astype(np.int64)
}
# Run inference
outputs = session.run(None, onnx_inputs)
# Convert outputs to list and handle numpy types
embedding = outputs[0][0].astype(float).tolist() # First output, first batch
return {
"embedding": embedding,
"tokens": tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(
"app:app",
host="0.0.0.0",
port=7860,
reload=False
)