Spaces:
Running
Running
from typing import Optional, Sequence, Any, Tuple, cast, Dict, Union, Set | |
from uuid import UUID | |
from overrides import override | |
from pypika import Table, Column | |
from itertools import groupby | |
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System | |
from chromadb.db.base import ( | |
Cursor, | |
SqlDB, | |
ParameterValue, | |
get_sql, | |
NotFoundError, | |
UniqueConstraintError, | |
) | |
from chromadb.db.system import SysDB | |
from chromadb.telemetry.opentelemetry import ( | |
add_attributes_to_current_span, | |
OpenTelemetryClient, | |
OpenTelemetryGranularity, | |
trace_method, | |
) | |
from chromadb.ingest import CollectionAssignmentPolicy, Producer | |
from chromadb.types import ( | |
Database, | |
OptionalArgument, | |
Segment, | |
Metadata, | |
Collection, | |
SegmentScope, | |
Tenant, | |
Unspecified, | |
UpdateMetadata, | |
) | |
class SqlSysDB(SqlDB, SysDB): | |
_assignment_policy: CollectionAssignmentPolicy | |
# Used only to delete topics on collection deletion. | |
# TODO: refactor to remove this dependency into a separate interface | |
_producer: Producer | |
def __init__(self, system: System): | |
self._assignment_policy = system.instance(CollectionAssignmentPolicy) | |
super().__init__(system) | |
self._opentelemetry_client = system.require(OpenTelemetryClient) | |
def start(self) -> None: | |
super().start() | |
self._producer = self._system.instance(Producer) | |
def create_database( | |
self, id: UUID, name: str, tenant: str = DEFAULT_TENANT | |
) -> None: | |
with self.tx() as cur: | |
# Get the tenant id for the tenant name and then insert the database with the id, name and tenant id | |
databases = Table("databases") | |
tenants = Table("tenants") | |
insert_database = ( | |
self.querybuilder() | |
.into(databases) | |
.columns(databases.id, databases.name, databases.tenant_id) | |
.insert( | |
ParameterValue(self.uuid_to_db(id)), | |
ParameterValue(name), | |
self.querybuilder() | |
.select(tenants.id) | |
.from_(tenants) | |
.where(tenants.id == ParameterValue(tenant)), | |
) | |
) | |
sql, params = get_sql(insert_database, self.parameter_format()) | |
try: | |
cur.execute(sql, params) | |
except self.unique_constraint_error() as e: | |
raise UniqueConstraintError( | |
f"Database {name} already exists for tenant {tenant}" | |
) from e | |
def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database: | |
with self.tx() as cur: | |
databases = Table("databases") | |
q = ( | |
self.querybuilder() | |
.from_(databases) | |
.select(databases.id, databases.name) | |
.where(databases.name == ParameterValue(name)) | |
.where(databases.tenant_id == ParameterValue(tenant)) | |
) | |
sql, params = get_sql(q, self.parameter_format()) | |
row = cur.execute(sql, params).fetchone() | |
if not row: | |
raise NotFoundError(f"Database {name} not found for tenant {tenant}") | |
if row[0] is None: | |
raise NotFoundError(f"Database {name} not found for tenant {tenant}") | |
id: UUID = cast(UUID, self.uuid_from_db(row[0])) | |
return Database( | |
id=id, | |
name=row[1], | |
tenant=tenant, | |
) | |
def create_tenant(self, name: str) -> None: | |
with self.tx() as cur: | |
tenants = Table("tenants") | |
insert_tenant = ( | |
self.querybuilder() | |
.into(tenants) | |
.columns(tenants.id) | |
.insert(ParameterValue(name)) | |
) | |
sql, params = get_sql(insert_tenant, self.parameter_format()) | |
try: | |
cur.execute(sql, params) | |
except self.unique_constraint_error() as e: | |
raise UniqueConstraintError(f"Tenant {name} already exists") from e | |
def get_tenant(self, name: str) -> Tenant: | |
with self.tx() as cur: | |
tenants = Table("tenants") | |
q = ( | |
self.querybuilder() | |
.from_(tenants) | |
.select(tenants.id) | |
.where(tenants.id == ParameterValue(name)) | |
) | |
sql, params = get_sql(q, self.parameter_format()) | |
row = cur.execute(sql, params).fetchone() | |
if not row: | |
raise NotFoundError(f"Tenant {name} not found") | |
return Tenant(name=name) | |
def create_segment(self, segment: Segment) -> None: | |
add_attributes_to_current_span( | |
{ | |
"segment_id": str(segment["id"]), | |
"segment_type": segment["type"], | |
"segment_scope": segment["scope"].value, | |
"segment_topic": str(segment["topic"]), | |
"collection": str(segment["collection"]), | |
} | |
) | |
with self.tx() as cur: | |
segments = Table("segments") | |
insert_segment = ( | |
self.querybuilder() | |
.into(segments) | |
.columns( | |
segments.id, | |
segments.type, | |
segments.scope, | |
segments.topic, | |
segments.collection, | |
) | |
.insert( | |
ParameterValue(self.uuid_to_db(segment["id"])), | |
ParameterValue(segment["type"]), | |
ParameterValue(segment["scope"].value), | |
ParameterValue(segment["topic"]), | |
ParameterValue(self.uuid_to_db(segment["collection"])), | |
) | |
) | |
sql, params = get_sql(insert_segment, self.parameter_format()) | |
try: | |
cur.execute(sql, params) | |
except self.unique_constraint_error() as e: | |
raise UniqueConstraintError( | |
f"Segment {segment['id']} already exists" | |
) from e | |
metadata_t = Table("segment_metadata") | |
if segment["metadata"]: | |
self._insert_metadata( | |
cur, | |
metadata_t, | |
metadata_t.segment_id, | |
segment["id"], | |
segment["metadata"], | |
) | |
def create_collection( | |
self, | |
id: UUID, | |
name: str, | |
metadata: Optional[Metadata] = None, | |
dimension: Optional[int] = None, | |
get_or_create: bool = False, | |
tenant: str = DEFAULT_TENANT, | |
database: str = DEFAULT_DATABASE, | |
) -> Tuple[Collection, bool]: | |
if id is None and not get_or_create: | |
raise ValueError("id must be specified if get_or_create is False") | |
add_attributes_to_current_span( | |
{ | |
"collection_id": str(id), | |
"collection_name": name, | |
} | |
) | |
existing = self.get_collections(name=name, tenant=tenant, database=database) | |
if existing: | |
if get_or_create: | |
collection = existing[0] | |
if metadata is not None and collection["metadata"] != metadata: | |
self.update_collection( | |
collection["id"], | |
metadata=metadata, | |
) | |
return ( | |
self.get_collections( | |
id=collection["id"], tenant=tenant, database=database | |
)[0], | |
False, | |
) | |
else: | |
raise UniqueConstraintError(f"Collection {name} already exists") | |
topic = self._assignment_policy.assign_collection(id) | |
collection = Collection( | |
id=id, | |
topic=topic, | |
name=name, | |
metadata=metadata, | |
dimension=dimension, | |
tenant=tenant, | |
database=database, | |
) | |
with self.tx() as cur: | |
collections = Table("collections") | |
databases = Table("databases") | |
insert_collection = ( | |
self.querybuilder() | |
.into(collections) | |
.columns( | |
collections.id, | |
collections.topic, | |
collections.name, | |
collections.dimension, | |
collections.database_id, | |
) | |
.insert( | |
ParameterValue(self.uuid_to_db(collection["id"])), | |
ParameterValue(collection["topic"]), | |
ParameterValue(collection["name"]), | |
ParameterValue(collection["dimension"]), | |
# Get the database id for the database with the given name and tenant | |
self.querybuilder() | |
.select(databases.id) | |
.from_(databases) | |
.where(databases.name == ParameterValue(database)) | |
.where(databases.tenant_id == ParameterValue(tenant)), | |
) | |
) | |
sql, params = get_sql(insert_collection, self.parameter_format()) | |
try: | |
cur.execute(sql, params) | |
except self.unique_constraint_error() as e: | |
raise UniqueConstraintError( | |
f"Collection {collection['id']} already exists" | |
) from e | |
metadata_t = Table("collection_metadata") | |
if collection["metadata"]: | |
self._insert_metadata( | |
cur, | |
metadata_t, | |
metadata_t.collection_id, | |
collection["id"], | |
collection["metadata"], | |
) | |
return collection, True | |
def get_segments( | |
self, | |
id: Optional[UUID] = None, | |
type: Optional[str] = None, | |
scope: Optional[SegmentScope] = None, | |
topic: Optional[str] = None, | |
collection: Optional[UUID] = None, | |
) -> Sequence[Segment]: | |
add_attributes_to_current_span( | |
{ | |
"segment_id": str(id), | |
"segment_type": type if type else "", | |
"segment_scope": scope.value if scope else "", | |
"segment_topic": topic if topic else "", | |
"collection": str(collection), | |
} | |
) | |
segments_t = Table("segments") | |
metadata_t = Table("segment_metadata") | |
q = ( | |
self.querybuilder() | |
.from_(segments_t) | |
.select( | |
segments_t.id, | |
segments_t.type, | |
segments_t.scope, | |
segments_t.topic, | |
segments_t.collection, | |
metadata_t.key, | |
metadata_t.str_value, | |
metadata_t.int_value, | |
metadata_t.float_value, | |
) | |
.left_join(metadata_t) | |
.on(segments_t.id == metadata_t.segment_id) | |
.orderby(segments_t.id) | |
) | |
if id: | |
q = q.where(segments_t.id == ParameterValue(self.uuid_to_db(id))) | |
if type: | |
q = q.where(segments_t.type == ParameterValue(type)) | |
if scope: | |
q = q.where(segments_t.scope == ParameterValue(scope.value)) | |
if topic: | |
q = q.where(segments_t.topic == ParameterValue(topic)) | |
if collection: | |
q = q.where( | |
segments_t.collection == ParameterValue(self.uuid_to_db(collection)) | |
) | |
with self.tx() as cur: | |
sql, params = get_sql(q, self.parameter_format()) | |
rows = cur.execute(sql, params).fetchall() | |
by_segment = groupby(rows, lambda r: cast(object, r[0])) | |
segments = [] | |
for segment_id, segment_rows in by_segment: | |
id = self.uuid_from_db(str(segment_id)) | |
rows = list(segment_rows) | |
type = str(rows[0][1]) | |
scope = SegmentScope(str(rows[0][2])) | |
topic = str(rows[0][3]) if rows[0][3] else None | |
collection = self.uuid_from_db(rows[0][4]) if rows[0][4] else None | |
metadata = self._metadata_from_rows(rows) | |
segments.append( | |
Segment( | |
id=cast(UUID, id), | |
type=type, | |
scope=scope, | |
topic=topic, | |
collection=collection, | |
metadata=metadata, | |
) | |
) | |
return segments | |
def get_collections( | |
self, | |
id: Optional[UUID] = None, | |
topic: Optional[str] = None, | |
name: Optional[str] = None, | |
tenant: str = DEFAULT_TENANT, | |
database: str = DEFAULT_DATABASE, | |
limit: Optional[int] = None, | |
offset: Optional[int] = None, | |
) -> Sequence[Collection]: | |
"""Get collections by name, embedding function and/or metadata""" | |
if name is not None and (tenant is None or database is None): | |
raise ValueError( | |
"If name is specified, tenant and database must also be specified in order to uniquely identify the collection" | |
) | |
add_attributes_to_current_span( | |
{ | |
"collection_id": str(id), | |
"collection_topic": topic if topic else "", | |
"collection_name": name if name else "", | |
} | |
) | |
collections_t = Table("collections") | |
metadata_t = Table("collection_metadata") | |
databases_t = Table("databases") | |
q = ( | |
self.querybuilder() | |
.from_(collections_t) | |
.select( | |
collections_t.id, | |
collections_t.name, | |
collections_t.topic, | |
collections_t.dimension, | |
databases_t.name, | |
databases_t.tenant_id, | |
metadata_t.key, | |
metadata_t.str_value, | |
metadata_t.int_value, | |
metadata_t.float_value, | |
) | |
.left_join(metadata_t) | |
.on(collections_t.id == metadata_t.collection_id) | |
.left_join(databases_t) | |
.on(collections_t.database_id == databases_t.id) | |
.orderby(collections_t.id) | |
) | |
if id: | |
q = q.where(collections_t.id == ParameterValue(self.uuid_to_db(id))) | |
if topic: | |
q = q.where(collections_t.topic == ParameterValue(topic)) | |
if name: | |
q = q.where(collections_t.name == ParameterValue(name)) | |
# Only if we have a name, tenant and database do we need to filter databases | |
# Given an id, we can uniquely identify the collection so we don't need to filter databases | |
if id is None and tenant and database: | |
databases_t = Table("databases") | |
q = q.where( | |
collections_t.database_id | |
== self.querybuilder() | |
.select(databases_t.id) | |
.from_(databases_t) | |
.where(databases_t.name == ParameterValue(database)) | |
.where(databases_t.tenant_id == ParameterValue(tenant)) | |
) | |
# cant set limit and offset here because this is metadata and we havent reduced yet | |
with self.tx() as cur: | |
sql, params = get_sql(q, self.parameter_format()) | |
rows = cur.execute(sql, params).fetchall() | |
by_collection = groupby(rows, lambda r: cast(object, r[0])) | |
collections = [] | |
for collection_id, collection_rows in by_collection: | |
id = self.uuid_from_db(str(collection_id)) | |
rows = list(collection_rows) | |
name = str(rows[0][1]) | |
topic = str(rows[0][2]) | |
dimension = int(rows[0][3]) if rows[0][3] else None | |
metadata = self._metadata_from_rows(rows) | |
collections.append( | |
Collection( | |
id=cast(UUID, id), | |
topic=topic, | |
name=name, | |
metadata=metadata, | |
dimension=dimension, | |
tenant=str(rows[0][5]), | |
database=str(rows[0][4]), | |
) | |
) | |
# apply limit and offset | |
if limit is not None: | |
collections = collections[offset:offset+limit] | |
else: | |
collections = collections[offset:] | |
return collections | |
def delete_segment(self, id: UUID) -> None: | |
"""Delete a segment from the SysDB""" | |
add_attributes_to_current_span( | |
{ | |
"segment_id": str(id), | |
} | |
) | |
t = Table("segments") | |
q = ( | |
self.querybuilder() | |
.from_(t) | |
.where(t.id == ParameterValue(self.uuid_to_db(id))) | |
.delete() | |
) | |
with self.tx() as cur: | |
# no need for explicit del from metadata table because of ON DELETE CASCADE | |
sql, params = get_sql(q, self.parameter_format()) | |
sql = sql + " RETURNING id" | |
result = cur.execute(sql, params).fetchone() | |
if not result: | |
raise NotFoundError(f"Segment {id} not found") | |
def delete_collection( | |
self, | |
id: UUID, | |
tenant: str = DEFAULT_TENANT, | |
database: str = DEFAULT_DATABASE, | |
) -> None: | |
"""Delete a topic and all associated segments from the SysDB""" | |
add_attributes_to_current_span( | |
{ | |
"collection_id": str(id), | |
} | |
) | |
t = Table("collections") | |
databases_t = Table("databases") | |
q = ( | |
self.querybuilder() | |
.from_(t) | |
.where(t.id == ParameterValue(self.uuid_to_db(id))) | |
.where( | |
t.database_id | |
== self.querybuilder() | |
.select(databases_t.id) | |
.from_(databases_t) | |
.where(databases_t.name == ParameterValue(database)) | |
.where(databases_t.tenant_id == ParameterValue(tenant)) | |
) | |
.delete() | |
) | |
with self.tx() as cur: | |
# no need for explicit del from metadata table because of ON DELETE CASCADE | |
sql, params = get_sql(q, self.parameter_format()) | |
sql = sql + " RETURNING id, topic" | |
result = cur.execute(sql, params).fetchone() | |
if not result: | |
raise NotFoundError(f"Collection {id} not found") | |
self._producer.delete_topic(result[1]) | |
def update_segment( | |
self, | |
id: UUID, | |
topic: OptionalArgument[Optional[str]] = Unspecified(), | |
collection: OptionalArgument[Optional[UUID]] = Unspecified(), | |
metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(), | |
) -> None: | |
add_attributes_to_current_span( | |
{ | |
"segment_id": str(id), | |
"collection": str(collection), | |
} | |
) | |
segments_t = Table("segments") | |
metadata_t = Table("segment_metadata") | |
q = ( | |
self.querybuilder() | |
.update(segments_t) | |
.where(segments_t.id == ParameterValue(self.uuid_to_db(id))) | |
) | |
if not topic == Unspecified(): | |
q = q.set(segments_t.topic, ParameterValue(topic)) | |
if not collection == Unspecified(): | |
collection = cast(Optional[UUID], collection) | |
q = q.set( | |
segments_t.collection, ParameterValue(self.uuid_to_db(collection)) | |
) | |
with self.tx() as cur: | |
sql, params = get_sql(q, self.parameter_format()) | |
if sql: # pypika emits a blank string if nothing to do | |
cur.execute(sql, params) | |
if metadata is None: | |
q = ( | |
self.querybuilder() | |
.from_(metadata_t) | |
.where(metadata_t.segment_id == ParameterValue(self.uuid_to_db(id))) | |
.delete() | |
) | |
sql, params = get_sql(q, self.parameter_format()) | |
cur.execute(sql, params) | |
elif metadata != Unspecified(): | |
metadata = cast(UpdateMetadata, metadata) | |
metadata = cast(UpdateMetadata, metadata) | |
self._insert_metadata( | |
cur, | |
metadata_t, | |
metadata_t.segment_id, | |
id, | |
metadata, | |
set(metadata.keys()), | |
) | |
def update_collection( | |
self, | |
id: UUID, | |
topic: OptionalArgument[Optional[str]] = Unspecified(), | |
name: OptionalArgument[str] = Unspecified(), | |
dimension: OptionalArgument[Optional[int]] = Unspecified(), | |
metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(), | |
) -> None: | |
add_attributes_to_current_span( | |
{ | |
"collection_id": str(id), | |
} | |
) | |
collections_t = Table("collections") | |
metadata_t = Table("collection_metadata") | |
q = ( | |
self.querybuilder() | |
.update(collections_t) | |
.where(collections_t.id == ParameterValue(self.uuid_to_db(id))) | |
) | |
if not topic == Unspecified(): | |
q = q.set(collections_t.topic, ParameterValue(topic)) | |
if not name == Unspecified(): | |
q = q.set(collections_t.name, ParameterValue(name)) | |
if not dimension == Unspecified(): | |
q = q.set(collections_t.dimension, ParameterValue(dimension)) | |
with self.tx() as cur: | |
sql, params = get_sql(q, self.parameter_format()) | |
if sql: # pypika emits a blank string if nothing to do | |
sql = sql + " RETURNING id" | |
result = cur.execute(sql, params) | |
if not result.fetchone(): | |
raise NotFoundError(f"Collection {id} not found") | |
# TODO: Update to use better semantics where it's possible to update | |
# individual keys without wiping all the existing metadata. | |
# For now, follow current legancy semantics where metadata is fully reset | |
if metadata != Unspecified(): | |
q = ( | |
self.querybuilder() | |
.from_(metadata_t) | |
.where( | |
metadata_t.collection_id == ParameterValue(self.uuid_to_db(id)) | |
) | |
.delete() | |
) | |
sql, params = get_sql(q, self.parameter_format()) | |
cur.execute(sql, params) | |
if metadata is not None: | |
metadata = cast(UpdateMetadata, metadata) | |
self._insert_metadata( | |
cur, | |
metadata_t, | |
metadata_t.collection_id, | |
id, | |
metadata, | |
set(metadata.keys()), | |
) | |
def _metadata_from_rows( | |
self, rows: Sequence[Tuple[Any, ...]] | |
) -> Optional[Metadata]: | |
"""Given SQL rows, return a metadata map (assuming that the last four columns | |
are the key, str_value, int_value & float_value)""" | |
add_attributes_to_current_span( | |
{ | |
"num_rows": len(rows), | |
} | |
) | |
metadata: Dict[str, Union[str, int, float]] = {} | |
for row in rows: | |
key = str(row[-4]) | |
if row[-3] is not None: | |
metadata[key] = str(row[-3]) | |
elif row[-2] is not None: | |
metadata[key] = int(row[-2]) | |
elif row[-1] is not None: | |
metadata[key] = float(row[-1]) | |
return metadata or None | |
def _insert_metadata( | |
self, | |
cur: Cursor, | |
table: Table, | |
id_col: Column, | |
id: UUID, | |
metadata: UpdateMetadata, | |
clear_keys: Optional[Set[str]] = None, | |
) -> None: | |
# It would be cleaner to use something like ON CONFLICT UPDATE here But that is | |
# very difficult to do in a portable way (e.g sqlite and postgres have | |
# completely different sytnax) | |
add_attributes_to_current_span( | |
{ | |
"num_keys": len(metadata), | |
} | |
) | |
if clear_keys: | |
q = ( | |
self.querybuilder() | |
.from_(table) | |
.where(id_col == ParameterValue(self.uuid_to_db(id))) | |
.where(table.key.isin([ParameterValue(k) for k in clear_keys])) | |
.delete() | |
) | |
sql, params = get_sql(q, self.parameter_format()) | |
cur.execute(sql, params) | |
q = ( | |
self.querybuilder() | |
.into(table) | |
.columns( | |
id_col, | |
table.key, | |
table.str_value, | |
table.int_value, | |
table.float_value, | |
) | |
) | |
sql_id = self.uuid_to_db(id) | |
for k, v in metadata.items(): | |
if isinstance(v, str): | |
q = q.insert( | |
ParameterValue(sql_id), | |
ParameterValue(k), | |
ParameterValue(v), | |
None, | |
None, | |
) | |
elif isinstance(v, int): | |
q = q.insert( | |
ParameterValue(sql_id), | |
ParameterValue(k), | |
None, | |
ParameterValue(v), | |
None, | |
) | |
elif isinstance(v, float): | |
q = q.insert( | |
ParameterValue(sql_id), | |
ParameterValue(k), | |
None, | |
None, | |
ParameterValue(v), | |
) | |
elif v is None: | |
continue | |
sql, params = get_sql(q, self.parameter_format()) | |
if sql: | |
cur.execute(sql, params) | |