File size: 3,592 Bytes
d941729
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import json
import numpy as np
import faiss
import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast
from transformers import AutoTokenizer, AutoModel
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

class FaissSearch:
    def __init__(self, model_path, index_path, index_keys_path, filtered_db_path, device='cuda:0'):
        self.device = device
        self.model_path = model_path
        self.index = faiss.read_index(index_path)
        self.max_len = 512

        with open(index_keys_path, 'r', encoding='utf-8') as f:
            self.index_keys = json.load(f)

        with open(filtered_db_path, 'r', encoding='utf-8') as f:
            self.filtered_db_data = json.load(f)

        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = None  

    def _load_model(self):
        if self.model is None:
            self.model = AutoModel.from_pretrained(self.model_path).to(self.device)

    def _query_tokenization(self, text):
        #text = "query: " + text # if using e5 model
        text = text
        tokens = self.tokenizer(
            text,
            return_tensors="pt",
            padding='max_length',
            truncation=True,
            max_length=self.max_len
        )
        return tokens

    def _query_embed_extraction(self, tokens, do_normalization=True):
        self._load_model()
        self.model.eval()
        with torch.no_grad():
            with autocast():
                inputs = {k: v.to(self.device) for k, v in tokens.items()}
                outputs = self.model(**inputs)
                embedding = outputs.last_hidden_state[:, 0].cpu()

                if do_normalization:
                    embedding = F.normalize(embedding, dim=-1)
        return embedding.numpy()

    def _search_results_filtering(self, preds, dists):
        sorted_values = [(ref, score) for ref, score in zip(preds, dists)]
        sorted_values = sorted(sorted_values, key=lambda x: x[1], reverse=True)
        sorted_preds = [x[0] for x in sorted_values]
        sorted_scores = [x[1] for x in sorted_values]
        return sorted_preds, sorted_scores

    def search(self, query, top=20):
        query_tokens = self._query_tokenization(query)
        query_embeds = self._query_embed_extraction(query_tokens, do_normalization=True)
        distances, indices = self.index.search(query_embeds, len(self.filtered_db_data))

        preds = [self.index_keys[str(x)] for x in indices[0]]
        preds, scores = self._search_results_filtering(preds, distances[0])
        docs = [self.filtered_db_data[ref] for ref in preds]

        torch.cuda.empty_cache()

        return preds[:top], docs[:top]
    

STEP = 5000
model_path = os.environ.get("MODEL_PATH", "bge/")
index_path = f"faiss_indexes/faiss__bge_{STEP}.index"
index_keys_path = f"faiss_indexes/index_keys__bge_{STEP}.json"
filtered_db_path = f"faiss_indexes/filtered_db_data__bge_{STEP}.json"

searcher = FaissSearch(model_path, index_path, index_keys_path, filtered_db_path, os.environ.get("DEVICE", "cuda:0"))

app = FastAPI()

class SearchRequest(BaseModel):
    query: str
    top: int = 10

class SearchResponse(BaseModel):
    predictions: list
    documents: list

@app.post("/search", response_model=SearchResponse)
async def search_endpoint(request: SearchRequest):
    try:
        preds, docs = searcher.search(request.query, top=request.top)
        return SearchResponse(predictions=preds, documents=docs)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))