|
import chromadb |
|
from chromadb.config import Settings, System |
|
import time,json |
|
from pydantic import BaseModel |
|
from typing import List, Dict, Any, Generator, Optional, cast, Callable |
|
from chromadb.api.types import ( |
|
Documents, |
|
Embeddings, |
|
EmbeddingFunction, |
|
IDs, |
|
Include, |
|
Metadatas, |
|
Where, |
|
WhereDocument, |
|
GetResult, |
|
QueryResult, |
|
CollectionMetadata, |
|
) |
|
from chromadb.errors import ( |
|
ChromaError, |
|
InvalidUUIDError, |
|
InvalidDimensionException, |
|
) |
|
from chromadb.server.fastapi.types import ( |
|
AddEmbedding, |
|
DeleteEmbedding, |
|
GetEmbedding, |
|
QueryEmbedding, |
|
RawSql, |
|
CreateCollection, |
|
UpdateCollection, |
|
UpdateEmbedding, |
|
) |
|
from chromadb.api import API |
|
from chromadb.config import System |
|
import chromadb.utils.embedding_functions as ef |
|
import pandas as pd |
|
import requests |
|
import json |
|
from typing import Sequence |
|
from chromadb.api.models.Collection import Collection |
|
import chromadb.errors as errors |
|
from uuid import UUID |
|
from chromadb.telemetry import Telemetry |
|
from overrides import override |
|
|
|
|
|
class client(): |
|
def __init__(self): |
|
self.db = chromadb.Client(Settings( |
|
chroma_db_impl="duckdb+parquet", |
|
persist_directory="./index/chroma" |
|
)) |
|
|
|
def heartbeat(self): |
|
return {"nanosecond heartbeat":int(time.time_ns())} |
|
|
|
def list_collections(self): |
|
return self.db.list_collections() |
|
|
|
def create_collection( |
|
self, |
|
name: str, |
|
metadata: Optional[CollectionMetadata] = None, |
|
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), |
|
get_or_create: bool = False, |
|
) -> Collection: |
|
col=self.db.create_collection(name,metadata=metadata,embedding_function=embedding_function,get_or_create=get_or_create) |
|
print(col) |
|
return col |
|
|
|
def get_collection( |
|
self, |
|
name: str, |
|
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction() |
|
) -> Collection: |
|
col=self.db.get_collection(name,embedding_function=embedding_function) |
|
print(col) |
|
return col |
|
|
|
def reset(self): |
|
return self.db.reset() |
|
|
|
def version(self): |
|
return self.db.get_version() |
|
|
|
def persist(self): |
|
return self.db.persist() |
|
|
|
def raw_sql(self,raw_sql: RawSql): |
|
return self.db.raw_sql(raw_sql.raw_sql) |
|
|
|
def add(self,ids: IDs, |
|
collection_id: UUID, |
|
embeddings: Embeddings, |
|
metadatas: Optional[Metadatas] = None, |
|
documents: Optional[Documents] = None, |
|
increment_index: bool = True, |
|
) -> bool: |
|
return self.db._add(collection_id=collection_id,embeddings=embeddings, |
|
metadatas=metadatas,documents=documents, |
|
ids=ids,increment_index=increment_index) |
|
|
|
def update( self, collection_id: UUID, ids: IDs, |
|
embeddings: Optional[Embeddings] = None, |
|
metadatas: Optional[Metadatas] = None, |
|
documents: Optional[Documents] = None, |
|
) -> bool: |
|
return self.db._update(ids=ids, collection_id=collection_id, embeddings=embeddings, documents=documents, metadatas=metadatas) |
|
|
|
def upsert( self, collection_id: UUID, ids: IDs, |
|
embeddings: Embeddings, |
|
metadatas: Optional[Metadatas] = None, |
|
documents: Optional[Documents] = None, |
|
increment_index: bool = True, |
|
) -> bool: |
|
return self.db._upsert(collection_id=collection_id,embeddings=embeddings,metadatas=metadatas,documents=documents,ids=ids,increment_index=increment_index) |
|
|
|
def get( self, collection_id: UUID, ids: Optional[IDs] = None, where: Optional[Where] = {}, |
|
sort: Optional[str] = None, limit: Optional[int] = None, offset: Optional[int] = None, |
|
page: Optional[int] = None, page_size: Optional[int] = None, |
|
where_document: Optional[WhereDocument] = {}, |
|
include: Include = ["embeddings", "metadatas", "documents"], |
|
) -> GetResult: |
|
return self.db._get(collection_id=collection_id, ids=ids, where=where, |
|
where_document=where_document, sort=sort, limit=limit, |
|
offset=offset, include=include) |
|
def delete( self, collection_id: UUID, ids: Optional[IDs], |
|
where: Optional[Where] = {}, where_document: Optional[WhereDocument] = {}, |
|
) -> IDs: |
|
return self.db._delete(where=where, ids=ids, collection_id=collection_id, where_document=where_document) |
|
|
|
def count(self, collection_id: UUID) -> int: |
|
return self.db._.count(collection_id) |
|
|
|
def get_nearest_neighbors( self, collection_id: UUID, query_embeddings: Embeddings, |
|
n_results: int = 10, where: Where = {}, where_document: WhereDocument = {}, |
|
include: Include = ["embeddings", "metadatas", "documents", "distances"], |
|
) -> QueryResult: |
|
return self.db._query(collection_id=collection_id, where=where, where_document=where_document, |
|
query_embeddings=query_embeddings, n_results=n_results, include=include) |
|
|
|
def create_index(self, collection_name: str) -> bool: |
|
return self.db.create_index(collection_name) |
|
|
|
def modify( self, id: UUID, new_name: Optional[str] = None, |
|
new_metadata: Optional[CollectionMetadata] = None, |
|
) -> None: |
|
"""This is for updating the collection""" |
|
return self.db._modify(id=id, new_name=new_name, new_metadata=new_metadata) |
|
|
|
def delete_collection( self, name: str,) -> None: |
|
return self.db.delete_collection(name) |
|
|
|
|