Spaces:
Runtime error
Runtime error
File size: 2,064 Bytes
e5d55e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import pickle
import faiss
import numpy as np
# from grammar import remove_verbs, clean_text
from utils import *
from sentence_transformers import SentenceTransformer
class FAISS:
def __init__(self, dimensions: int):
self.dimensions = dimensions
self.index = faiss.IndexFlatL2(dimensions)
self.vectors = {}
self.counter = 0
self.model_name = 'paraphrase-multilingual-MiniLM-L12-v2'
self.sentence_encoder = SentenceTransformer(self.model_name)
def init_vectors(self, path):
with open(path, 'rb') as pkl_file:
self.vectors = pickle.load(pkl_file)
def init_index(self, path):
self.index = faiss.read_index(path)
def add(self, text, idx, pop, emb=None):
if emb is None:
text_vec = self.sentence_encoder.encode([text])
else:
text_vec = emb
self.index.add(text_vec)
self.vectors[self.counter] = (idx, text, pop, text_vec)
self.counter += 1
def search(self, v: list, k: int = 10):
result = []
distance, item_index = self.index.search(v, k)
for dist, i in zip(distance[0], item_index[0]):
if i == -1:
break
else:
result.append((self.vectors[i][0], self.vectors[i][1], self.vectors[i][2], dist))
return result
def suggest_tags(self, query, top_n=10, k=30) -> list:
emb = self.sentence_encoder.encode([query.lower()])
r = self.search(emb, k)
result = []
for i in r:
if check(query, i[1]):
result.append(i)
# надо добавить вес относительно длины
result = sorted(result, key=lambda x: x[0] * 0.3 - x[-1], reverse=True)
total_result = []
for i in range(len(result)):
flag = True
for j in result[i + 1:]:
flag &= sweet_check(result[i][1], j[1])
if flag:
total_result.append(result[i][1])
return total_result[:top_n] |