Spaces:
Paused
Paused
import os | |
from unittest.mock import MagicMock | |
import pytest | |
from _pytest.monkeypatch import MonkeyPatch | |
from pymochow import MochowClient | |
from pymochow.model.database import Database | |
from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState | |
from pymochow.model.schema import HNSWParams, VectorIndex | |
from pymochow.model.table import Table | |
from requests.adapters import HTTPAdapter | |
class AttrDict(dict): | |
def __getattr__(self, item): | |
return self.get(item) | |
class MockBaiduVectorDBClass: | |
def mock_vector_db_client( | |
self, | |
config=None, | |
adapter: HTTPAdapter = None, | |
): | |
self.conn = MagicMock() | |
self._config = MagicMock() | |
def list_databases(self, config=None) -> list[Database]: | |
return [ | |
Database( | |
conn=self.conn, | |
database_name="dify", | |
config=self._config, | |
) | |
] | |
def create_database(self, database_name: str, config=None) -> Database: | |
return Database(conn=self.conn, database_name=database_name, config=config) | |
def list_table(self, config=None) -> list[Table]: | |
return [] | |
def drop_table(self, table_name: str, config=None): | |
return {"code": 0, "msg": "Success"} | |
def create_table( | |
self, | |
table_name: str, | |
replication: int, | |
partition: int, | |
schema, | |
enable_dynamic_field=False, | |
description: str = "", | |
config=None, | |
) -> Table: | |
return Table(self, table_name, replication, partition, schema, enable_dynamic_field, description, config) | |
def describe_table(self, table_name: str, config=None) -> Table: | |
return Table( | |
self, | |
table_name, | |
3, | |
1, | |
None, | |
enable_dynamic_field=False, | |
description="table for dify", | |
config=config, | |
state=TableState.NORMAL, | |
) | |
def upsert(self, rows, config=None): | |
return {"code": 0, "msg": "operation success", "affectedCount": 1} | |
def rebuild_index(self, index_name: str, config=None): | |
return {"code": 0, "msg": "Success"} | |
def describe_index(self, index_name: str, config=None): | |
return VectorIndex( | |
index_name=index_name, | |
index_type=IndexType.HNSW, | |
field="vector", | |
metric_type=MetricType.L2, | |
params=HNSWParams(m=16, efconstruction=200), | |
auto_build=False, | |
state=IndexState.NORMAL, | |
) | |
def query( | |
self, | |
primary_key, | |
partition_key=None, | |
projections=None, | |
retrieve_vector=False, | |
read_consistency=ReadConsistency.EVENTUAL, | |
config=None, | |
): | |
return AttrDict( | |
{ | |
"row": { | |
"id": primary_key.get("id"), | |
"vector": [0.23432432, 0.8923744, 0.89238432], | |
"text": "text", | |
"metadata": '{"doc_id": "doc_id_001"}', | |
}, | |
"code": 0, | |
"msg": "Success", | |
} | |
) | |
def delete(self, primary_key=None, partition_key=None, filter=None, config=None): | |
return {"code": 0, "msg": "Success"} | |
def search( | |
self, | |
anns, | |
partition_key=None, | |
projections=None, | |
retrieve_vector=False, | |
read_consistency=ReadConsistency.EVENTUAL, | |
config=None, | |
): | |
return AttrDict( | |
{ | |
"rows": [ | |
{ | |
"row": { | |
"id": "doc_id_001", | |
"vector": [0.23432432, 0.8923744, 0.89238432], | |
"text": "text", | |
"metadata": '{"doc_id": "doc_id_001"}', | |
}, | |
"distance": 0.1, | |
"score": 0.5, | |
} | |
], | |
"code": 0, | |
"msg": "Success", | |
} | |
) | |
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" | |
def setup_baiduvectordb_mock(request, monkeypatch: MonkeyPatch): | |
if MOCK: | |
monkeypatch.setattr(MochowClient, "__init__", MockBaiduVectorDBClass.mock_vector_db_client) | |
monkeypatch.setattr(MochowClient, "list_databases", MockBaiduVectorDBClass.list_databases) | |
monkeypatch.setattr(MochowClient, "create_database", MockBaiduVectorDBClass.create_database) | |
monkeypatch.setattr(Database, "table", MockBaiduVectorDBClass.describe_table) | |
monkeypatch.setattr(Database, "list_table", MockBaiduVectorDBClass.list_table) | |
monkeypatch.setattr(Database, "create_table", MockBaiduVectorDBClass.create_table) | |
monkeypatch.setattr(Database, "drop_table", MockBaiduVectorDBClass.drop_table) | |
monkeypatch.setattr(Database, "describe_table", MockBaiduVectorDBClass.describe_table) | |
monkeypatch.setattr(Table, "rebuild_index", MockBaiduVectorDBClass.rebuild_index) | |
monkeypatch.setattr(Table, "describe_index", MockBaiduVectorDBClass.describe_index) | |
monkeypatch.setattr(Table, "delete", MockBaiduVectorDBClass.delete) | |
monkeypatch.setattr(Table, "query", MockBaiduVectorDBClass.query) | |
monkeypatch.setattr(Table, "search", MockBaiduVectorDBClass.search) | |
yield | |
if MOCK: | |
monkeypatch.undo() | |