chryzxc's picture
Update app.py
658ebc3 verified
raw
history blame
1.59 kB
from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from onnxruntime import InferenceSession
import numpy as np
import os
import uvicorn
# Initialize FastAPI with docs disabled for Spaces
app = FastAPI(docs_url=None, redoc_url=None)
# CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Load ONNX model
try:
session = InferenceSession("model.onnx")
print("Model loaded successfully")
except Exception as e:
print(f"Model loading failed: {str(e)}")
raise
@app.get("/")
async def health_check():
return {"status": "ready", "model": "onnx"}
@app.post("/api/predict")
async def predict(request: Request):
try:
# Get JSON input
data = await request.json()
# Convert to numpy arrays with correct shape
input_ids = np.array(data["input_ids"], dtype=np.int64).reshape(1, -1)
attention_mask = np.array(data["attention_mask"], dtype=np.int64).reshape(1, -1)
# Run inference
outputs = session.run(
None,
{
"input_ids": input_ids,
"attention_mask": attention_mask
}
)
return {"embedding": outputs[0].tolist()}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
# Required for Hugging Face Spaces
if __name__ == "__main__":
uvicorn.run(
"app:app",
host="0.0.0.0",
port=7860,
reload=False
)