chryzxc commited on
Commit
84f505f
·
verified ·
1 Parent(s): 1e9ac73

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -29
app.py CHANGED
@@ -1,34 +1,29 @@
1
- import gradio as gr
 
2
  import numpy as np
3
- import onnxruntime as ort
 
 
4
 
5
- # Load the ONNX model
6
- session = ort.InferenceSession("model.onnx", providers=["CPUExecutionProvider"])
7
 
8
- # Prediction function
9
- def predict(input_ids: list[int], attention_mask: list[int]):
10
- # Convert to numpy arrays and batch them
11
- input_ids_np = np.array([input_ids], dtype=np.int64)
12
- attention_mask_np = np.array([attention_mask], dtype=np.int64)
13
-
14
- # Run the model
15
- outputs = session.run(None, {
16
- "input_ids": input_ids_np,
17
- "attention_mask": attention_mask_np
18
- })
19
-
20
- # Return raw outputs or post-process as needed
21
- return outputs
22
-
23
- # Expose API endpoint
24
- demo = gr.Interface(
25
- fn=predict,
26
- inputs=[
27
- gr.JSON(label="input_ids"),
28
- gr.JSON(label="attention_mask")
29
- ],
30
- outputs="json",
31
- allow_flagging="never"
32
  )
33
 
34
- app = gr.mount_gradio_app(app=None, blocks=demo, path="/")
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.middleware.cors import CORSMiddleware
3
  import numpy as np
4
+ from onnxruntime import InferenceSession
5
+ from transformers import AutoTokenizer
6
+ import os
7
 
8
+ app = FastAPI()
 
9
 
10
+ # CORS setup
11
+ app.add_middleware(
12
+ CORSMiddleware,
13
+ allow_origins=["*"],
14
+ allow_methods=["*"],
15
+ allow_headers=["*"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  )
17
 
18
+ # Load model
19
+ session = InferenceSession("model.onnx")
20
+ tokenizer = AutoTokenizer.from_pretrained("Xenova/multi-qa-mpnet-base-dot-v1")
21
+
22
+ @app.post("/predict")
23
+ async def predict(query: str):
24
+ inputs = tokenizer(query, return_tensors="np")
25
+ inputs = {k: v.astype(np.int64) for k, v in inputs.items()}
26
+ outputs = session.run(None, inputs)
27
+ embedding = outputs[0][0].tolist()
28
+
29
+ return {"embedding": embedding}