chryzxc commited on
Commit
0e2d401
·
verified ·
1 Parent(s): 20e9804

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -19
app.py CHANGED
@@ -1,38 +1,51 @@
1
- from fastapi import FastAPI, HTTPException
2
  from onnxruntime import InferenceSession
3
  from transformers import AutoTokenizer
4
  import numpy as np
5
  import os
 
6
 
7
  app = FastAPI()
8
 
9
- # Initialize tokenizer (doesn't require PyTorch/TensorFlow)
10
  tokenizer = AutoTokenizer.from_pretrained(
11
  "Xenova/multi-qa-mpnet-base-dot-v1",
12
- use_fast=True, # Uses Rust implementation
13
  legacy=False
14
  )
15
 
16
  # Load ONNX model
17
- session = InferenceSession("model.onnx")
 
 
 
 
 
18
 
19
  @app.get("/")
20
- def read_root():
21
- return {"status": "ONNX Model API is running"}
22
 
23
  @app.post("/api/predict")
24
- async def predict(text: str):
25
  try:
26
- # Tokenize without framework dependencies
 
 
 
 
 
 
 
27
  inputs = tokenizer(
28
  text,
29
- return_tensors="np", # Get NumPy arrays directly
30
- padding=True,
31
  truncation=True,
32
- max_length=32 # Match your model's expected input size
33
  )
34
 
35
- # Prepare ONNX inputs
36
  onnx_inputs = {
37
  "input_ids": inputs["input_ids"].astype(np.int64),
38
  "attention_mask": inputs["attention_mask"].astype(np.int64)
@@ -41,21 +54,21 @@ async def predict(text: str):
41
  # Run inference
42
  outputs = session.run(None, onnx_inputs)
43
 
44
- # Convert to native Python types
 
 
45
  return {
46
- "embedding": outputs[0].astype(np.float32).tolist(),
47
  "tokens": tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
48
  }
49
 
50
  except Exception as e:
51
- raise HTTPException(status_code=400, detail=str(e))
52
 
53
  if __name__ == "__main__":
54
  uvicorn.run(
55
- app,
56
  host="0.0.0.0",
57
  port=7860,
58
- # Required for Spaces:
59
- proxy_headers=True,
60
- forwarded_allow_ips="*"
61
  )
 
1
+ from fastapi import FastAPI, HTTPException, Request
2
  from onnxruntime import InferenceSession
3
  from transformers import AutoTokenizer
4
  import numpy as np
5
  import os
6
+ import uvicorn
7
 
8
  app = FastAPI()
9
 
10
+ # Initialize tokenizer
11
  tokenizer = AutoTokenizer.from_pretrained(
12
  "Xenova/multi-qa-mpnet-base-dot-v1",
13
+ use_fast=True,
14
  legacy=False
15
  )
16
 
17
  # Load ONNX model
18
+ try:
19
+ session = InferenceSession("model.onnx")
20
+ print("Model loaded successfully")
21
+ except Exception as e:
22
+ print(f"Failed to load model: {str(e)}")
23
+ raise
24
 
25
  @app.get("/")
26
+ def health_check():
27
+ return {"status": "OK", "model": "ONNX"}
28
 
29
  @app.post("/api/predict")
30
+ async def predict(request: Request):
31
  try:
32
+ # Get JSON input
33
+ data = await request.json()
34
+ text = data.get("text", "")
35
+
36
+ if not text:
37
+ raise HTTPException(status_code=400, detail="No text provided")
38
+
39
+ # Tokenize input
40
  inputs = tokenizer(
41
  text,
42
+ return_tensors="np",
43
+ padding="max_length",
44
  truncation=True,
45
+ max_length=32
46
  )
47
 
48
+ # Prepare ONNX inputs with correct shapes
49
  onnx_inputs = {
50
  "input_ids": inputs["input_ids"].astype(np.int64),
51
  "attention_mask": inputs["attention_mask"].astype(np.int64)
 
54
  # Run inference
55
  outputs = session.run(None, onnx_inputs)
56
 
57
+ # Convert outputs to list and handle numpy types
58
+ embedding = outputs[0][0].astype(float).tolist() # First output, first batch
59
+
60
  return {
61
+ "embedding": embedding,
62
  "tokens": tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
63
  }
64
 
65
  except Exception as e:
66
+ raise HTTPException(status_code=500, detail=str(e))
67
 
68
  if __name__ == "__main__":
69
  uvicorn.run(
70
+ "app:app",
71
  host="0.0.0.0",
72
  port=7860,
73
+ reload=False
 
 
74
  )