File size: 1,983 Bytes
0e2d401
0d7d4cd
84f1ee8
589009a
84f1ee8
0e2d401
1e9ac73
589009a
1e9ac73
0e2d401
dd92c0c
 
 
 
 
80490de
589009a
0d7d4cd
80490de
0d7d4cd
 
 
 
 
 
 
 
 
 
 
4f729af
589009a
0e2d401
80490de
0e2d401
 
 
 
 
 
 
855d918
80490de
0d7d4cd
 
017d40d
 
0d7d4cd
0e2d401
0d7d4cd
 
 
017d40d
d7d161f
3be5098
0d7d4cd
 
80490de
0e2d401
20e9804
 
0d7d4cd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from fastapi import FastAPI, HTTPException, Request
from fastapi.encoders import jsonable_encoder
from onnxruntime import InferenceSession
from transformers import AutoTokenizer
import numpy as np
import uvicorn

app = FastAPI()

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    "Xenova/multi-qa-mpnet-base-dot-v1",
    use_fast=True,
    legacy=False
)

# Load ONNX model
session = InferenceSession("model.onnx")

def convert_output(value):
    """Recursively convert numpy types to native Python types"""
    if isinstance(value, (np.generic, np.ndarray)):
        if value.size == 1:
            return float(value.item())  # Convert single values to float
        return value.astype(float).tolist()  # Convert arrays to list
    elif isinstance(value, list):
        return [convert_output(x) for x in value]
    elif isinstance(value, dict):
        return {k: convert_output(v) for k, v in value.items()}
    return value

@app.post("/api/predict")
async def predict(request: Request):
    try:
        data = await request.json()
        text = data.get("text", "")
        
        if not text:
            raise HTTPException(status_code=400, detail="No text provided")
        
        # Tokenize input
        inputs = tokenizer(text)
        
        # Run model
        outputs = session.run(None, {
            "input_ids": inputs["input_ids"].astype(np.int64),
            "attention_mask": inputs["attention_mask"].astype(np.int64)
        })
        
        # Prepare response with converted types
        response = {
            "embedding": convert_output(outputs[0]),  # Process main output
            "tokens": tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
        }
        print("embeddings", response["embedding"])
        return jsonable_encoder(response)
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)