Spaces:
Sleeping
Sleeping
File size: 11,569 Bytes
287a0bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 |
from multiprocessing.connection import Connection
import sys
import os
import shutil
import subprocess
import tempfile
from types import ModuleType
from typing import Generator, List, Tuple, Dict, Any, Callable, Type
from hypothesis import given, settings
import hypothesis.strategies as st
import pytest
import json
from urllib import request
from chromadb import config
from chromadb.api import ServerAPI
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
import chromadb.test.property.strategies as strategies
import chromadb.test.property.invariants as invariants
from packaging import version as packaging_version
import re
import multiprocessing
from chromadb.config import Settings
MINIMUM_VERSION = "0.4.1"
version_re = re.compile(r"^[0-9]+\.[0-9]+\.[0-9]+$")
# Some modules do not work across versions, since we upgrade our support for them, and should be explicitly reimported in the subprocess
VERSIONED_MODULES = ["pydantic"]
def versions() -> List[str]:
"""Returns the pinned minimum version and the latest version of chromadb."""
url = "https://pypi.org/pypi/chromadb/json"
data = json.load(request.urlopen(request.Request(url)))
versions = list(data["releases"].keys())
# Older versions on pypi contain "devXYZ" suffixes
versions = [v for v in versions if version_re.match(v)]
versions.sort(key=packaging_version.Version)
return [MINIMUM_VERSION, versions[-1]]
def _bool_to_int(metadata: Dict[str, Any]) -> Dict[str, Any]:
metadata.update((k, 1) for k, v in metadata.items() if v is True)
metadata.update((k, 0) for k, v in metadata.items() if v is False)
return metadata
def _patch_boolean_metadata(
collection: strategies.Collection,
embeddings: strategies.RecordSet,
settings: Settings,
) -> None:
# Since the old version does not support boolean value metadata, we will convert
# boolean value metadata to int
collection_metadata = collection.metadata
if collection_metadata is not None:
_bool_to_int(collection_metadata) # type: ignore
if embeddings["metadatas"] is not None:
if isinstance(embeddings["metadatas"], list):
for metadata in embeddings["metadatas"]:
if metadata is not None and isinstance(metadata, dict):
_bool_to_int(metadata)
elif isinstance(embeddings["metadatas"], dict):
metadata = embeddings["metadatas"]
_bool_to_int(metadata)
def _patch_telemetry_client(
collection: strategies.Collection,
embeddings: strategies.RecordSet,
settings: Settings,
) -> None:
# chroma 0.4.14 added OpenTelemetry, distinct from ProductTelemetry. Before 0.4.14
# ProductTelemetry was simply called Telemetry.
settings.chroma_telemetry_impl = "chromadb.telemetry.posthog.Posthog"
version_patches: List[
Tuple[str, Callable[[strategies.Collection, strategies.RecordSet, Settings], None]]
] = [
("0.4.3", _patch_boolean_metadata),
("0.4.14", _patch_telemetry_client),
]
def patch_for_version(
version: str,
collection: strategies.Collection,
embeddings: strategies.RecordSet,
settings: Settings,
) -> None:
"""Override aspects of the collection and embeddings, before testing, to account for
breaking changes in old versions."""
for patch_version, patch in version_patches:
if packaging_version.Version(version) <= packaging_version.Version(
patch_version
):
patch(collection, embeddings, settings)
def api_import_for_version(module: Any, version: str) -> Type: # type: ignore
if packaging_version.Version(version) <= packaging_version.Version("0.4.14"):
return module.api.API # type: ignore
return module.api.ServerAPI # type: ignore
def configurations(versions: List[str]) -> List[Tuple[str, Settings]]:
return [
(
version,
Settings(
chroma_api_impl="chromadb.api.segment.SegmentAPI",
chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager",
allow_reset=True,
is_persistent=True,
persist_directory=tempfile.gettempdir() + "/persistence_test_chromadb",
),
)
for version in versions
]
test_old_versions = versions()
base_install_dir = tempfile.gettempdir() + "/persistence_test_chromadb_versions"
# This fixture is not shared with the rest of the tests because it is unique in how it
# installs the versions of chromadb
@pytest.fixture(scope="module", params=configurations(test_old_versions)) # type: ignore
def version_settings(request) -> Generator[Tuple[str, Settings], None, None]:
configuration = request.param
version = configuration[0]
install_version(version)
yield configuration
# Cleanup the installed version
path = get_path_to_version_install(version)
shutil.rmtree(path)
# Cleanup the persisted data
data_path = configuration[1].persist_directory
if os.path.exists(data_path):
shutil.rmtree(data_path, ignore_errors=True)
def get_path_to_version_install(version: str) -> str:
return base_install_dir + "/" + version
def get_path_to_version_library(version: str) -> str:
return get_path_to_version_install(version) + "/chromadb/__init__.py"
def install_version(version: str) -> None:
# Check if already installed
version_library = get_path_to_version_library(version)
if os.path.exists(version_library):
return
path = get_path_to_version_install(version)
install(f"chromadb=={version}", path)
def install(pkg: str, path: str) -> int:
# -q -q to suppress pip output to ERROR level
# https://pip.pypa.io/en/stable/cli/pip/#quiet
print(f"Installing chromadb version {pkg} to {path}")
return subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"-q",
"-q",
"install",
pkg,
"--target={}".format(path),
]
)
def switch_to_version(version: str) -> ModuleType:
module_name = "chromadb"
# Remove old version from sys.modules, except test modules
old_modules = {
n: m
for n, m in sys.modules.items()
if n == module_name
or (n.startswith(module_name + "."))
or n in VERSIONED_MODULES
or (any(n.startswith(m + ".") for m in VERSIONED_MODULES))
}
for n in old_modules:
del sys.modules[n]
# Load the target version and override the path to the installed version
# https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
sys.path.insert(0, get_path_to_version_install(version))
import chromadb
assert chromadb.__version__ == version
return chromadb
class not_implemented_ef(EmbeddingFunction[Documents]):
def __call__(self, input: Documents) -> Embeddings:
assert False, "Embedding function should not be called"
def persist_generated_data_with_old_version(
version: str,
settings: Settings,
collection_strategy: strategies.Collection,
embeddings_strategy: strategies.RecordSet,
conn: Connection,
) -> None:
try:
old_module = switch_to_version(version)
system = old_module.config.System(settings)
api = system.instance(api_import_for_version(old_module, version))
system.start()
api.reset()
coll = api.create_collection(
name=collection_strategy.name,
metadata=collection_strategy.metadata,
# In order to test old versions, we can't rely on the not_implemented function
embedding_function=not_implemented_ef(),
)
coll.add(**embeddings_strategy)
# Just use some basic checks for sanity and manual testing where you break the new
# version
check_embeddings = invariants.wrap_all(embeddings_strategy)
# Check count
assert coll.count() == len(check_embeddings["embeddings"] or [])
# Check ids
result = coll.get()
actual_ids = result["ids"]
embedding_id_to_index = {id: i for i, id in enumerate(check_embeddings["ids"])}
actual_ids = sorted(actual_ids, key=lambda id: embedding_id_to_index[id])
assert actual_ids == check_embeddings["ids"]
# Shutdown system
system.stop()
except Exception as e:
conn.send(e)
raise e
# Since we can't pickle the embedding function, we always generate record sets with embeddings
collection_st: st.SearchStrategy[strategies.Collection] = st.shared(
strategies.collections(with_hnsw_params=True, has_embeddings=True), key="coll"
)
@given(
collection_strategy=collection_st,
embeddings_strategy=strategies.recordsets(collection_st),
)
@settings(deadline=None)
def test_cycle_versions(
version_settings: Tuple[str, Settings],
collection_strategy: strategies.Collection,
embeddings_strategy: strategies.RecordSet,
) -> None:
# Test backwards compatibility
# For the current version, ensure that we can load a collection from
# the previous versions
version, settings = version_settings
# The strategies can generate metadatas of malformed inputs. Other tests
# will error check and cover these cases to make sure they error. Here we
# just convert them to valid values since the error cases are already tested
if embeddings_strategy["metadatas"] == {}:
embeddings_strategy["metadatas"] = None
if embeddings_strategy["metadatas"] is not None and isinstance(
embeddings_strategy["metadatas"], list
):
embeddings_strategy["metadatas"] = [
m if m is None or len(m) > 0 else None # type: ignore
for m in embeddings_strategy["metadatas"]
]
patch_for_version(version, collection_strategy, embeddings_strategy, settings)
# Can't pickle a function, and we won't need them
collection_strategy.embedding_function = None
collection_strategy.known_metadata_keys = {}
# Run the task in a separate process to avoid polluting the current process
# with the old version. Using spawn instead of fork to avoid sharing the
# current process memory which would cause the old version to be loaded
ctx = multiprocessing.get_context("spawn")
conn1, conn2 = multiprocessing.Pipe()
p = ctx.Process(
target=persist_generated_data_with_old_version,
args=(version, settings, collection_strategy, embeddings_strategy, conn2),
)
p.start()
p.join()
if conn1.poll():
e = conn1.recv()
raise e
p.close()
# Switch to the current version (local working directory) and check the invariants
# are preserved for the collection
system = config.System(settings)
api = system.instance(ServerAPI)
system.start()
coll = api.get_collection(
name=collection_strategy.name,
embedding_function=not_implemented_ef(), # type: ignore
)
invariants.count(coll, embeddings_strategy)
invariants.metadatas_match(coll, embeddings_strategy)
invariants.documents_match(coll, embeddings_strategy)
invariants.ids_match(coll, embeddings_strategy)
invariants.ann_accuracy(coll, embeddings_strategy)
# Shutdown system
system.stop()
|