chryzxc commited on
Commit
1e9ac73
·
verified ·
1 Parent(s): 48abe47

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -19
app.py CHANGED
@@ -1,24 +1,34 @@
1
- # main.py
2
- from fastapi import FastAPI
3
- from pydantic import BaseModel
4
- import onnxruntime as ort
5
  import numpy as np
 
 
 
 
 
 
 
 
 
 
6
 
7
- app = FastAPI()
 
 
 
 
8
 
9
- session = ort.InferenceSession("model.onnx")
 
10
 
11
- class ModelInput(BaseModel):
12
- input_ids: list[int]
13
- attention_mask: list[int]
 
 
 
 
 
 
 
14
 
15
- @app.post("/predict")
16
- def predict(data: ModelInput):
17
- input_ids = np.array(data.input_ids, dtype=np.int64).reshape(1, -1)
18
- attention_mask = np.array(data.attention_mask, dtype=np.int64).reshape(1, -1)
19
- inputs = {
20
- "input_ids": input_ids,
21
- "attention_mask": attention_mask,
22
- }
23
- outputs = session.run(None, inputs)
24
- return {"output": outputs}
 
1
+ import gradio as gr
 
 
 
2
  import numpy as np
3
+ import onnxruntime as ort
4
+
5
+ # Load the ONNX model
6
+ session = ort.InferenceSession("model.onnx", providers=["CPUExecutionProvider"])
7
+
8
+ # Prediction function
9
+ def predict(input_ids: list[int], attention_mask: list[int]):
10
+ # Convert to numpy arrays and batch them
11
+ input_ids_np = np.array([input_ids], dtype=np.int64)
12
+ attention_mask_np = np.array([attention_mask], dtype=np.int64)
13
 
14
+ # Run the model
15
+ outputs = session.run(None, {
16
+ "input_ids": input_ids_np,
17
+ "attention_mask": attention_mask_np
18
+ })
19
 
20
+ # Return raw outputs or post-process as needed
21
+ return outputs
22
 
23
+ # Expose API endpoint
24
+ demo = gr.Interface(
25
+ fn=predict,
26
+ inputs=[
27
+ gr.JSON(label="input_ids"),
28
+ gr.JSON(label="attention_mask")
29
+ ],
30
+ outputs="json",
31
+ allow_flagging="never"
32
+ )
33
 
34
+ app = gr.mount_gradio_app(app=None, blocks=demo, path="/")