chryzxc's picture
Update app.py
589009a verified
raw
history blame
1.4 kB
from fastapi import FastAPI, HTTPException
from onnxruntime import InferenceSession
from transformers import AutoTokenizer
import numpy as np
import os
app = FastAPI()
# Initialize tokenizer (doesn't require PyTorch/TensorFlow)
tokenizer = AutoTokenizer.from_pretrained(
"Xenova/multi-qa-mpnet-base-dot-v1",
use_fast=True, # Uses Rust implementation
legacy=False
)
# Load ONNX model
session = InferenceSession("model.onnx")
@app.post("/api/predict")
async def predict(text: str):
try:
# Tokenize without framework dependencies
inputs = tokenizer(
text,
return_tensors="np", # Get NumPy arrays directly
padding=True,
truncation=True,
max_length=32 # Match your model's expected input size
)
# Prepare ONNX inputs
onnx_inputs = {
"input_ids": inputs["input_ids"].astype(np.int64),
"attention_mask": inputs["attention_mask"].astype(np.int64)
}
# Run inference
outputs = session.run(None, onnx_inputs)
# Convert to native Python types
return {
"embedding": outputs[0].astype(np.float32).tolist(),
"tokens": tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))