chryzxc commited on
Commit
017d40d
·
verified ·
1 Parent(s): d7d161f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -36
app.py CHANGED
@@ -1,11 +1,12 @@
1
- from fastapi import FastAPI, Request, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
 
3
  from onnxruntime import InferenceSession
4
  import numpy as np
5
  import os
6
- import uvicorn
7
 
8
- app = FastAPI(title="ONNX Model API")
9
 
10
  # CORS configuration
11
  app.add_middleware(
@@ -15,47 +16,50 @@ app.add_middleware(
15
  allow_headers=["*"],
16
  )
17
 
18
- # Load ONNX model
 
19
  session = InferenceSession("model.onnx")
20
 
21
- # Essential for Spaces health checks
22
- @app.get("/")
23
- def read_root():
24
- return {"status": "ONNX Model API is running"}
 
25
 
26
- # Main prediction endpoint
27
- @app.post("/predict")
28
- async def predict(request: Request):
29
  try:
30
- data = await request.json()
31
- input_ids = np.array(data["input_ids"], dtype=np.int64).reshape(1, -1)
32
- attention_mask = np.array(data["attention_mask"], dtype=np.int64).reshape(1, -1)
33
 
34
- outputs = session.run(None, {
35
- "input_ids": input_ids,
36
- "attention_mask": attention_mask
37
- })
 
 
 
 
38
 
39
- result = {
40
- "embedding": outputs[0].astype(np.float32).tolist() # Force float32 conversion
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  }
42
 
43
- return jsonable_encoder(result)
44
-
45
  except Exception as e:
46
  raise HTTPException(status_code=400, detail=str(e))
47
 
48
- # Special endpoint for Spaces compatibility
49
- @app.post("/api/predict")
50
- async def spaces_predict(request: Request):
51
- return await predict(request)
52
-
53
- if __name__ == "__main__":
54
- uvicorn.run(
55
- app,
56
- host="0.0.0.0",
57
- port=7860,
58
- # Required for Spaces:
59
- proxy_headers=True,
60
- forwarded_allow_ips="*"
61
- )
 
1
+ from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from transformers import AutoTokenizer
4
  from onnxruntime import InferenceSession
5
  import numpy as np
6
  import os
7
+ from typing import Dict
8
 
9
+ app = FastAPI(title="ONNX Model API with Tokenizer")
10
 
11
  # CORS configuration
12
  app.add_middleware(
 
16
  allow_headers=["*"],
17
  )
18
 
19
+ # Initialize components
20
+ tokenizer = AutoTokenizer.from_pretrained("Xenova/multi-qa-mpnet-base-dot-v1")
21
  session = InferenceSession("model.onnx")
22
 
23
+ def convert_outputs(outputs):
24
+ """Ensure all numpy values are converted to Python native types"""
25
+ if isinstance(outputs, (np.generic, np.ndarray)):
26
+ return outputs.item() if outputs.ndim == 0 else outputs.tolist()
27
+ return outputs
28
 
29
+ @app.post("/api/process")
30
+ async def process_text(request: Dict[str, str]):
 
31
  try:
32
+ text = request.get("text", "")
 
 
33
 
34
+ # Tokenize the input text
35
+ inputs = tokenizer(
36
+ text,
37
+ return_tensors="np",
38
+ padding=True,
39
+ truncation=True,
40
+ max_length=32 # Match your model's expected input size
41
+ )
42
 
43
+ # Convert to ONNX-compatible format
44
+ onnx_inputs = {
45
+ "input_ids": inputs["input_ids"].astype(np.int64),
46
+ "attention_mask": inputs["attention_mask"].astype(np.int64)
47
+ }
48
+
49
+ # Run model inference
50
+ outputs = session.run(None, onnx_inputs)
51
+
52
+ # Convert all numpy types to native Python types
53
+ processed_outputs = [convert_outputs(output) for output in outputs]
54
+
55
+ return {
56
+ "embedding": processed_outputs[0], # Assuming first output is embeddings
57
+ "tokens": tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
58
  }
59
 
 
 
60
  except Exception as e:
61
  raise HTTPException(status_code=400, detail=str(e))
62
 
63
+ @app.get("/health")
64
+ async def health_check():
65
+ return {"status": "healthy"}