File size: 2,200 Bytes
0e2d401
0d7d4cd
84f1ee8
589009a
84f1ee8
0e2d401
1e9ac73
589009a
1e9ac73
0e2d401
dd92c0c
 
 
 
 
80490de
589009a
0d7d4cd
80490de
0d7d4cd
 
 
 
 
 
 
 
 
 
 
4f729af
589009a
0e2d401
80490de
0e2d401
 
 
 
 
 
 
6d7e0c5
 
 
 
 
 
 
80490de
6d7e0c5
 
 
 
 
 
c47d871
0d7d4cd
6d7e0c5
bfdd5dc
6d7e0c5
 
d7d161f
0d7d4cd
80490de
0e2d401
20e9804
 
0d7d4cd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from fastapi import FastAPI, HTTPException, Request
from fastapi.encoders import jsonable_encoder
from onnxruntime import InferenceSession
from transformers import AutoTokenizer
import numpy as np
import uvicorn

app = FastAPI()

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    "Xenova/multi-qa-mpnet-base-dot-v1",
    use_fast=True,
    legacy=False
)

# Load ONNX model
session = InferenceSession("model.onnx")

def convert_output(value):
    """Recursively convert numpy types to native Python types"""
    if isinstance(value, (np.generic, np.ndarray)):
        if value.size == 1:
            return float(value.item())  # Convert single values to float
        return value.astype(float).tolist()  # Convert arrays to list
    elif isinstance(value, list):
        return [convert_output(x) for x in value]
    elif isinstance(value, dict):
        return {k: convert_output(v) for k, v in value.items()}
    return value

@app.post("/api/predict")
async def predict(request: Request):
    try:
        data = await request.json()
        text = data.get("text", "")
        
        if not text:
            raise HTTPException(status_code=400, detail="No text provided")
        
        # Tokenize input
        inputs = tokenizer(
            text,
            return_tensors="np",
            padding=False,  # Disable padding
            truncation=False,  # Disable truncation
            add_special_tokens=True  # Ensure CLS/SEP tokens
        )
        
        onnx_inputs = {
            "input_ids": np.array(inputs["input_ids"], dtype=np.int64),
            "attention_mask": np.array(inputs["attention_mask"], dtype=np.int64)
        }
        
        outputs = session.run(None, onnx_inputs)
        print("OUTPUTS",outputs)
        # Prepare response with converted types
        return {
            "embedding": outputs[0][0].astype(float).tolist(),
            "input_ids": inputs["input_ids"][0].tolist(),
            "attention_mask": inputs["attention_mask"][0].tolist()
        }
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)