chryzxc commited on
Commit
84f1ee8
·
verified ·
1 Parent(s): 4286dff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -17
app.py CHANGED
@@ -1,29 +1,28 @@
1
- from fastapi import FastAPI
2
- from fastapi.middleware.cors import CORSMiddleware
3
- import numpy as np
4
- from onnxruntime import InferenceSession
5
  from transformers import AutoTokenizer
6
- import os
 
 
 
7
 
8
  app = FastAPI()
9
 
10
- # CORS setup
11
- app.add_middleware(
12
- CORSMiddleware,
13
- allow_origins=["*"],
14
- allow_methods=["*"],
15
- allow_headers=["*"],
16
  )
17
-
18
- # Load model
19
  session = InferenceSession("model.onnx")
20
- tokenizer = AutoTokenizer.from_pretrained("Xenova/multi-qa-mpnet-base-dot-v1")
 
 
21
 
22
  @app.post("/predict")
23
  async def predict(query: str):
 
24
  inputs = tokenizer(query, return_tensors="np")
25
  inputs = {k: v.astype(np.int64) for k, v in inputs.items()}
26
- outputs = session.run(None, inputs)
27
- embedding = outputs[0][0].tolist()
28
 
29
- return {"embedding": embedding}
 
 
 
 
 
 
 
 
1
  from transformers import AutoTokenizer
2
+ from onnxruntime import InferenceSession
3
+ import numpy as np
4
+ import json
5
+ from fastapi import FastAPI
6
 
7
  app = FastAPI()
8
 
9
+ # Initialize components
10
+ tokenizer = AutoTokenizer.from_pretrained(
11
+ "Xenova/multi-qa-mpnet-base-dot-v1",
12
+ use_fast=False # Avoids framework dependencies
 
 
13
  )
 
 
14
  session = InferenceSession("model.onnx")
15
+
16
+ def cosine_similarity(a, b):
17
+ return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
18
 
19
  @app.post("/predict")
20
  async def predict(query: str):
21
+ # Tokenize
22
  inputs = tokenizer(query, return_tensors="np")
23
  inputs = {k: v.astype(np.int64) for k, v in inputs.items()}
 
 
24
 
25
+ # Get embedding
26
+ embedding = session.run(None, inputs)[0][0]
27
+
28
+ return {"embedding": embedding.tolist()}