File size: 2,019 Bytes
017d40d
80490de
017d40d
84f1ee8
 
80490de
017d40d
1e9ac73
017d40d
1e9ac73
80490de
 
 
 
 
 
 
 
017d40d
 
4afa954
80490de
017d40d
 
 
 
 
84f1ee8
017d40d
 
80490de
017d40d
658ebc3
017d40d
 
 
 
 
 
 
 
80490de
017d40d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7d161f
 
80490de
658ebc3
80490de
017d40d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer
from onnxruntime import InferenceSession
import numpy as np
import os
from typing import Dict

app = FastAPI(title="ONNX Model API with Tokenizer")

# CORS configuration
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# Initialize components
tokenizer = AutoTokenizer.from_pretrained("Xenova/multi-qa-mpnet-base-dot-v1")
session = InferenceSession("model.onnx")

def convert_outputs(outputs):
    """Ensure all numpy values are converted to Python native types"""
    if isinstance(outputs, (np.generic, np.ndarray)):
        return outputs.item() if outputs.ndim == 0 else outputs.tolist()
    return outputs

@app.post("/api/process")
async def process_text(request: Dict[str, str]):
    try:
        text = request.get("text", "")
        
        # Tokenize the input text
        inputs = tokenizer(
            text,
            return_tensors="np",
            padding=True,
            truncation=True,
            max_length=32  # Match your model's expected input size
        )
        
        # Convert to ONNX-compatible format
        onnx_inputs = {
            "input_ids": inputs["input_ids"].astype(np.int64),
            "attention_mask": inputs["attention_mask"].astype(np.int64)
        }
        
        # Run model inference
        outputs = session.run(None, onnx_inputs)
        
        # Convert all numpy types to native Python types
        processed_outputs = [convert_outputs(output) for output in outputs]
        
        return {
            "embedding": processed_outputs[0],  # Assuming first output is embeddings
            "tokens": tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
        }
        
    except Exception as e:
        raise HTTPException(status_code=400, detail=str(e))

@app.get("/health")
async def health_check():
    return {"status": "healthy"}