|  | from abc import ABC, abstractmethod | 
					
						
						|  | from typing import Optional, Union | 
					
						
						|  | from dataclasses import dataclass | 
					
						
						|  | import numpy as np | 
					
						
						|  | import polars as pl | 
					
						
						|  | from typing import List, Dict | 
					
						
						|  |  | 
					
						
						|  | DEFAULT_MATCH_VECTOR_TOPN = 10 | 
					
						
						|  | DEFAULT_MATCH_SPARSE_TOPN = 10 | 
					
						
						|  | VEC = Union[list, np.ndarray] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class SparseVector: | 
					
						
						|  | indices: list[int] | 
					
						
						|  | values: Union[list[float], list[int], None] = None | 
					
						
						|  |  | 
					
						
						|  | def __post_init__(self): | 
					
						
						|  | assert (self.values is None) or (len(self.indices) == len(self.values)) | 
					
						
						|  |  | 
					
						
						|  | def to_dict_old(self): | 
					
						
						|  | d = {"indices": self.indices} | 
					
						
						|  | if self.values is not None: | 
					
						
						|  | d["values"] = self.values | 
					
						
						|  | return d | 
					
						
						|  |  | 
					
						
						|  | def to_dict(self): | 
					
						
						|  | if self.values is None: | 
					
						
						|  | raise ValueError("SparseVector.values is None") | 
					
						
						|  | result = {} | 
					
						
						|  | for i, v in zip(self.indices, self.values): | 
					
						
						|  | result[str(i)] = v | 
					
						
						|  | return result | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def from_dict(d): | 
					
						
						|  | return SparseVector(d["indices"], d.get("values")) | 
					
						
						|  |  | 
					
						
						|  | def __str__(self): | 
					
						
						|  | return f"SparseVector(indices={self.indices}{'' if self.values is None else f', values={self.values}'})" | 
					
						
						|  |  | 
					
						
						|  | def __repr__(self): | 
					
						
						|  | return str(self) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class MatchTextExpr(ABC): | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | fields: str, | 
					
						
						|  | matching_text: str, | 
					
						
						|  | topn: int, | 
					
						
						|  | extra_options: dict = dict(), | 
					
						
						|  | ): | 
					
						
						|  | self.fields = fields | 
					
						
						|  | self.matching_text = matching_text | 
					
						
						|  | self.topn = topn | 
					
						
						|  | self.extra_options = extra_options | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class MatchDenseExpr(ABC): | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | vector_column_name: str, | 
					
						
						|  | embedding_data: VEC, | 
					
						
						|  | embedding_data_type: str, | 
					
						
						|  | distance_type: str, | 
					
						
						|  | topn: int = DEFAULT_MATCH_VECTOR_TOPN, | 
					
						
						|  | extra_options: dict = dict(), | 
					
						
						|  | ): | 
					
						
						|  | self.vector_column_name = vector_column_name | 
					
						
						|  | self.embedding_data = embedding_data | 
					
						
						|  | self.embedding_data_type = embedding_data_type | 
					
						
						|  | self.distance_type = distance_type | 
					
						
						|  | self.topn = topn | 
					
						
						|  | self.extra_options = extra_options | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class MatchSparseExpr(ABC): | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | vector_column_name: str, | 
					
						
						|  | sparse_data: SparseVector | dict, | 
					
						
						|  | distance_type: str, | 
					
						
						|  | topn: int, | 
					
						
						|  | opt_params: Optional[dict] = None, | 
					
						
						|  | ): | 
					
						
						|  | self.vector_column_name = vector_column_name | 
					
						
						|  | self.sparse_data = sparse_data | 
					
						
						|  | self.distance_type = distance_type | 
					
						
						|  | self.topn = topn | 
					
						
						|  | self.opt_params = opt_params | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class MatchTensorExpr(ABC): | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | column_name: str, | 
					
						
						|  | query_data: VEC, | 
					
						
						|  | query_data_type: str, | 
					
						
						|  | topn: int, | 
					
						
						|  | extra_option: Optional[dict] = None, | 
					
						
						|  | ): | 
					
						
						|  | self.column_name = column_name | 
					
						
						|  | self.query_data = query_data | 
					
						
						|  | self.query_data_type = query_data_type | 
					
						
						|  | self.topn = topn | 
					
						
						|  | self.extra_option = extra_option | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class FusionExpr(ABC): | 
					
						
						|  | def __init__(self, method: str, topn: int, fusion_params: Optional[dict] = None): | 
					
						
						|  | self.method = method | 
					
						
						|  | self.topn = topn | 
					
						
						|  | self.fusion_params = fusion_params | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | MatchExpr = Union[ | 
					
						
						|  | MatchTextExpr, MatchDenseExpr, MatchSparseExpr, MatchTensorExpr, FusionExpr | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class OrderByExpr(ABC): | 
					
						
						|  | def __init__(self): | 
					
						
						|  | self.fields = list() | 
					
						
						|  | def asc(self, field: str): | 
					
						
						|  | self.fields.append((field, 0)) | 
					
						
						|  | return self | 
					
						
						|  | def desc(self, field: str): | 
					
						
						|  | self.fields.append((field, 1)) | 
					
						
						|  | return self | 
					
						
						|  | def fields(self): | 
					
						
						|  | return self.fields | 
					
						
						|  |  | 
					
						
						|  | class DocStoreConnection(ABC): | 
					
						
						|  | """ | 
					
						
						|  | Database operations | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def dbType(self) -> str: | 
					
						
						|  | """ | 
					
						
						|  | Return the type of the database. | 
					
						
						|  | """ | 
					
						
						|  | raise NotImplementedError("Not implemented") | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def health(self) -> dict: | 
					
						
						|  | """ | 
					
						
						|  | Return the health status of the database. | 
					
						
						|  | """ | 
					
						
						|  | raise NotImplementedError("Not implemented") | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  | Table operations | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int): | 
					
						
						|  | """ | 
					
						
						|  | Create an index with given name | 
					
						
						|  | """ | 
					
						
						|  | raise NotImplementedError("Not implemented") | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def deleteIdx(self, indexName: str, knowledgebaseId: str): | 
					
						
						|  | """ | 
					
						
						|  | Delete an index with given name | 
					
						
						|  | """ | 
					
						
						|  | raise NotImplementedError("Not implemented") | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def indexExist(self, indexName: str, knowledgebaseId: str) -> bool: | 
					
						
						|  | """ | 
					
						
						|  | Check if an index with given name exists | 
					
						
						|  | """ | 
					
						
						|  | raise NotImplementedError("Not implemented") | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  | CRUD operations | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def search( | 
					
						
						|  | self, selectFields: list[str], highlight: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexNames: str|list[str], knowledgebaseIds: list[str] | 
					
						
						|  | ) -> list[dict] | pl.DataFrame: | 
					
						
						|  | """ | 
					
						
						|  | Search with given conjunctive equivalent filtering condition and return all fields of matched documents | 
					
						
						|  | """ | 
					
						
						|  | raise NotImplementedError("Not implemented") | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None: | 
					
						
						|  | """ | 
					
						
						|  | Get single chunk with given id | 
					
						
						|  | """ | 
					
						
						|  | raise NotImplementedError("Not implemented") | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str) -> list[str]: | 
					
						
						|  | """ | 
					
						
						|  | Update or insert a bulk of rows | 
					
						
						|  | """ | 
					
						
						|  | raise NotImplementedError("Not implemented") | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool: | 
					
						
						|  | """ | 
					
						
						|  | Update rows with given conjunctive equivalent filtering condition | 
					
						
						|  | """ | 
					
						
						|  | raise NotImplementedError("Not implemented") | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int: | 
					
						
						|  | """ | 
					
						
						|  | Delete rows with given conjunctive equivalent filtering condition | 
					
						
						|  | """ | 
					
						
						|  | raise NotImplementedError("Not implemented") | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  | Helper functions for search result | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def getTotal(self, res): | 
					
						
						|  | raise NotImplementedError("Not implemented") | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def getChunkIds(self, res): | 
					
						
						|  | raise NotImplementedError("Not implemented") | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def getFields(self, res, fields: List[str]) -> Dict[str, dict]: | 
					
						
						|  | raise NotImplementedError("Not implemented") | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def getHighlight(self, res, keywords: List[str], fieldnm: str): | 
					
						
						|  | raise NotImplementedError("Not implemented") | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def getAggregation(self, res, fieldnm: str): | 
					
						
						|  | raise NotImplementedError("Not implemented") | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  | SQL | 
					
						
						|  | """ | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def sql(sql: str, fetch_size: int, format: str): | 
					
						
						|  | """ | 
					
						
						|  | Run the sql generated by text-to-sql | 
					
						
						|  | """ | 
					
						
						|  | raise NotImplementedError("Not implemented") | 
					
						
						|  |  |