Danil commited on
Commit
e793d79
·
1 Parent(s): 1cf4b2c

Update indexer.py

Browse files
Files changed (1) hide show
  1. indexer.py +105 -9
indexer.py CHANGED
@@ -1,13 +1,15 @@
1
  import pickle
2
  import faiss
3
  import numpy as np
4
- # from grammar import remove_verbs, clean_text
5
  from utils import *
6
  from sentence_transformers import SentenceTransformer
7
 
 
 
 
8
 
9
  class FAISS:
10
- def __init__(self, dimensions: int):
11
  self.dimensions = dimensions
12
  self.index = faiss.IndexFlatL2(dimensions)
13
  self.vectors = {}
@@ -15,23 +17,76 @@ class FAISS:
15
  self.model_name = 'paraphrase-multilingual-MiniLM-L12-v2'
16
  self.sentence_encoder = SentenceTransformer(self.model_name)
17
 
18
- def init_vectors(self, path):
 
 
 
 
 
 
19
  with open(path, 'rb') as pkl_file:
20
  self.vectors = pickle.load(pkl_file)
21
 
22
- def init_index(self, path):
 
 
 
 
 
 
 
 
23
  self.index = faiss.read_index(path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- def add(self, text, idx, pop, emb=None):
 
 
 
 
 
 
 
 
 
26
  if emb is None:
27
  text_vec = self.sentence_encoder.encode([text])
28
  else:
29
  text_vec = emb
 
30
  self.index.add(text_vec)
31
  self.vectors[self.counter] = (idx, text, pop, text_vec)
 
32
  self.counter += 1
33
 
34
- def search(self, v: list, k: int = 10):
 
 
 
 
 
 
 
 
 
35
  result = []
36
  distance, item_index = self.index.search(v, k)
37
  for dist, i in zip(distance[0], item_index[0]):
@@ -42,8 +97,17 @@ class FAISS:
42
 
43
  return result
44
 
45
- def suggest_tags(self, query, top_n=10, k=30) -> list:
 
 
46
 
 
 
 
 
 
 
 
47
  emb = self.sentence_encoder.encode([query.lower()])
48
  r = self.search(emb, k)
49
 
@@ -57,8 +121,40 @@ class FAISS:
57
  for i in range(len(result)):
58
  flag = True
59
  for j in result[i + 1:]:
60
- flag &= easy_check(result[i][1], j[1])
61
  if flag:
62
  total_result.append(result[i][1])
63
 
64
- return total_result[:top_n]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pickle
2
  import faiss
3
  import numpy as np
 
4
  from utils import *
5
  from sentence_transformers import SentenceTransformer
6
 
7
+ from tqdm import tqdm
8
+ from typing import List
9
+
10
 
11
  class FAISS:
12
+ def __init__(self, dimensions: int) -> None:
13
  self.dimensions = dimensions
14
  self.index = faiss.IndexFlatL2(dimensions)
15
  self.vectors = {}
 
17
  self.model_name = 'paraphrase-multilingual-MiniLM-L12-v2'
18
  self.sentence_encoder = SentenceTransformer(self.model_name)
19
 
20
+ def init_vectors(self, path: str) -> None:
21
+ """
22
+ Заполняет набор векторов предобученными значениями
23
+
24
+ Args:
25
+ path: путь к файлу в формате pickle
26
+ """
27
  with open(path, 'rb') as pkl_file:
28
  self.vectors = pickle.load(pkl_file)
29
 
30
+ self.counter = len(self.vectors)
31
+
32
+ def init_index(self, path) -> None:
33
+ """
34
+ Заполняет индекс FAISS предобученными значениями
35
+
36
+ Args:
37
+ path: путь к файлу в формате FAISS
38
+ """
39
  self.index = faiss.read_index(path)
40
+
41
+ def save_vectors(self, path: str) -> None:
42
+ """
43
+ Сохраняет набор векторов
44
+
45
+ Args:
46
+ path: желаемый путь к файлу
47
+ """
48
+ with open(path, "wb") as fp:
49
+ pickle.dump(self.index.vectors, fp)
50
+
51
+ def save_index(self, path: str) -> None:
52
+ """
53
+ Сохраняет индекс FAISS
54
+
55
+ Args:
56
+ path: желаемый путь к файлу
57
+ """
58
+ faiss.write_index(self.index, path)
59
 
60
+ def add(self, text: str, idx: int, pop: float, emb=None) -> None:
61
+ """
62
+ Добавляет в поисковый индекс новый вектор
63
+
64
+ Args:
65
+ text: текст запроса
66
+ idx: индекс нового вектора
67
+ pop: популярность запроса
68
+ emb (optional): эмбеддинг текста запроса (если не указан, то будет подготовлен с помощью self.sentence_encoder)
69
+ """
70
  if emb is None:
71
  text_vec = self.sentence_encoder.encode([text])
72
  else:
73
  text_vec = emb
74
+
75
  self.index.add(text_vec)
76
  self.vectors[self.counter] = (idx, text, pop, text_vec)
77
+
78
  self.counter += 1
79
 
80
+ def search(self, v: List, k: int = 10) -> List[List]:
81
+ """
82
+ Ищет в поисковом индексе ближайших соседей к вектору v
83
+
84
+ Args:
85
+ v: вектор для поиска ближайших соседей
86
+ k: число векторов в выдаче
87
+ Returns:
88
+ список векторов, ближайших к вектору v, в формате [idx, text, popularity, similarity]
89
+ """
90
  result = []
91
  distance, item_index = self.index.search(v, k)
92
  for dist, i in zip(distance[0], item_index[0]):
 
97
 
98
  return result
99
 
100
+ def suggest_tags(self, query: str, top_n: int = 10, k: int = 30) -> List[str]:
101
+ """
102
+ Получает список тегов для пользователя по текстовому запросу
103
 
104
+ Args:
105
+ query: запрос пользователя
106
+ top_n (optional): число тегов в выдаче
107
+ k (optional): число векторов из индекса, среди которых будут искаться теги для выдачи
108
+ Returns:
109
+ список тегов для выдачи пользователю
110
+ """
111
  emb = self.sentence_encoder.encode([query.lower()])
112
  r = self.search(emb, k)
113
 
 
121
  for i in range(len(result)):
122
  flag = True
123
  for j in result[i + 1:]:
124
+ flag &= sweet_check(result[i][1], j[1])
125
  if flag:
126
  total_result.append(result[i][1])
127
 
128
+ return total_result[:top_n]
129
+
130
+ def fill(self, queries: List[str], popularities: pd.DataFrame) -> None:
131
+ """
132
+ Заполняет поисковый индекс запросами queries, популярности которых берутся из таблицы popularities
133
+
134
+ Args:
135
+ queries: список запросов
136
+ popularities: таблица, в которой содержатся колонки query и query_popularity
137
+ """
138
+ idx = -1
139
+ for query in tqdm(queries):
140
+ idx += 1
141
+ if type(query) == str:
142
+ emb = self.index.sentence_encoder.encode([query.lower()])
143
+ bool_add = True
144
+ search_sim = self.index.search(emb, 1)
145
+
146
+ try:
147
+ popularity = popularities[popularities["query"] == query]["query_popularity"].item()
148
+ except ValueError:
149
+ # Если для текущего запроса неизвестна популярность, возьмем значение 5
150
+ popularity = 5
151
+
152
+ if len(search_sim) > 0:
153
+ search_sim = search_sim[0]
154
+ if search_sim[-1] < 0.15:
155
+ # Не добавляем вектор, если он находится достаточно близко к уже присутствующему в индексе
156
+ bool_add = False
157
+ if bool_add:
158
+ self.index.add(query, popularity, idx, emb)
159
+ else:
160
+ self.index.add(query, popularity, idx, emb)