chroma / server.py
anubhav77's picture
Adding all other methods
d6aeaee
raw
history blame
5.58 kB
import chromadb
from chromadb.config import Settings, System
import time,json
from pydantic import BaseModel
from typing import List, Dict, Any, Generator, Optional, cast, Callable
from chromadb.api.types import (
Documents,
Embeddings,
EmbeddingFunction,
IDs,
Include,
Metadatas,
Where,
WhereDocument,
GetResult,
QueryResult,
CollectionMetadata,
)
from chromadb.errors import (
ChromaError,
InvalidUUIDError,
InvalidDimensionException,
)
from chromadb.server.fastapi.types import (
AddEmbedding,
DeleteEmbedding,
GetEmbedding,
QueryEmbedding,
RawSql, # Results,
CreateCollection,
UpdateCollection,
UpdateEmbedding,
)
from chromadb.api import API
from chromadb.config import System
import chromadb.utils.embedding_functions as ef
import pandas as pd
import requests
import json
from typing import Sequence
from chromadb.api.models.Collection import Collection
import chromadb.errors as errors
from uuid import UUID
from chromadb.telemetry import Telemetry
from overrides import override
class client():
def __init__(self):
self.db = chromadb.Client(Settings(
chroma_db_impl="duckdb+parquet",
persist_directory="./index/chroma" # Optional, defaults to .chromadb/ in the current directory
))
def heartbeat(self):
return {"nanosecond heartbeat":int(time.time_ns())}
def list_collections(self):
return self.db.list_collections()
def create_collection(
self,
name: str,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
get_or_create: bool = False,
) -> Collection:
col=self.db.create_collection(name,metadata=metadata,embedding_function=embedding_function,get_or_create=get_or_create)
print(col)
return col
def get_collection(
self,
name: str,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction()
) -> Collection:
col=self.db.get_collection(name,embedding_function=embedding_function)
print(col)
return col
def reset(self):
return self.db.reset()
def version(self):
return self.db.get_version()
def persist(self):
return self.db.persist()
def raw_sql(self,raw_sql: RawSql):
return self.db.raw_sql(raw_sql.raw_sql)
def add(self,ids: IDs,
collection_id: UUID,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
increment_index: bool = True,
) -> bool:
return self.db._add(collection_id=collection_id,embeddings=embeddings,
metadatas=metadatas,documents=documents,
ids=ids,increment_index=increment_index)
def update( self, collection_id: UUID, ids: IDs,
embeddings: Optional[Embeddings] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
) -> bool:
return self.db._update(ids=ids, collection_id=collection_id, embeddings=embeddings, documents=documents, metadatas=metadatas)
def upsert( self, collection_id: UUID, ids: IDs,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
increment_index: bool = True,
) -> bool:
return self.db._upsert(collection_id=collection_id,embeddings=embeddings,metadatas=metadatas,documents=documents,ids=ids,increment_index=increment_index)
def get( self, collection_id: UUID, ids: Optional[IDs] = None, where: Optional[Where] = {},
sort: Optional[str] = None, limit: Optional[int] = None, offset: Optional[int] = None,
page: Optional[int] = None, page_size: Optional[int] = None,
where_document: Optional[WhereDocument] = {},
include: Include = ["embeddings", "metadatas", "documents"],
) -> GetResult:
return self.db._get(collection_id=collection_id, ids=ids, where=where,
where_document=where_document, sort=sort, limit=limit,
offset=offset, include=include)
def delete( self, collection_id: UUID, ids: Optional[IDs],
where: Optional[Where] = {}, where_document: Optional[WhereDocument] = {},
) -> IDs:
return self.db._delete(where=where, ids=ids, collection_id=collection_id, where_document=where_document)
def count(self, collection_id: UUID) -> int:
return self.db._.count(collection_id)
def get_nearest_neighbors( self, collection_id: UUID, query_embeddings: Embeddings,
n_results: int = 10, where: Where = {}, where_document: WhereDocument = {},
include: Include = ["embeddings", "metadatas", "documents", "distances"],
) -> QueryResult:
return self.db._query(collection_id=collection_id, where=where, where_document=where_document,
query_embeddings=query_embeddings, n_results=n_results, include=include)
def create_index(self, collection_name: str) -> bool:
return self.db.create_index(collection_name)
def modify( self, id: UUID, new_name: Optional[str] = None,
new_metadata: Optional[CollectionMetadata] = None,
) -> None:
"""This is for updating the collection"""
return self.db._modify(id=id, new_name=new_name, new_metadata=new_metadata)
def delete_collection( self, name: str,) -> None:
return self.db.delete_collection(name)