api / src /vector_store_manager /chroma_manager.py
Chandima Prabhath
Refactor code structure for improved readability and maintainability
10b392a
# src/vector_store_manager/chroma_manager.py
from langchain_chroma import Chroma # cite: embed_pipeline.py, query_pipeline.py
from langchain.schema import Document # cite: embed_pipeline.py
from config.settings import PERSIST_DIR, CHROMADB_COLLECTION_NAME # cite: embed_pipeline.py, query_pipeline.py
from src.embedding_generator.embedder import EmbeddingGenerator
import logging
from typing import List, Dict, Any
logger = logging.getLogger(__name__)
class ChromaManager:
"""
Manages interactions with the ChromaDB vector store.
"""
def __init__(self, embedding_generator: EmbeddingGenerator):
self.embedding_generator = embedding_generator
# --- Financial Ministry Adaptation ---
# TODO: Configure Chroma client to use a scalable backend (e.g., ClickHouse)
# instead of or in addition to persistent_directory for production.
# This might involve using chromadb.HttpClient or specific backend configurations.
# Handle connection errors and retries to the database backend.
# Implement authentication/authorization for ChromaDB access.
# ------------------------------------
try:
# Initialize Chroma with the embedding function and persistence settings
# For production, you might replace persist_directory with client settings
# pointing to a ClickHouse backend.
self.vectordb = Chroma(
persist_directory=PERSIST_DIR, # cite: embed_pipeline.py, query_pipeline.py
collection_name=CHROMADB_COLLECTION_NAME, # cite: embed_pipeline.py, query_pipeline.py
embedding_function=self.embedding_generator.embedder # Use the Langchain embedder instance
)
logger.info(f"Initialized ChromaDB collection: '{CHROMADB_COLLECTION_NAME}' at '{PERSIST_DIR}'")
# You might want to check if the collection exists and its health
except Exception as e:
logger.critical(f"Failed to initialize ChromaDB: {e}")
raise e
def add_documents(self, chunks: List[Document]):
"""
Adds document chunks to the ChromaDB collection.
Args:
chunks: A list of Langchain Document chunks with metadata.
"""
# --- Financial Ministry Adaptation ---
# Implement error handling and retry logic for batch additions.
# Consider transactional behavior if adding large batches requires it.
# Log successful and failed additions.
# Ensure document IDs are managed consistently (e.g., based on source + chunk index or a stable hash).
# ------------------------------------
try:
# Langchain's add_documents handles embedding internally using the provided embedding_function
# Ensure your chunks have unique IDs if you need to update/delete later.
# If IDs are not in metadata, Langchain/Chroma might generate them.
# For better control, you might generate IDs in document_processor and pass them here.
if not chunks:
logger.warning("No chunks to add to ChromaDB.")
return
# If chunks don't have IDs, generate them (simple example)
# In a real system, use stable IDs based on source data
# chunk_ids = [f"{chunk.metadata.get('source', 'unknown')}_{i}" for i, chunk in enumerate(chunks)]
# self.vectordb.add_documents(chunks, ids=chunk_ids)
self.vectordb.add_documents(chunks) # Langchain handles IDs if not provided
logger.info(f"Added {len(chunks)} chunks to ChromaDB.")
except Exception as e:
logger.error(f"Failed to add documents to ChromaDB: {e}")
# Implement retry logic or raise exception
def update_documents(self, ids: List[str], documents: List[str], metadatas: List[Dict[str, Any]]):
"""
Updates documents in the ChromaDB collection by ID.
Args:
ids: List of document IDs to update.
documents: List of new document content corresponding to IDs.
metadatas: List of new metadata dictionaries corresponding to IDs.
"""
# --- Financial Ministry Adaptation ---
# Implement error handling and retry logic.
# Validate that IDs exist before attempting to update.
# ------------------------------------
try:
self.vectordb._collection.update( # Accessing the underlying collection for update/delete
ids=ids,
documents=documents,
metadatas=metadatas
)
logger.info(f"Updated documents with IDs: {ids}")
except Exception as e:
logger.error(f"Failed to update documents with IDs {ids}: {e}")
raise e
def delete_documents(self, ids: List[str] = None, where: Dict[str, Any] = None):
"""
Deletes documents from the ChromaDB collection by ID or metadata filter.
Args:
ids: List of document IDs to delete.
where: A dictionary for metadata filtering (e.g., {"source": "old_file.txt"}).
"""
# --- Financial Ministry Adaptation ---
# Implement error handling and retry logic.
# Add logging to record which documents were deleted and why (if using where).
# ------------------------------------
try:
if ids:
self.vectordb._collection.delete(ids=ids) # Accessing the underlying collection
logger.info(f"Deleted documents with IDs: {ids}")
elif where:
self.vectordb._collection.delete(where=where) # Accessing the underlying collection
logger.info(f"Deleted documents matching metadata filter: {where}")
else:
logger.warning("Delete called without specifying ids or where filter.")
except Exception as e:
logger.error(f"Failed to delete documents (ids: {ids}, where: {where}): {e}")
raise e
def get_documents(self, ids: List[str] = None, where: Dict[str, Any] = None,
where_document: Dict[str, Any] = None, limit: int = None,
offset: int = None, include: List[str] = None) -> Dict[str, List[Any]]:
"""
Retrieves documents and their details from the ChromaDB collection.
Args:
ids: List of document IDs to retrieve.
where: Metadata filter.
where_document: Document content filter.
limit: Maximum number of results.
offset: Offset for pagination.
include: List of fields to include (e.g., ['metadatas', 'documents']). IDs are always included.
Returns:
A dictionary containing the retrieved data (ids, documents, metadatas, etc.).
"""
# --- Financial Ministry Adaptation ---
# Implement error handling and retry logic.
# Ensure sensitive metadata is handled appropriately if retrieved.
# ------------------------------------
try:
# Default include to metadatas and documents if not specified
if include is None:
include = ['metadatas', 'documents'] # Default as per Chroma docs
results = self.vectordb._collection.get( # Accessing the underlying collection
ids=ids,
where=where,
where_document=where_document,
limit=limit,
offset=offset,
include=include
)
logger.debug(f"Retrieved {len(results.get('ids', []))} documents from ChromaDB.")
return results
except Exception as e:
logger.error(f"Failed to retrieve documents from ChromaDB: {e}")
raise e
def as_retriever(self, search_kwargs: Dict[str, Any] = None):
"""
Returns a Langchain Retriever instance for the Chroma collection.
Args:
search_kwargs: Arguments for the retriever (e.g., {"k": 5}).
Returns:
A Langchain Retriever.
"""
# --- Financial Ministry Adaptation ---
# Consider adding default search_kwargs here if not provided.
# Ensure the retriever uses the configured embedding function.
# ------------------------------------
if search_kwargs is None:
search_kwargs = {}
# Langchain's .as_retriever method automatically uses the embedding_function
# provided during Chroma initialization.
return self.vectordb.as_retriever(search_kwargs=search_kwargs) # cite: query_pipeline.py