File size: 5,581 Bytes
c88e992 d6aeaee 2182c09 d6aeaee 46ca867 0b82f4b 46ca867 0b82f4b c88e992 5f34b1a c88e992 46ca867 c88e992 46ca867 0b82f4b 46ca867 d6aeaee 46ca867 0b82f4b 2182c09 d6aeaee 2182c09 d6aeaee 2182c09 d6aeaee 2182c09 d6aeaee 2182c09 d6aeaee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import chromadb
from chromadb.config import Settings, System
import time,json
from pydantic import BaseModel
from typing import List, Dict, Any, Generator, Optional, cast, Callable
from chromadb.api.types import (
Documents,
Embeddings,
EmbeddingFunction,
IDs,
Include,
Metadatas,
Where,
WhereDocument,
GetResult,
QueryResult,
CollectionMetadata,
)
from chromadb.errors import (
ChromaError,
InvalidUUIDError,
InvalidDimensionException,
)
from chromadb.server.fastapi.types import (
AddEmbedding,
DeleteEmbedding,
GetEmbedding,
QueryEmbedding,
RawSql, # Results,
CreateCollection,
UpdateCollection,
UpdateEmbedding,
)
from chromadb.api import API
from chromadb.config import System
import chromadb.utils.embedding_functions as ef
import pandas as pd
import requests
import json
from typing import Sequence
from chromadb.api.models.Collection import Collection
import chromadb.errors as errors
from uuid import UUID
from chromadb.telemetry import Telemetry
from overrides import override
class client():
def __init__(self):
self.db = chromadb.Client(Settings(
chroma_db_impl="duckdb+parquet",
persist_directory="./index/chroma" # Optional, defaults to .chromadb/ in the current directory
))
def heartbeat(self):
return {"nanosecond heartbeat":int(time.time_ns())}
def list_collections(self):
return self.db.list_collections()
def create_collection(
self,
name: str,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
get_or_create: bool = False,
) -> Collection:
col=self.db.create_collection(name,metadata=metadata,embedding_function=embedding_function,get_or_create=get_or_create)
print(col)
return col
def get_collection(
self,
name: str,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction()
) -> Collection:
col=self.db.get_collection(name,embedding_function=embedding_function)
print(col)
return col
def reset(self):
return self.db.reset()
def version(self):
return self.db.get_version()
def persist(self):
return self.db.persist()
def raw_sql(self,raw_sql: RawSql):
return self.db.raw_sql(raw_sql.raw_sql)
def add(self,ids: IDs,
collection_id: UUID,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
increment_index: bool = True,
) -> bool:
return self.db._add(collection_id=collection_id,embeddings=embeddings,
metadatas=metadatas,documents=documents,
ids=ids,increment_index=increment_index)
def update( self, collection_id: UUID, ids: IDs,
embeddings: Optional[Embeddings] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
) -> bool:
return self.db._update(ids=ids, collection_id=collection_id, embeddings=embeddings, documents=documents, metadatas=metadatas)
def upsert( self, collection_id: UUID, ids: IDs,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
increment_index: bool = True,
) -> bool:
return self.db._upsert(collection_id=collection_id,embeddings=embeddings,metadatas=metadatas,documents=documents,ids=ids,increment_index=increment_index)
def get( self, collection_id: UUID, ids: Optional[IDs] = None, where: Optional[Where] = {},
sort: Optional[str] = None, limit: Optional[int] = None, offset: Optional[int] = None,
page: Optional[int] = None, page_size: Optional[int] = None,
where_document: Optional[WhereDocument] = {},
include: Include = ["embeddings", "metadatas", "documents"],
) -> GetResult:
return self.db._get(collection_id=collection_id, ids=ids, where=where,
where_document=where_document, sort=sort, limit=limit,
offset=offset, include=include)
def delete( self, collection_id: UUID, ids: Optional[IDs],
where: Optional[Where] = {}, where_document: Optional[WhereDocument] = {},
) -> IDs:
return self.db._delete(where=where, ids=ids, collection_id=collection_id, where_document=where_document)
def count(self, collection_id: UUID) -> int:
return self.db._.count(collection_id)
def get_nearest_neighbors( self, collection_id: UUID, query_embeddings: Embeddings,
n_results: int = 10, where: Where = {}, where_document: WhereDocument = {},
include: Include = ["embeddings", "metadatas", "documents", "distances"],
) -> QueryResult:
return self.db._query(collection_id=collection_id, where=where, where_document=where_document,
query_embeddings=query_embeddings, n_results=n_results, include=include)
def create_index(self, collection_name: str) -> bool:
return self.db.create_index(collection_name)
def modify( self, id: UUID, new_name: Optional[str] = None,
new_metadata: Optional[CollectionMetadata] = None,
) -> None:
"""This is for updating the collection"""
return self.db._modify(id=id, new_name=new_name, new_metadata=new_metadata)
def delete_collection( self, name: str,) -> None:
return self.db.delete_collection(name)
|