chatfed_retriever0.3 / utils /vectorstore_interface.py
ppsingh's picture
Update utils/vectorstore_interface.py
c7e3ba0 verified
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
from gradio_client import Client
import logging
import os
from pathlib import Path
from dotenv import load_dotenv
from .utils import get_auth, process_content
load_dotenv()
class VectorStoreInterface(ABC):
"""Abstract interface for different vector store implementations."""
@abstractmethod
def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]:
"""Search for similar documents."""
pass
class HuggingFaceSpacesVectorStore(VectorStoreInterface):
"""Vector store implementation for Hugging Face Spaces with MCP endpoints."""
def __init__(self, url: str, collection_name: str, api_key: Optional[str] = None):
repo_id = url
logging.info(f"Connecting to Hugging Face Space: {repo_id}")
if api_key:
self.client = Client(repo_id, hf_token=api_key)
else:
self.client = Client(repo_id)
self.collection_name = collection_name
def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]:
"""Search using Hugging Face Spaces MCP API."""
try:
# Use the /search_text endpoint as documented in the API
result = self.client.predict(
query=query,
collection_name=self.collection_name,
model_name=kwargs.get('model_name'),
top_k=top_k,
api_name="/search_text"
)
logging.info(f"Successfully retrieved {len(result) if result else 0} documents")
return result
except Exception as e:
logging.error(f"Error searching Hugging Face Spaces: {str(e)}")
raise e
class QdrantVectorStore(VectorStoreInterface):
"""Vector store implementation for direct Qdrant connection."""
def __init__(self, url: str, api_key: Optional[str] = None):
from qdrant_client import QdrantClient
from sentence_transformers import SentenceTransformer
self.client = QdrantClient(host = url,
# very important that port to be used for python client
port=443,
https=True,
# api_key = QDRANT_API_KEY_READ,
## this is for write access
api_key = api_key,
timeout=120,)
#self.client = QdrantClient(
# url=url, # Use url parameter which handles full URLs with protocol
# api_key=api_key
#)
#self.collection_name = collection_name
# Initialize embedding model as None - will be loaded on first use
self._embedding_model = None
self._current_model_name = None
def _get_embedding_model(self, model_name: str = None):
"""Lazy load embedding model to avoid loading if not needed."""
if model_name is None:
model_name = "BAAI/bge-m3" # Default from config
# Only reload if model name changed
if self._embedding_model is None or self._current_model_name != model_name:
logging.info(f"Loading embedding model: {model_name}")
from sentence_transformers import SentenceTransformer
cache_folder = Path(os.getenv("HF_HUB_CACHE", "/tmp/hf_cache"))
cache_folder.mkdir(parents=True, exist_ok=True)
self._embedding_model = SentenceTransformer(
model_name,
cache_folder=str(cache_folder)
)
# self._embedding_model = SentenceTransformer(model_name)
self._current_model_name = model_name
logging.info(f"Successfully loaded embedding model: {model_name}")
return self._embedding_model
def search(self, query: str, collection_name:str, top_k: int, **kwargs) -> List[Dict[str, Any]]:
"""Search using direct Qdrant connection."""
try:
# Get embedding model
model_name = kwargs.get('model_name')
embedding_model = self._get_embedding_model(model_name)
# Convert query to embedding
logging.info(f"Converting query to embedding using model: {self._current_model_name}")
query_embedding = embedding_model.encode(query).tolist()
# Get filter from kwargs if provided
filter_obj = kwargs.get('filter', None)
# Perform vector search
logging.info(f"Searching Qdrant collection '{collection_name}' for top {top_k} results")
search_result = self.client.search(
collection_name=collection_name,
query_vector=query_embedding,
query_filter=filter_obj, # Add filter support
limit=top_k,
with_payload=True,
with_vectors=False
)
logging.info(search_result)
# Format results to match expected output format
results = []
for hit in search_result:
raw_content = hit.payload.get("text", "")
# Process content to handle malformed nested list structures
processed_content = process_content(raw_content)
result_dict = {
"answer": processed_content,
"answer_metadata": hit.payload.get("metadata", {}),
"score": hit.score
}
results.append(result_dict)
logging.info(f"Successfully retrieved {len(results)} documents from Qdrant")
return results
except Exception as e:
logging.error(f"Error searching Qdrant: {str(e)}")
raise e
def create_vectorstore(config: Any) -> VectorStoreInterface:
"""Factory function to create appropriate vector store based on configuration."""
vectorstore_type = config.get("vectorstore", "PROVIDER")
# Get authentication config based on provider
auth_config = get_auth(vectorstore_type.lower())
if vectorstore_type.lower() == "huggingface":
url = config.get("vectorstore", "URL")
collection_name = config.get("vectorstore", "COLLECTION_NAME")
api_key = auth_config["api_key"]
return HuggingFaceSpacesVectorStore(url, collection_name, api_key)
elif vectorstore_type.lower() == "qdrant":
url = config.get("vectorstore", "URL") # Use the full URL
#collection_name = config.get("vectorstore", "COLLECTION_NAME")
api_key = auth_config["api_key"]
# Remove port parameter since it's included in the URL
return QdrantVectorStore(url, api_key)
else:
raise ValueError(f"Unsupported vector store type: {vectorstore_type}")