AnyNameHack / indexer.py
Danil's picture
Upload indexer.py
c05a5b4
raw
history blame
2.06 kB
import pickle
import faiss
import numpy as np
# from grammar import remove_verbs, clean_text
from utils import *
from sentence_transformers import SentenceTransformer
class FAISS:
def __init__(self, dimensions: int):
self.dimensions = dimensions
self.index = faiss.IndexFlatL2(dimensions)
self.vectors = {}
self.counter = 0
self.model_name = 'paraphrase-multilingual-MiniLM-L12-v2'
self.sentence_encoder = SentenceTransformer(self.model_name)
def init_vectors(self, path):
with open(path, 'rb') as pkl_file:
self.vectors = pickle.load(pkl_file)
def init_index(self, path):
self.index = faiss.read_index(path)
def add(self, text, idx, pop, emb=None):
if emb is None:
text_vec = self.sentence_encoder.encode([text])
else:
text_vec = emb
self.index.add(text_vec)
self.vectors[self.counter] = (idx, text, pop, text_vec)
self.counter += 1
def search(self, v: list, k: int = 10):
result = []
distance, item_index = self.index.search(v, k)
for dist, i in zip(distance[0], item_index[0]):
if i == -1:
break
else:
result.append((self.vectors[i][0], self.vectors[i][1], self.vectors[i][2], dist))
return result
def suggest_tags(self, query, top_n=10, k=30) -> list:
emb = self.sentence_encoder.encode([query.lower()])
r = self.search(emb, k)
result = []
for i in r:
if check(query, i[1]):
result.append(i)
# надо добавить вес относительно длины
result = sorted(result, key=lambda x: x[0] * 0.3 - x[-1], reverse=True)
total_result = []
for i in range(len(result)):
flag = True
for j in result[i + 1:]:
flag &= sweet_check(result[i][1], j[1])
if flag:
total_result.append(result[i][1])
return total_result[:top_n]