chryzxc commited on
Commit
589009a
·
verified ·
1 Parent(s): 600eac4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -35
app.py CHANGED
@@ -1,65 +1,47 @@
1
  from fastapi import FastAPI, HTTPException
2
- from fastapi.middleware.cors import CORSMiddleware
3
- from transformers import AutoTokenizer
4
  from onnxruntime import InferenceSession
 
5
  import numpy as np
6
  import os
7
- from typing import Dict
8
 
9
- app = FastAPI(title="ONNX Model API with Tokenizer")
10
 
11
- # CORS configuration
12
- app.add_middleware(
13
- CORSMiddleware,
14
- allow_origins=["*"],
15
- allow_methods=["*"],
16
- allow_headers=["*"],
17
  )
18
 
19
- # Initialize components
20
- tokenizer = AutoTokenizer.from_pretrained("Xenova/multi-qa-mpnet-base-dot-v1")
21
  session = InferenceSession("model.onnx")
22
 
23
- def convert_outputs(outputs):
24
- """Ensure all numpy values are converted to Python native types"""
25
- if isinstance(outputs, (np.generic, np.ndarray)):
26
- return outputs.item() if outputs.ndim == 0 else outputs.tolist()
27
- return outputs
28
-
29
- @app.post("/api/process")
30
- async def process_text(request: Dict[str, str]):
31
  try:
32
- text = request.get("text", "")
33
-
34
- # Tokenize the input text
35
  inputs = tokenizer(
36
  text,
37
- return_tensors="np",
38
  padding=True,
39
  truncation=True,
40
  max_length=32 # Match your model's expected input size
41
  )
42
 
43
- # Convert to ONNX-compatible format
44
  onnx_inputs = {
45
  "input_ids": inputs["input_ids"].astype(np.int64),
46
  "attention_mask": inputs["attention_mask"].astype(np.int64)
47
  }
48
 
49
- # Run model inference
50
  outputs = session.run(None, onnx_inputs)
51
 
52
- # Convert all numpy types to native Python types
53
- processed_outputs = [convert_outputs(output) for output in outputs]
54
-
55
  return {
56
- "embedding": processed_outputs[0], # Assuming first output is embeddings
57
  "tokens": tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
58
  }
59
 
60
  except Exception as e:
61
- raise HTTPException(status_code=400, detail=str(e))
62
-
63
- @app.get("/health")
64
- async def health_check():
65
- return {"status": "healthy"}
 
1
  from fastapi import FastAPI, HTTPException
 
 
2
  from onnxruntime import InferenceSession
3
+ from transformers import AutoTokenizer
4
  import numpy as np
5
  import os
 
6
 
7
+ app = FastAPI()
8
 
9
+ # Initialize tokenizer (doesn't require PyTorch/TensorFlow)
10
+ tokenizer = AutoTokenizer.from_pretrained(
11
+ "Xenova/multi-qa-mpnet-base-dot-v1",
12
+ use_fast=True, # Uses Rust implementation
13
+ legacy=False
 
14
  )
15
 
16
+ # Load ONNX model
 
17
  session = InferenceSession("model.onnx")
18
 
19
+ @app.post("/api/predict")
20
+ async def predict(text: str):
 
 
 
 
 
 
21
  try:
22
+ # Tokenize without framework dependencies
 
 
23
  inputs = tokenizer(
24
  text,
25
+ return_tensors="np", # Get NumPy arrays directly
26
  padding=True,
27
  truncation=True,
28
  max_length=32 # Match your model's expected input size
29
  )
30
 
31
+ # Prepare ONNX inputs
32
  onnx_inputs = {
33
  "input_ids": inputs["input_ids"].astype(np.int64),
34
  "attention_mask": inputs["attention_mask"].astype(np.int64)
35
  }
36
 
37
+ # Run inference
38
  outputs = session.run(None, onnx_inputs)
39
 
40
+ # Convert to native Python types
 
 
41
  return {
42
+ "embedding": outputs[0].astype(np.float32).tolist(),
43
  "tokens": tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
44
  }
45
 
46
  except Exception as e:
47
+ raise HTTPException(status_code=400, detail=str(e))