File size: 1,934 Bytes
0e2d401
84f1ee8
589009a
84f1ee8
80490de
0e2d401
1e9ac73
589009a
1e9ac73
0e2d401
589009a
 
0e2d401
589009a
80490de
 
589009a
0e2d401
 
 
 
 
 
80490de
4f729af
0e2d401
 
4f729af
589009a
0e2d401
80490de
0e2d401
 
 
 
 
 
 
 
017d40d
 
0e2d401
 
017d40d
0e2d401
017d40d
80490de
0e2d401
017d40d
 
 
 
 
589009a
017d40d
 
0e2d401
 
 
017d40d
0e2d401
017d40d
d7d161f
 
80490de
0e2d401
20e9804
 
 
0e2d401
20e9804
 
0e2d401
20e9804
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from fastapi import FastAPI, HTTPException, Request
from onnxruntime import InferenceSession
from transformers import AutoTokenizer
import numpy as np
import os
import uvicorn

app = FastAPI()

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    "Xenova/multi-qa-mpnet-base-dot-v1",
    use_fast=True,
    legacy=False
)

# Load ONNX model
try:
    session = InferenceSession("model.onnx")
    print("Model loaded successfully")
except Exception as e:
    print(f"Failed to load model: {str(e)}")
    raise

@app.get("/")
def health_check():
    return {"status": "OK", "model": "ONNX"}

@app.post("/api/predict")
async def predict(request: Request):
    try:
        # Get JSON input
        data = await request.json()
        text = data.get("text", "")
        
        if not text:
            raise HTTPException(status_code=400, detail="No text provided")
        
        # Tokenize input
        inputs = tokenizer(
            text,
            return_tensors="np",
            padding="max_length",
            truncation=True,
            max_length=32
        )
        
        # Prepare ONNX inputs with correct shapes
        onnx_inputs = {
            "input_ids": inputs["input_ids"].astype(np.int64),
            "attention_mask": inputs["attention_mask"].astype(np.int64)
        }
        
        # Run inference
        outputs = session.run(None, onnx_inputs)
        
        # Convert outputs to list and handle numpy types
        embedding = outputs[0][0].astype(float).tolist()  # First output, first batch
        
        return {
            "embedding": embedding,
            "tokens": tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
        }
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    uvicorn.run(
        "app:app",
        host="0.0.0.0",
        port=7860,
        reload=False
    )