from abc import ABC, abstractmethod from typing import List, Dict, Any, Optional from gradio_client import Client import logging import os from pathlib import Path from dotenv import load_dotenv from .utils import get_auth, process_content load_dotenv() class VectorStoreInterface(ABC): """Abstract interface for different vector store implementations.""" @abstractmethod def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]: """Search for similar documents.""" pass class HuggingFaceSpacesVectorStore(VectorStoreInterface): """Vector store implementation for Hugging Face Spaces with MCP endpoints.""" def __init__(self, url: str, collection_name: str, api_key: Optional[str] = None): repo_id = url logging.info(f"Connecting to Hugging Face Space: {repo_id}") if api_key: self.client = Client(repo_id, hf_token=api_key) else: self.client = Client(repo_id) self.collection_name = collection_name def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]: """Search using Hugging Face Spaces MCP API.""" try: # Use the /search_text endpoint as documented in the API result = self.client.predict( query=query, collection_name=self.collection_name, model_name=kwargs.get('model_name'), top_k=top_k, api_name="/search_text" ) logging.info(f"Successfully retrieved {len(result) if result else 0} documents") return result except Exception as e: logging.error(f"Error searching Hugging Face Spaces: {str(e)}") raise e class QdrantVectorStore(VectorStoreInterface): """Vector store implementation for direct Qdrant connection.""" def __init__(self, url: str, api_key: Optional[str] = None): from qdrant_client import QdrantClient from sentence_transformers import SentenceTransformer self.client = QdrantClient(host = url, # very important that port to be used for python client port=443, https=True, # api_key = QDRANT_API_KEY_READ, ## this is for write access api_key = api_key, timeout=120,) #self.client = QdrantClient( # url=url, # Use url parameter which handles full URLs with protocol # api_key=api_key #) #self.collection_name = collection_name # Initialize embedding model as None - will be loaded on first use self._embedding_model = None self._current_model_name = None def _get_embedding_model(self, model_name: str = None): """Lazy load embedding model to avoid loading if not needed.""" if model_name is None: model_name = "BAAI/bge-m3" # Default from config # Only reload if model name changed if self._embedding_model is None or self._current_model_name != model_name: logging.info(f"Loading embedding model: {model_name}") from sentence_transformers import SentenceTransformer cache_folder = Path(os.getenv("HF_HUB_CACHE", "/tmp/hf_cache")) cache_folder.mkdir(parents=True, exist_ok=True) self._embedding_model = SentenceTransformer( model_name, cache_folder=str(cache_folder) ) # self._embedding_model = SentenceTransformer(model_name) self._current_model_name = model_name logging.info(f"Successfully loaded embedding model: {model_name}") return self._embedding_model def search(self, query: str, collection_name:str, top_k: int, **kwargs) -> List[Dict[str, Any]]: """Search using direct Qdrant connection.""" try: # Get embedding model model_name = kwargs.get('model_name') embedding_model = self._get_embedding_model(model_name) # Convert query to embedding logging.info(f"Converting query to embedding using model: {self._current_model_name}") query_embedding = embedding_model.encode(query).tolist() # Get filter from kwargs if provided filter_obj = kwargs.get('filter', None) # Perform vector search logging.info(f"Searching Qdrant collection '{collection_name}' for top {top_k} results") search_result = self.client.search( collection_name=collection_name, query_vector=query_embedding, query_filter=filter_obj, # Add filter support limit=top_k, with_payload=True, with_vectors=False ) logging.info(search_result) # Format results to match expected output format results = [] for hit in search_result: raw_content = hit.payload.get('text', '') # Process content to handle malformed nested list structures processed_content = process_content(raw_content) result_dict = { 'answer': processed_content, 'answer_metadata': hit.payload.get('metadata', {}), 'score': hit.score } results.append(result_dict) logging.info(f"Successfully retrieved {len(results)} documents from Qdrant") return results except Exception as e: logging.error(f"Error searching Qdrant: {str(e)}") raise e def create_vectorstore(config: Any) -> VectorStoreInterface: """Factory function to create appropriate vector store based on configuration.""" vectorstore_type = config.get("vectorstore", "PROVIDER") # Get authentication config based on provider auth_config = get_auth(vectorstore_type.lower()) if vectorstore_type.lower() == "huggingface": url = config.get("vectorstore", "URL") collection_name = config.get("vectorstore", "COLLECTION_NAME") api_key = auth_config["api_key"] return HuggingFaceSpacesVectorStore(url, collection_name, api_key) elif vectorstore_type.lower() == "qdrant": url = config.get("vectorstore", "URL") # Use the full URL #collection_name = config.get("vectorstore", "COLLECTION_NAME") api_key = auth_config["api_key"] # Remove port parameter since it's included in the URL return QdrantVectorStore(url, api_key) else: raise ValueError(f"Unsupported vector store type: {vectorstore_type}")