chryzxc's picture
Update app.py
017d40d verified
raw
history blame
2.02 kB
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer
from onnxruntime import InferenceSession
import numpy as np
import os
from typing import Dict
app = FastAPI(title="ONNX Model API with Tokenizer")
# CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize components
tokenizer = AutoTokenizer.from_pretrained("Xenova/multi-qa-mpnet-base-dot-v1")
session = InferenceSession("model.onnx")
def convert_outputs(outputs):
"""Ensure all numpy values are converted to Python native types"""
if isinstance(outputs, (np.generic, np.ndarray)):
return outputs.item() if outputs.ndim == 0 else outputs.tolist()
return outputs
@app.post("/api/process")
async def process_text(request: Dict[str, str]):
try:
text = request.get("text", "")
# Tokenize the input text
inputs = tokenizer(
text,
return_tensors="np",
padding=True,
truncation=True,
max_length=32 # Match your model's expected input size
)
# Convert to ONNX-compatible format
onnx_inputs = {
"input_ids": inputs["input_ids"].astype(np.int64),
"attention_mask": inputs["attention_mask"].astype(np.int64)
}
# Run model inference
outputs = session.run(None, onnx_inputs)
# Convert all numpy types to native Python types
processed_outputs = [convert_outputs(output) for output in outputs]
return {
"embedding": processed_outputs[0], # Assuming first output is embeddings
"tokens": tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "healthy"}