File size: 1,602 Bytes
658ebc3 80490de 84f1ee8 80490de 658ebc3 1e9ac73 4afa954 1e9ac73 80490de 4afa954 80490de 4afa954 80490de 4afa954 84f1ee8 4afa954 658ebc3 80490de 658ebc3 4afa954 80490de d7d161f 84f505f 80490de 658ebc3 80490de 4afa954 80490de 658ebc3 4afa954 658ebc3 4afa954 658ebc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from onnxruntime import InferenceSession
import numpy as np
import os
import uvicorn
app = FastAPI(title="ONNX Model API")
# CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Load ONNX model
session = InferenceSession("model.onnx")
# Essential for Spaces health checks
@app.get("/")
def read_root():
return {"status": "ONNX Model API is running"}
# Main prediction endpoint
@app.post("/predict")
async def predict(request: Request):
try:
data = await request.json()
input_ids = np.array(data["input_ids"], dtype=np.int64).reshape(1, -1)
attention_mask = np.array(data["attention_mask"], dtype=np.int64).reshape(1, -1)
outputs = session.run(None, {
"input_ids": input_ids,
"attention_mask": attention_mask
})
result = {
"embedding": outputs[0].astype(np.float32).tolist() # Force float32 conversion
}
return jsonable_encoder(result)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
# Special endpoint for Spaces compatibility
@app.post("/api/predict")
async def spaces_predict(request: Request):
return await predict(request)
if __name__ == "__main__":
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
# Required for Spaces:
proxy_headers=True,
forwarded_allow_ips="*"
) |