File size: 1,602 Bytes
658ebc3
80490de
84f1ee8
 
80490de
658ebc3
1e9ac73
4afa954
1e9ac73
80490de
 
 
 
 
 
 
 
 
4afa954
80490de
4afa954
80490de
4afa954
 
84f1ee8
4afa954
 
658ebc3
80490de
658ebc3
 
 
 
4afa954
 
 
 
80490de
d7d161f
 
 
 
 
84f505f
80490de
658ebc3
80490de
4afa954
 
 
 
 
80490de
658ebc3
4afa954
658ebc3
 
4afa954
 
 
658ebc3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from onnxruntime import InferenceSession
import numpy as np
import os
import uvicorn

app = FastAPI(title="ONNX Model API")

# CORS configuration
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# Load ONNX model
session = InferenceSession("model.onnx")

# Essential for Spaces health checks
@app.get("/")
def read_root():
    return {"status": "ONNX Model API is running"}

# Main prediction endpoint
@app.post("/predict")
async def predict(request: Request):
    try:
        data = await request.json()
        input_ids = np.array(data["input_ids"], dtype=np.int64).reshape(1, -1)
        attention_mask = np.array(data["attention_mask"], dtype=np.int64).reshape(1, -1)
        
        outputs = session.run(None, {
            "input_ids": input_ids,
            "attention_mask": attention_mask
        })
        
        result = {
            "embedding": outputs[0].astype(np.float32).tolist()  # Force float32 conversion
        }
        
        return jsonable_encoder(result)
    
    except Exception as e:
        raise HTTPException(status_code=400, detail=str(e))

# Special endpoint for Spaces compatibility
@app.post("/api/predict")
async def spaces_predict(request: Request):
    return await predict(request)

if __name__ == "__main__":
    uvicorn.run(
        app,
        host="0.0.0.0",
        port=7860,
        # Required for Spaces:
        proxy_headers=True,
        forwarded_allow_ips="*"
    )