chryzxc commited on
Commit
658ebc3
·
verified ·
1 Parent(s): d6190c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -14
app.py CHANGED
@@ -1,10 +1,12 @@
1
- from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from onnxruntime import InferenceSession
4
  import numpy as np
5
  import os
 
6
 
7
- app = FastAPI(title="ONNX Model API")
 
8
 
9
  # CORS configuration
10
  app.add_middleware(
@@ -15,20 +17,28 @@ app.add_middleware(
15
  )
16
 
17
  # Load ONNX model
18
- model_path = os.path.join(os.getcwd(), "model.onnx")
19
- session = InferenceSession(model_path)
 
 
 
 
20
 
21
  @app.get("/")
22
- def health_check():
23
- return {"status": "healthy", "message": "ONNX model is ready"}
24
 
25
- @app.post("/predict")
26
- async def predict(inputs: dict):
27
- """Expects {'input_ids': [], 'attention_mask': []}"""
28
  try:
29
- input_ids = np.array(inputs["input_ids"], dtype=np.int64).reshape(1, -1)
30
- attention_mask = np.array(inputs["attention_mask"], dtype=np.int64).reshape(1, -1)
31
 
 
 
 
 
 
32
  outputs = session.run(
33
  None,
34
  {
@@ -40,9 +50,13 @@ async def predict(inputs: dict):
40
  return {"embedding": outputs[0].tolist()}
41
 
42
  except Exception as e:
43
- return {"error": str(e)}
44
 
45
  # Required for Hugging Face Spaces
46
  if __name__ == "__main__":
47
- import uvicorn
48
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
1
+ from fastapi import FastAPI, Request, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from onnxruntime import InferenceSession
4
  import numpy as np
5
  import os
6
+ import uvicorn
7
 
8
+ # Initialize FastAPI with docs disabled for Spaces
9
+ app = FastAPI(docs_url=None, redoc_url=None)
10
 
11
  # CORS configuration
12
  app.add_middleware(
 
17
  )
18
 
19
  # Load ONNX model
20
+ try:
21
+ session = InferenceSession("model.onnx")
22
+ print("Model loaded successfully")
23
+ except Exception as e:
24
+ print(f"Model loading failed: {str(e)}")
25
+ raise
26
 
27
  @app.get("/")
28
+ async def health_check():
29
+ return {"status": "ready", "model": "onnx"}
30
 
31
+ @app.post("/api/predict")
32
+ async def predict(request: Request):
 
33
  try:
34
+ # Get JSON input
35
+ data = await request.json()
36
 
37
+ # Convert to numpy arrays with correct shape
38
+ input_ids = np.array(data["input_ids"], dtype=np.int64).reshape(1, -1)
39
+ attention_mask = np.array(data["attention_mask"], dtype=np.int64).reshape(1, -1)
40
+
41
+ # Run inference
42
  outputs = session.run(
43
  None,
44
  {
 
50
  return {"embedding": outputs[0].tolist()}
51
 
52
  except Exception as e:
53
+ raise HTTPException(status_code=400, detail=str(e))
54
 
55
  # Required for Hugging Face Spaces
56
  if __name__ == "__main__":
57
+ uvicorn.run(
58
+ "app:app",
59
+ host="0.0.0.0",
60
+ port=7860,
61
+ reload=False
62
+ )