chryzxc's picture
Update app.py
4afa954 verified
raw
history blame
1.48 kB
from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from onnxruntime import InferenceSession
import numpy as np
import os
import uvicorn
app = FastAPI(title="ONNX Model API")
# CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Load ONNX model
session = InferenceSession("model.onnx")
# Essential for Spaces health checks
@app.get("/")
def read_root():
return {"status": "ONNX Model API is running"}
# Main prediction endpoint
@app.post("/predict")
async def predict(request: Request):
try:
data = await request.json()
input_ids = np.array(data["input_ids"], dtype=np.int64).reshape(1, -1)
attention_mask = np.array(data["attention_mask"], dtype=np.int64).reshape(1, -1)
outputs = session.run(None, {
"input_ids": input_ids,
"attention_mask": attention_mask
})
return {"embedding": outputs[0].tolist()}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
# Special endpoint for Spaces compatibility
@app.post("/api/predict")
async def spaces_predict(request: Request):
return await predict(request)
if __name__ == "__main__":
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
# Required for Spaces:
proxy_headers=True,
forwarded_allow_ips="*"
)