Spaces:
Sleeping
Sleeping
File size: 3,161 Bytes
287a0bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import multiprocessing
import re
from typing import Any, Callable, Dict, Union
from chromadb.types import Metadata
Validator = Callable[[Union[str, int, float]], bool]
param_validators: Dict[str, Validator] = {
"hnsw:space": lambda p: bool(re.match(r"^(l2|cosine|ip)$", str(p))),
"hnsw:construction_ef": lambda p: isinstance(p, int),
"hnsw:search_ef": lambda p: isinstance(p, int),
"hnsw:M": lambda p: isinstance(p, int),
"hnsw:num_threads": lambda p: isinstance(p, int),
"hnsw:resize_factor": lambda p: isinstance(p, (int, float)),
}
# Extra params used for persistent hnsw
persistent_param_validators: Dict[str, Validator] = {
"hnsw:batch_size": lambda p: isinstance(p, int) and p > 2,
"hnsw:sync_threshold": lambda p: isinstance(p, int) and p > 2,
}
class Params:
@staticmethod
def _select(metadata: Metadata) -> Dict[str, Any]:
segment_metadata = {}
for param, value in metadata.items():
if param.startswith("hnsw:"):
segment_metadata[param] = value
return segment_metadata
@staticmethod
def _validate(metadata: Dict[str, Any], validators: Dict[str, Validator]) -> None:
"""Validates the metadata"""
# Validate it
for param, value in metadata.items():
if param not in validators:
raise ValueError(f"Unknown HNSW parameter: {param}")
if not validators[param](value):
raise ValueError(f"Invalid value for HNSW parameter: {param} = {value}")
class HnswParams(Params):
space: str
construction_ef: int
search_ef: int
M: int
num_threads: int
resize_factor: float
def __init__(self, metadata: Metadata):
metadata = metadata or {}
self.space = str(metadata.get("hnsw:space", "l2"))
self.construction_ef = int(metadata.get("hnsw:construction_ef", 100))
self.search_ef = int(metadata.get("hnsw:search_ef", 10))
self.M = int(metadata.get("hnsw:M", 16))
self.num_threads = int(
metadata.get("hnsw:num_threads", multiprocessing.cpu_count())
)
self.resize_factor = float(metadata.get("hnsw:resize_factor", 1.2))
@staticmethod
def extract(metadata: Metadata) -> Metadata:
"""Validate and return only the relevant hnsw params"""
segment_metadata = HnswParams._select(metadata)
HnswParams._validate(segment_metadata, param_validators)
return segment_metadata
class PersistentHnswParams(HnswParams):
batch_size: int
sync_threshold: int
def __init__(self, metadata: Metadata):
super().__init__(metadata)
self.batch_size = int(metadata.get("hnsw:batch_size", 100))
self.sync_threshold = int(metadata.get("hnsw:sync_threshold", 1000))
@staticmethod
def extract(metadata: Metadata) -> Metadata:
"""Returns only the relevant hnsw params"""
all_validators = {**param_validators, **persistent_param_validators}
segment_metadata = PersistentHnswParams._select(metadata)
PersistentHnswParams._validate(segment_metadata, all_validators)
return segment_metadata
|