mcp-network-doc-dem / query_interface.py
rogerscuall's picture
Upload folder using huggingface_hub
63f7ae3 verified
#!/usr/bin/env python3
# /// script
# dependencies = [
# "langchain_community",
# "chromadb",
# "huggingface_hub",
# "langchain_community",
# "sentence_transformers",
# "pydantic"
# ]
# ///
#!/usr/bin/env python3
"""
Query interface for Arista AVD documentation vector database.
Provides search and retrieval capabilities.
"""
import argparse
import json
from typing import List, Dict, Any, Optional
from pathlib import Path
import logging
from pydantic import BaseModel, Field
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.schema import Document
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class EmbeddingConfig(BaseModel):
"""Configuration for embeddings."""
model_name: str = Field(default="all-MiniLM-L6-v2", description="The name of the HuggingFace model to use")
device: str = Field(default="cpu", description="Device to use for embedding generation (cpu or cuda)")
normalize_embeddings: bool = Field(default=True, description="Whether to normalize embeddings")
class AristaDocumentQuery(BaseModel):
"""Query interface for Arista AVD documentation."""
persist_directory: str = Field(default="./chroma_db", description="Directory containing the vector store")
embedding_config: EmbeddingConfig = Field(default_factory=EmbeddingConfig, description="Configuration for embeddings")
# These will be initialized in __init__
embeddings: Any = Field(default=None, exclude=True)
vector_store: Any = Field(default=None, exclude=True)
class Config:
arbitrary_types_allowed = True
def __init__(self, **data):
super().__init__(**data)
self.embeddings = HuggingFaceEmbeddings(
model_name=self.embedding_config.model_name,
model_kwargs={'device': self.embedding_config.device},
encode_kwargs={'normalize_embeddings': self.embedding_config.normalize_embeddings}
)
self.vector_store = self._load_vector_store()
def _load_vector_store(self) -> Chroma:
"""Load the existing vector store."""
try:
vector_store = Chroma(
persist_directory=self.persist_directory,
embedding_function=self.embeddings
)
logger.info(f"Loaded vector store from {self.persist_directory}")
return vector_store
except Exception as e:
logger.error(f"Error loading vector store: {e}")
raise
def similarity_search(self, query: str, k: int = 5, filter_dict: Optional[Dict] = None) -> List[Document]:
"""Perform similarity search on the vector store."""
try:
if filter_dict:
results = self.vector_store.similarity_search(
query=query,
k=k,
filter=filter_dict
)
else:
results = self.vector_store.similarity_search(
query=query,
k=k
)
return results
except Exception as e:
logger.error(f"Error during similarity search: {e}")
return []
def search_by_category(self, query: str, category: str, k: int = 5) -> List[Document]:
"""Search documents within a specific category."""
filter_dict = {"category": category}
return self.similarity_search(query, k=k, filter_dict=filter_dict)
def search_by_type(self, query: str, doc_type: str, k: int = 5) -> List[Document]:
"""Search documents of a specific type (markdown/csv)."""
filter_dict = {"type": doc_type}
return self.similarity_search(query, k=k, filter_dict=filter_dict)
def get_categories(self) -> List[str]:
"""Get all available categories in the vector store."""
# This is a simplified version - in a real implementation,
# you might want to query the metadata directly from ChromaDB
categories = [
'device_configuration',
'fabric_documentation',
'testing',
'netbox_integration',
'arista_cloud_test',
'avd_design',
'api_usage',
'workflow',
'infoblox_integration',
'network_testing',
'general_documentation',
'project_documentation'
]
return categories
def format_results(self, results: List[Document], verbose: bool = False) -> str:
"""Format search results for display."""
output = []
for i, doc in enumerate(results, 1):
output.append(f"\n{'='*80}")
output.append(f"Result {i}:")
output.append(f"Source: {doc.metadata.get('source', 'Unknown')}")
output.append(f"Category: {doc.metadata.get('category', 'Unknown')}")
output.append(f"Type: {doc.metadata.get('type', 'Unknown')}")
if doc.metadata.get('type') == 'csv':
output.append(f"Columns: {doc.metadata.get('columns', 'Unknown')}")
output.append(f"Rows: {doc.metadata.get('rows', 'Unknown')}")
output.append(f"\nContent Preview:")
content_preview = doc.page_content[:500] + "..." if len(doc.page_content) > 500 else doc.page_content
output.append(content_preview)
if verbose:
output.append(f"\nFull Content:")
output.append(doc.page_content)
return "\n".join(output)
def export_results(self, results: List[Document], output_file: str) -> None:
"""Export search results to a JSON file."""
data = []
for doc in results:
data.append({
'content': doc.page_content,
'metadata': doc.metadata
})
with open(output_file, 'w') as f:
json.dump(data, f, indent=2)
logger.info(f"Results exported to {output_file}")
def main():
"""Main function for command-line interface."""
parser = argparse.ArgumentParser(description="Query Arista AVD documentation vector database")
parser.add_argument("query", nargs="?", help="Search query")
parser.add_argument("-k", "--top-k", type=int, default=5, help="Number of results to return (default: 5)")
parser.add_argument("-c", "--category", help="Filter by category")
parser.add_argument("-t", "--type", choices=['markdown', 'csv'], help="Filter by document type")
parser.add_argument("-v", "--verbose", action="store_true", help="Show full content")
parser.add_argument("-e", "--export", help="Export results to JSON file")
parser.add_argument("--list-categories", action="store_true", help="List available categories")
args = parser.parse_args()
# Initialize query interface
query_interface = AristaDocumentQuery()
# List categories if requested
if args.list_categories:
categories = query_interface.get_categories()
print("Available categories:")
for cat in categories:
print(f" - {cat}")
return
# Ensure query is provided if not listing categories
if not args.query:
parser.error("Query is required unless using --list-categories")
# Perform search
if args.category:
results = query_interface.search_by_category(args.query, args.category, k=args.top_k)
elif args.type:
results = query_interface.search_by_type(args.query, args.type, k=args.top_k)
else:
results = query_interface.similarity_search(args.query, k=args.top_k)
# Display results
if results:
formatted_results = query_interface.format_results(results, verbose=args.verbose)
print(formatted_results)
# Export if requested
if args.export:
query_interface.export_results(results, args.export)
else:
print("No results found.")
if __name__ == "__main__":
main()