chryzxc commited on
Commit
80490de
·
verified ·
1 Parent(s): befb5c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -13
app.py CHANGED
@@ -1,22 +1,48 @@
1
  from fastapi import FastAPI
 
2
  from onnxruntime import InferenceSession
3
  import numpy as np
 
4
 
5
- app = FastAPI()
6
 
7
- # Load ONNX model only
8
- # session = InferenceSession("model.onnx")
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  @app.post("/predict")
11
  async def predict(inputs: dict):
12
- # Expect pre-tokenized input from client
13
- ##input_ids = np.array(inputs["input_ids"], dtype=np.int64)
14
- #attention_mask = np.array(inputs["attention_mask"], dtype=np.int64)
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- # Run model
17
- #outputs = session.run(None, {
18
- # "input_ids": input_ids,
19
- # "attention_mask": attention_mask
20
- #})
21
- return "Status ok"
22
- #return {"embedding": outputs[0].tolist()}
 
1
  from fastapi import FastAPI
2
+ from fastapi.middleware.cors import CORSMiddleware
3
  from onnxruntime import InferenceSession
4
  import numpy as np
5
+ import os
6
 
7
+ app = FastAPI(title="ONNX Model API")
8
 
9
+ # CORS configuration
10
+ app.add_middleware(
11
+ CORSMiddleware,
12
+ allow_origins=["*"],
13
+ allow_methods=["*"],
14
+ allow_headers=["*"],
15
+ )
16
+
17
+ # Load ONNX model
18
+ model_path = os.path.join(os.getcwd(), "model.onnx")
19
+ session = InferenceSession(model_path)
20
+
21
+ @app.get("/")
22
+ def health_check():
23
+ return {"status": "healthy", "message": "ONNX model is ready"}
24
 
25
  @app.post("/predict")
26
  async def predict(inputs: dict):
27
+ """Expects {'input_ids': [], 'attention_mask': []}"""
28
+ try:
29
+ input_ids = np.array(inputs["input_ids"], dtype=np.int64).reshape(1, -1)
30
+ attention_mask = np.array(inputs["attention_mask"], dtype=np.int64).reshape(1, -1)
31
+
32
+ outputs = session.run(
33
+ None,
34
+ {
35
+ "input_ids": input_ids,
36
+ "attention_mask": attention_mask
37
+ }
38
+ )
39
+
40
+ return {"embedding": outputs[0].tolist()}
41
 
42
+ except Exception as e:
43
+ return {"error": str(e)}
44
+
45
+ # Required for Hugging Face Spaces
46
+ if __name__ == "__main__":
47
+ import uvicorn
48
+ uvicorn.run(app, host="0.0.0.0", port=7860)