Spaces:
Running
on
T4
Running
on
T4
File size: 7,049 Bytes
137c471 a6471b0 137c471 b63e909 137c471 a6471b0 137c471 b63e909 137c471 a6471b0 137c471 a6471b0 137c471 a6471b0 137c471 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
from gradio_client import Client
import logging
import os
from pathlib import Path
from dotenv import load_dotenv
from .utils import get_auth, process_content
load_dotenv()
class VectorStoreInterface(ABC):
"""Abstract interface for different vector store implementations."""
@abstractmethod
def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]:
"""Search for similar documents."""
pass
class HuggingFaceSpacesVectorStore(VectorStoreInterface):
"""Vector store implementation for Hugging Face Spaces with MCP endpoints."""
def __init__(self, url: str, collection_name: str, api_key: Optional[str] = None):
repo_id = url
logging.info(f"Connecting to Hugging Face Space: {repo_id}")
if api_key:
self.client = Client(repo_id, hf_token=api_key)
else:
self.client = Client(repo_id)
self.collection_name = collection_name
def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]:
"""Search using Hugging Face Spaces MCP API."""
try:
# Use the /search_text endpoint as documented in the API
result = self.client.predict(
query=query,
collection_name=self.collection_name,
model_name=kwargs.get('model_name'),
top_k=top_k,
api_name="/search_text"
)
logging.info(f"Successfully retrieved {len(result) if result else 0} documents")
return result
except Exception as e:
logging.error(f"Error searching Hugging Face Spaces: {str(e)}")
raise e
class QdrantVectorStore(VectorStoreInterface):
"""Vector store implementation for direct Qdrant connection."""
def __init__(self, url: str, api_key: Optional[str] = None):
from qdrant_client import QdrantClient
from sentence_transformers import SentenceTransformer
self.client = QdrantClient(host = url,
# very important that port to be used for python client
port=443,
https=True,
# api_key = QDRANT_API_KEY_READ,
## this is for write access
api_key = api_key,
timeout=120,)
#self.client = QdrantClient(
# url=url, # Use url parameter which handles full URLs with protocol
# api_key=api_key
#)
#self.collection_name = collection_name
# Initialize embedding model as None - will be loaded on first use
self._embedding_model = None
self._current_model_name = None
def _get_embedding_model(self, model_name: str = None):
"""Lazy load embedding model to avoid loading if not needed."""
if model_name is None:
model_name = "BAAI/bge-m3" # Default from config
# Only reload if model name changed
if self._embedding_model is None or self._current_model_name != model_name:
logging.info(f"Loading embedding model: {model_name}")
from sentence_transformers import SentenceTransformer
cache_folder = Path(os.getenv("HF_HUB_CACHE", "/tmp/hf_cache"))
cache_folder.mkdir(parents=True, exist_ok=True)
self._embedding_model = SentenceTransformer(
model_name,
cache_folder=str(cache_folder)
)
# self._embedding_model = SentenceTransformer(model_name)
self._current_model_name = model_name
logging.info(f"Successfully loaded embedding model: {model_name}")
return self._embedding_model
def search(self, query: str, collection_name:str, top_k: int, **kwargs) -> List[Dict[str, Any]]:
"""Search using direct Qdrant connection."""
try:
# Get embedding model
model_name = kwargs.get('model_name')
embedding_model = self._get_embedding_model(model_name)
# Convert query to embedding
logging.info(f"Converting query to embedding using model: {self._current_model_name}")
query_embedding = embedding_model.encode(query).tolist()
# Get filter from kwargs if provided
filter_obj = kwargs.get('filter', None)
# Perform vector search
logging.info(f"Searching Qdrant collection '{collection_name}' for top {top_k} results")
search_result = self.client.search(
collection_name=collection_name,
query_vector=query_embedding,
query_filter=filter_obj, # Add filter support
limit=top_k,
with_payload=True,
with_vectors=False
)
logging.info(search_result)
# Format results to match expected output format
results = []
for hit in search_result:
raw_content = hit.payload.get('text', '')
# Process content to handle malformed nested list structures
processed_content = process_content(raw_content)
result_dict = {
'answer': processed_content,
'answer_metadata': hit.payload.get('metadata', {}),
'score': hit.score
}
results.append(result_dict)
logging.info(f"Successfully retrieved {len(results)} documents from Qdrant")
return results
except Exception as e:
logging.error(f"Error searching Qdrant: {str(e)}")
raise e
def create_vectorstore(config: Any) -> VectorStoreInterface:
"""Factory function to create appropriate vector store based on configuration."""
vectorstore_type = config.get("vectorstore", "PROVIDER")
# Get authentication config based on provider
auth_config = get_auth(vectorstore_type.lower())
if vectorstore_type.lower() == "huggingface":
url = config.get("vectorstore", "URL")
collection_name = config.get("vectorstore", "COLLECTION_NAME")
api_key = auth_config["api_key"]
return HuggingFaceSpacesVectorStore(url, collection_name, api_key)
elif vectorstore_type.lower() == "qdrant":
url = config.get("vectorstore", "URL") # Use the full URL
#collection_name = config.get("vectorstore", "COLLECTION_NAME")
api_key = auth_config["api_key"]
# Remove port parameter since it's included in the URL
return QdrantVectorStore(url, api_key)
else:
raise ValueError(f"Unsupported vector store type: {vectorstore_type}") |