Spaces:
Running
on
T4
Running
on
T4
from abc import ABC, abstractmethod | |
from typing import List, Dict, Any, Optional | |
from gradio_client import Client | |
import logging | |
import os | |
from pathlib import Path | |
from dotenv import load_dotenv | |
from .utils import get_auth, process_content | |
load_dotenv() | |
class VectorStoreInterface(ABC): | |
"""Abstract interface for different vector store implementations.""" | |
def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]: | |
"""Search for similar documents.""" | |
pass | |
class HuggingFaceSpacesVectorStore(VectorStoreInterface): | |
"""Vector store implementation for Hugging Face Spaces with MCP endpoints.""" | |
def __init__(self, url: str, collection_name: str, api_key: Optional[str] = None): | |
repo_id = url | |
logging.info(f"Connecting to Hugging Face Space: {repo_id}") | |
if api_key: | |
self.client = Client(repo_id, hf_token=api_key) | |
else: | |
self.client = Client(repo_id) | |
self.collection_name = collection_name | |
def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]: | |
"""Search using Hugging Face Spaces MCP API.""" | |
try: | |
# Use the /search_text endpoint as documented in the API | |
result = self.client.predict( | |
query=query, | |
collection_name=self.collection_name, | |
model_name=kwargs.get('model_name'), | |
top_k=top_k, | |
api_name="/search_text" | |
) | |
logging.info(f"Successfully retrieved {len(result) if result else 0} documents") | |
return result | |
except Exception as e: | |
logging.error(f"Error searching Hugging Face Spaces: {str(e)}") | |
raise e | |
class QdrantVectorStore(VectorStoreInterface): | |
"""Vector store implementation for direct Qdrant connection.""" | |
def __init__(self, url: str, api_key: Optional[str] = None): | |
from qdrant_client import QdrantClient | |
from sentence_transformers import SentenceTransformer | |
self.client = QdrantClient(host = url, | |
# very important that port to be used for python client | |
port=443, | |
https=True, | |
# api_key = QDRANT_API_KEY_READ, | |
## this is for write access | |
api_key = api_key, | |
timeout=120,) | |
#self.client = QdrantClient( | |
# url=url, # Use url parameter which handles full URLs with protocol | |
# api_key=api_key | |
#) | |
#self.collection_name = collection_name | |
# Initialize embedding model as None - will be loaded on first use | |
self._embedding_model = None | |
self._current_model_name = None | |
def _get_embedding_model(self, model_name: str = None): | |
"""Lazy load embedding model to avoid loading if not needed.""" | |
if model_name is None: | |
model_name = "BAAI/bge-m3" # Default from config | |
# Only reload if model name changed | |
if self._embedding_model is None or self._current_model_name != model_name: | |
logging.info(f"Loading embedding model: {model_name}") | |
from sentence_transformers import SentenceTransformer | |
cache_folder = Path(os.getenv("HF_HUB_CACHE", "/tmp/hf_cache")) | |
cache_folder.mkdir(parents=True, exist_ok=True) | |
self._embedding_model = SentenceTransformer( | |
model_name, | |
cache_folder=str(cache_folder) | |
) | |
# self._embedding_model = SentenceTransformer(model_name) | |
self._current_model_name = model_name | |
logging.info(f"Successfully loaded embedding model: {model_name}") | |
return self._embedding_model | |
def search(self, query: str, collection_name:str, top_k: int, **kwargs) -> List[Dict[str, Any]]: | |
"""Search using direct Qdrant connection.""" | |
try: | |
# Get embedding model | |
model_name = kwargs.get('model_name') | |
embedding_model = self._get_embedding_model(model_name) | |
# Convert query to embedding | |
logging.info(f"Converting query to embedding using model: {self._current_model_name}") | |
query_embedding = embedding_model.encode(query).tolist() | |
# Get filter from kwargs if provided | |
filter_obj = kwargs.get('filter', None) | |
# Perform vector search | |
logging.info(f"Searching Qdrant collection '{collection_name}' for top {top_k} results") | |
search_result = self.client.search( | |
collection_name=collection_name, | |
query_vector=query_embedding, | |
query_filter=filter_obj, # Add filter support | |
limit=top_k, | |
with_payload=True, | |
with_vectors=False | |
) | |
logging.info(search_result) | |
# Format results to match expected output format | |
results = [] | |
for hit in search_result: | |
raw_content = hit.payload.get("text", "") | |
# Process content to handle malformed nested list structures | |
processed_content = process_content(raw_content) | |
result_dict = { | |
"answer": processed_content, | |
"answer_metadata": hit.payload.get("metadata", {}), | |
"score": hit.score | |
} | |
results.append(result_dict) | |
logging.info(f"Successfully retrieved {len(results)} documents from Qdrant") | |
return results | |
except Exception as e: | |
logging.error(f"Error searching Qdrant: {str(e)}") | |
raise e | |
def create_vectorstore(config: Any) -> VectorStoreInterface: | |
"""Factory function to create appropriate vector store based on configuration.""" | |
vectorstore_type = config.get("vectorstore", "PROVIDER") | |
# Get authentication config based on provider | |
auth_config = get_auth(vectorstore_type.lower()) | |
if vectorstore_type.lower() == "huggingface": | |
url = config.get("vectorstore", "URL") | |
collection_name = config.get("vectorstore", "COLLECTION_NAME") | |
api_key = auth_config["api_key"] | |
return HuggingFaceSpacesVectorStore(url, collection_name, api_key) | |
elif vectorstore_type.lower() == "qdrant": | |
url = config.get("vectorstore", "URL") # Use the full URL | |
#collection_name = config.get("vectorstore", "COLLECTION_NAME") | |
api_key = auth_config["api_key"] | |
# Remove port parameter since it's included in the URL | |
return QdrantVectorStore(url, api_key) | |
else: | |
raise ValueError(f"Unsupported vector store type: {vectorstore_type}") |