chryzxc's picture
Update app.py
84f505f verified
raw
history blame
758 Bytes
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import numpy as np
from onnxruntime import InferenceSession
from transformers import AutoTokenizer
import os
app = FastAPI()
# CORS setup
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Load model
session = InferenceSession("model.onnx")
tokenizer = AutoTokenizer.from_pretrained("Xenova/multi-qa-mpnet-base-dot-v1")
@app.post("/predict")
async def predict(query: str):
inputs = tokenizer(query, return_tensors="np")
inputs = {k: v.astype(np.int64) for k, v in inputs.items()}
outputs = session.run(None, inputs)
embedding = outputs[0][0].tolist()
return {"embedding": embedding}