from https://github.com/ArneBinder/pie-document-level/pull/233
Browse filesuse QdrantVectorStore again (since it works with download/upload now)
- vector_store.py +82 -33
vector_store.py
CHANGED
@@ -12,6 +12,16 @@ E = TypeVar("E")
|
|
12 |
|
13 |
|
14 |
class VectorStore(Generic[T, E], abc.ABC):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
@abc.abstractmethod
|
16 |
def _add(self, embedding: E, payload: T, emb_id: str) -> None:
|
17 |
"""Save an embedding with payload for a given ID."""
|
@@ -22,6 +32,11 @@ class VectorStore(Generic[T, E], abc.ABC):
|
|
22 |
"""Get the embedding for a given ID."""
|
23 |
pass
|
24 |
|
|
|
|
|
|
|
|
|
|
|
25 |
def _get_emb_id(self, emb_id: Optional[str] = None, payload: Optional[T] = None) -> str:
|
26 |
if emb_id is None:
|
27 |
if payload is None:
|
@@ -67,16 +82,41 @@ class VectorStore(Generic[T, E], abc.ABC):
|
|
67 |
def __len__(self):
|
68 |
pass
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
def save_to_directory(self, directory: str) -> None:
|
71 |
"""Save the vector store to a directory."""
|
72 |
-
|
|
|
73 |
|
74 |
def load_from_directory(self, directory: str, replace: bool = False) -> None:
|
75 |
"""Load the vector store from a directory.
|
76 |
|
77 |
If `replace` is True, the current content of the store will be replaced.
|
78 |
"""
|
79 |
-
|
|
|
|
|
80 |
|
81 |
|
82 |
def vector_norm(vector: List[float]) -> float:
|
@@ -88,10 +128,7 @@ def cosine_similarity(a: List[float], b: List[float]) -> float:
|
|
88 |
|
89 |
|
90 |
class SimpleVectorStore(VectorStore[T, List[float]]):
|
91 |
-
|
92 |
-
INDEX_FILE = "vectors_index.json"
|
93 |
-
EMBEDDINGS_FILE = "vectors_data.npy"
|
94 |
-
PAYLOADS_FILE = "vectors_payloads.json"
|
95 |
|
96 |
def __init__(self):
|
97 |
self.vectors: dict[str, List[float]] = {}
|
@@ -148,8 +185,7 @@ class SimpleVectorStore(VectorStore[T, List[float]]):
|
|
148 |
|
149 |
return [(emb_id, self.payloads[emb_id], sim) for emb_id, sim in similar_entries]
|
150 |
|
151 |
-
def
|
152 |
-
os.makedirs(directory, exist_ok=True)
|
153 |
indices = list(self.vectors.keys())
|
154 |
with open(os.path.join(directory, self.INDEX_FILE), "w") as f:
|
155 |
json.dump(indices, f)
|
@@ -159,20 +195,15 @@ class SimpleVectorStore(VectorStore[T, List[float]]):
|
|
159 |
with open(os.path.join(directory, self.PAYLOADS_FILE), "w") as f:
|
160 |
json.dump(payloads, f)
|
161 |
|
162 |
-
def
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
embeddings_np = np.load(os.path.join(directory, self.EMBEDDINGS_FILE))
|
168 |
-
with open(os.path.join(directory, self.PAYLOADS_FILE), "r") as f:
|
169 |
-
payloads = json.load(f)
|
170 |
-
for emb_id, emb, payload in zip(index, embeddings_np, payloads):
|
171 |
-
self.vectors[emb_id] = emb.tolist()
|
172 |
-
self.payloads[emb_id] = payload
|
173 |
|
174 |
|
175 |
class QdrantVectorStore(VectorStore[T, List[float]]):
|
|
|
176 |
|
177 |
COLLECTION_NAME = "ADUs"
|
178 |
MAX_LIMIT = 100
|
@@ -184,8 +215,8 @@ class QdrantVectorStore(VectorStore[T, List[float]]):
|
|
184 |
distance: Distance = Distance.COSINE,
|
185 |
):
|
186 |
self.client = QdrantClient(location=location)
|
187 |
-
self.
|
188 |
-
self.
|
189 |
self.client.create_collection(
|
190 |
collection_name=self.COLLECTION_NAME,
|
191 |
vectors_config=VectorParams(size=vector_size, distance=distance),
|
@@ -196,21 +227,26 @@ class QdrantVectorStore(VectorStore[T, List[float]]):
|
|
196 |
|
197 |
def _add(self, emb_id: str, payload: T, embedding: List[float]) -> None:
|
198 |
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
self.client.upsert(
|
204 |
collection_name=self.COLLECTION_NAME,
|
205 |
-
points=[PointStruct(id=
|
206 |
)
|
207 |
-
self.id2idx[emb_id] = _id
|
208 |
-
self.idx2id[_id] = emb_id
|
209 |
|
210 |
def _get(self, emb_id: str) -> Optional[List[float]]:
|
211 |
points = self.client.retrieve(
|
212 |
collection_name=self.COLLECTION_NAME,
|
213 |
-
ids=[self.
|
214 |
with_vectors=True,
|
215 |
)
|
216 |
if len(points) == 0:
|
@@ -225,11 +261,14 @@ class QdrantVectorStore(VectorStore[T, List[float]]):
|
|
225 |
) -> List[Tuple[str, T, float]]:
|
226 |
similar_entries = self.client.recommend(
|
227 |
collection_name=self.COLLECTION_NAME,
|
228 |
-
positive=[self.
|
229 |
limit=top_k or self.MAX_LIMIT,
|
230 |
score_threshold=min_similarity,
|
231 |
)
|
232 |
-
return [
|
|
|
|
|
|
|
233 |
|
234 |
def clear(self) -> None:
|
235 |
vectors_config = self.client.get_collection(
|
@@ -240,5 +279,15 @@ class QdrantVectorStore(VectorStore[T, List[float]]):
|
|
240 |
collection_name=self.COLLECTION_NAME,
|
241 |
vectors_config=vectors_config,
|
242 |
)
|
243 |
-
self.
|
244 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
|
14 |
class VectorStore(Generic[T, E], abc.ABC):
|
15 |
+
"""Abstract base class for a vector store.
|
16 |
+
|
17 |
+
A vector store is a key-value store that maps an ID to a vector embedding and a payload. The
|
18 |
+
payload can be any JSON-serializable object, e.g. a dictionary.
|
19 |
+
"""
|
20 |
+
|
21 |
+
INDEX_FILE = "vectors_index.json"
|
22 |
+
EMBEDDINGS_FILE = "vectors_data.npy"
|
23 |
+
PAYLOADS_FILE = "vectors_payloads.json"
|
24 |
+
|
25 |
@abc.abstractmethod
|
26 |
def _add(self, embedding: E, payload: T, emb_id: str) -> None:
|
27 |
"""Save an embedding with payload for a given ID."""
|
|
|
32 |
"""Get the embedding for a given ID."""
|
33 |
pass
|
34 |
|
35 |
+
@abc.abstractmethod
|
36 |
+
def clear(self) -> None:
|
37 |
+
"""Clear the store."""
|
38 |
+
pass
|
39 |
+
|
40 |
def _get_emb_id(self, emb_id: Optional[str] = None, payload: Optional[T] = None) -> str:
|
41 |
if emb_id is None:
|
42 |
if payload is None:
|
|
|
82 |
def __len__(self):
|
83 |
pass
|
84 |
|
85 |
+
def _add_from_directory(self, directory: str) -> None:
|
86 |
+
with open(os.path.join(directory, self.INDEX_FILE), "r") as f:
|
87 |
+
index = json.load(f)
|
88 |
+
embeddings_np = np.load(os.path.join(directory, self.EMBEDDINGS_FILE))
|
89 |
+
with open(os.path.join(directory, self.PAYLOADS_FILE), "r") as f:
|
90 |
+
payloads = json.load(f)
|
91 |
+
for emb_id, emb, payload in zip(index, embeddings_np, payloads):
|
92 |
+
self._add(emb_id=emb_id, payload=payload, embedding=emb.tolist())
|
93 |
+
|
94 |
+
@abc.abstractmethod
|
95 |
+
def as_indices_vectors_payloads(self) -> Tuple[List[str], np.ndarray, List[T]]:
|
96 |
+
"""Return a tuple of indices, vectors and payloads."""
|
97 |
+
pass
|
98 |
+
|
99 |
+
def _save_to_directory(self, directory: str) -> None:
|
100 |
+
indices, vectors, payloads = self.as_indices_vectors_payloads()
|
101 |
+
np.save(os.path.join(directory, self.EMBEDDINGS_FILE), vectors)
|
102 |
+
with open(os.path.join(directory, self.PAYLOADS_FILE), "w") as f:
|
103 |
+
json.dump(payloads, f)
|
104 |
+
with open(os.path.join(directory, self.INDEX_FILE), "w") as f:
|
105 |
+
json.dump(indices, f)
|
106 |
+
|
107 |
def save_to_directory(self, directory: str) -> None:
|
108 |
"""Save the vector store to a directory."""
|
109 |
+
os.makedirs(directory, exist_ok=True)
|
110 |
+
self._save_to_directory(directory)
|
111 |
|
112 |
def load_from_directory(self, directory: str, replace: bool = False) -> None:
|
113 |
"""Load the vector store from a directory.
|
114 |
|
115 |
If `replace` is True, the current content of the store will be replaced.
|
116 |
"""
|
117 |
+
if replace:
|
118 |
+
self.clear()
|
119 |
+
self._add_from_directory(directory)
|
120 |
|
121 |
|
122 |
def vector_norm(vector: List[float]) -> float:
|
|
|
128 |
|
129 |
|
130 |
class SimpleVectorStore(VectorStore[T, List[float]]):
|
131 |
+
"""Simple in-memory vector store using a dictionary."""
|
|
|
|
|
|
|
132 |
|
133 |
def __init__(self):
|
134 |
self.vectors: dict[str, List[float]] = {}
|
|
|
185 |
|
186 |
return [(emb_id, self.payloads[emb_id], sim) for emb_id, sim in similar_entries]
|
187 |
|
188 |
+
def _save_to_directory(self, directory: str) -> None:
|
|
|
189 |
indices = list(self.vectors.keys())
|
190 |
with open(os.path.join(directory, self.INDEX_FILE), "w") as f:
|
191 |
json.dump(indices, f)
|
|
|
195 |
with open(os.path.join(directory, self.PAYLOADS_FILE), "w") as f:
|
196 |
json.dump(payloads, f)
|
197 |
|
198 |
+
def as_indices_vectors_payloads(self) -> Tuple[List[str], np.ndarray, List[T]]:
|
199 |
+
indices = list(self.vectors.keys())
|
200 |
+
embeddings_np = np.array(list(self.vectors.values()))
|
201 |
+
payloads = [self.payloads[idx] for idx in indices]
|
202 |
+
return indices, embeddings_np, payloads
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
|
204 |
|
205 |
class QdrantVectorStore(VectorStore[T, List[float]]):
|
206 |
+
"""Vector store using Qdrant as a backend."""
|
207 |
|
208 |
COLLECTION_NAME = "ADUs"
|
209 |
MAX_LIMIT = 100
|
|
|
215 |
distance: Distance = Distance.COSINE,
|
216 |
):
|
217 |
self.client = QdrantClient(location=location)
|
218 |
+
self.emb_id2point_id = {}
|
219 |
+
self.point_id2emb_id = {}
|
220 |
self.client.create_collection(
|
221 |
collection_name=self.COLLECTION_NAME,
|
222 |
vectors_config=VectorParams(size=vector_size, distance=distance),
|
|
|
227 |
|
228 |
def _add(self, emb_id: str, payload: T, embedding: List[float]) -> None:
|
229 |
|
230 |
+
if emb_id in self.emb_id2point_id:
|
231 |
+
# update existing entry
|
232 |
+
point_id = self.emb_id2point_id[emb_id]
|
233 |
+
else:
|
234 |
+
# we use the length of the emb_id2point_id dict as the index,
|
235 |
+
# because we assume that, even when we delete an entry from
|
236 |
+
# the store, we do not delete it from the index
|
237 |
+
point_id = len(self.emb_id2point_id)
|
238 |
+
self.emb_id2point_id[emb_id] = point_id
|
239 |
+
self.point_id2emb_id[point_id] = emb_id
|
240 |
+
|
241 |
self.client.upsert(
|
242 |
collection_name=self.COLLECTION_NAME,
|
243 |
+
points=[PointStruct(id=point_id, vector=embedding, payload=payload)],
|
244 |
)
|
|
|
|
|
245 |
|
246 |
def _get(self, emb_id: str) -> Optional[List[float]]:
|
247 |
points = self.client.retrieve(
|
248 |
collection_name=self.COLLECTION_NAME,
|
249 |
+
ids=[self.emb_id2point_id[emb_id]],
|
250 |
with_vectors=True,
|
251 |
)
|
252 |
if len(points) == 0:
|
|
|
261 |
) -> List[Tuple[str, T, float]]:
|
262 |
similar_entries = self.client.recommend(
|
263 |
collection_name=self.COLLECTION_NAME,
|
264 |
+
positive=[self.emb_id2point_id[ref_id]],
|
265 |
limit=top_k or self.MAX_LIMIT,
|
266 |
score_threshold=min_similarity,
|
267 |
)
|
268 |
+
return [
|
269 |
+
(self.point_id2emb_id[entry.id], entry.payload, entry.score)
|
270 |
+
for entry in similar_entries
|
271 |
+
]
|
272 |
|
273 |
def clear(self) -> None:
|
274 |
vectors_config = self.client.get_collection(
|
|
|
279 |
collection_name=self.COLLECTION_NAME,
|
280 |
vectors_config=vectors_config,
|
281 |
)
|
282 |
+
self.emb_id2point_id.clear()
|
283 |
+
self.point_id2emb_id.clear()
|
284 |
+
|
285 |
+
def as_indices_vectors_payloads(self) -> Tuple[List[str], np.ndarray, List[T]]:
|
286 |
+
num_entries = self.client.get_collection(collection_name=self.COLLECTION_NAME).points_count
|
287 |
+
data, point_ids = self.client.scroll(
|
288 |
+
collection_name=self.COLLECTION_NAME, with_vectors=True, limit=num_entries
|
289 |
+
)
|
290 |
+
vectors_np = np.array([point.vector for point in data])
|
291 |
+
payloads = [point.payload for point in data]
|
292 |
+
emb_ids = [self.point_id2emb_id[point.id] for point in data]
|
293 |
+
return emb_ids, vectors_np, payloads
|