File size: 2,019 Bytes
017d40d 80490de 017d40d 84f1ee8 80490de 017d40d 1e9ac73 017d40d 1e9ac73 80490de 017d40d 4afa954 80490de 017d40d 84f1ee8 017d40d 80490de 017d40d 658ebc3 017d40d 80490de 017d40d d7d161f 80490de 658ebc3 80490de 017d40d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer
from onnxruntime import InferenceSession
import numpy as np
import os
from typing import Dict
app = FastAPI(title="ONNX Model API with Tokenizer")
# CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize components
tokenizer = AutoTokenizer.from_pretrained("Xenova/multi-qa-mpnet-base-dot-v1")
session = InferenceSession("model.onnx")
def convert_outputs(outputs):
"""Ensure all numpy values are converted to Python native types"""
if isinstance(outputs, (np.generic, np.ndarray)):
return outputs.item() if outputs.ndim == 0 else outputs.tolist()
return outputs
@app.post("/api/process")
async def process_text(request: Dict[str, str]):
try:
text = request.get("text", "")
# Tokenize the input text
inputs = tokenizer(
text,
return_tensors="np",
padding=True,
truncation=True,
max_length=32 # Match your model's expected input size
)
# Convert to ONNX-compatible format
onnx_inputs = {
"input_ids": inputs["input_ids"].astype(np.int64),
"attention_mask": inputs["attention_mask"].astype(np.int64)
}
# Run model inference
outputs = session.run(None, onnx_inputs)
# Convert all numpy types to native Python types
processed_outputs = [convert_outputs(output) for output in outputs]
return {
"embedding": processed_outputs[0], # Assuming first output is embeddings
"tokens": tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "healthy"} |