from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List, Dict, Optional from langchain_community.embeddings import SentenceTransformerEmbeddings import os from dotenv import load_dotenv from langchain_community.vectorstores import FAISS from langchain_core.embeddings import Embeddings import logging load_dotenv() logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) app = FastAPI(title="Retriever Agent") FAISS_INDEX_PATH = os.getenv( "FAISS_INDEX_PATH", "/app/faiss_index_store" ) # Path inside container EMBEDDING_MODEL_NAME = os.getenv("EMBEDDING_MODEL_NAME", "all-MiniLM-L6-v2") embedding_model_instance: Optional[Embeddings] = None vectorstore_instance: Optional[FAISS] = None def get_embedding_model() -> Embeddings: """Initialize and return the SentenceTransformerEmbeddings model.""" global embedding_model_instance if embedding_model_instance is None: try: logger.info( f"Loading SentenceTransformerEmbeddings with model: {EMBEDDING_MODEL_NAME}" ) embedding_model_instance = SentenceTransformerEmbeddings( model_name=EMBEDDING_MODEL_NAME ) logger.info( f"SentenceTransformerEmbeddings model '{EMBEDDING_MODEL_NAME}' loaded successfully." ) except Exception as e: logger.error( f"Error loading SentenceTransformerEmbeddings model '{EMBEDDING_MODEL_NAME}': {e}", exc_info=True, ) raise RuntimeError(f"Could not load embedding model: {e}") return embedding_model_instance def get_vectorstore() -> FAISS: """Load or create the FAISS vector store.""" global vectorstore_instance if vectorstore_instance is None: emb_model = get_embedding_model() if os.path.exists(FAISS_INDEX_PATH) and os.path.isdir(FAISS_INDEX_PATH): try: logger.info( f"Attempting to load FAISS index from {FAISS_INDEX_PATH}..." ) vectorstore_instance = FAISS.load_local( FAISS_INDEX_PATH, emb_model, allow_dangerous_deserialization=True, ) logger.info( f"FAISS index loaded from {FAISS_INDEX_PATH}. Documents: {vectorstore_instance.index.ntotal if vectorstore_instance.index else 'N/A'}" ) except Exception as e: logger.error( f"Error loading FAISS index from {FAISS_INDEX_PATH}: {e}", exc_info=True, ) logger.warning("Creating a new FAISS index due to loading error.") try: vectorstore_instance = FAISS.from_texts( texts=["Initial dummy document for FAISS."], embedding=emb_model, ) vectorstore_instance.save_local(FAISS_INDEX_PATH) logger.info( f"New FAISS index created with dummy doc and saved to {FAISS_INDEX_PATH}" ) except Exception as create_e: logger.error( f"Failed to create new FAISS index: {create_e}", exc_info=True ) raise RuntimeError(f"Could not create new FAISS index: {create_e}") else: logger.info( f"FAISS index path {FAISS_INDEX_PATH} not found or invalid. Creating new index." ) try: vectorstore_instance = FAISS.from_texts( texts=["Initial dummy document for FAISS."], embedding=emb_model ) vectorstore_instance.save_local(FAISS_INDEX_PATH) logger.info(f"New FAISS index created and saved to {FAISS_INDEX_PATH}") except Exception as create_e: logger.error( f"Failed to create new FAISS index: {create_e}", exc_info=True ) raise RuntimeError(f"Could not create new FAISS index: {create_e}") return vectorstore_instance class IndexRequest(BaseModel): docs: List[str] class RetrieveRequest(BaseModel): query: str top_k: int = 3 @app.post("/index") def index_docs(request: IndexRequest): try: vecstore = get_vectorstore() if not request.docs: logger.warning("No documents provided for indexing.") return { "status": "no documents provided", "num_docs_in_store": vecstore.index.ntotal if vecstore.index else 0, } logger.info(f"Indexing {len(request.docs)} new documents.") vecstore.add_texts(texts=request.docs) vecstore.save_local(FAISS_INDEX_PATH) logger.info( f"Index updated and saved. Total documents in store: {vecstore.index.ntotal}" ) return {"status": "indexed", "num_docs_in_store": vecstore.index.ntotal} except Exception as e: logger.error(f"Error during indexing: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Indexing failed: {str(e)}") @app.post("/retrieve") def retrieve(request: RetrieveRequest): try: vecstore = get_vectorstore() if not vecstore.index or vecstore.index.ntotal == 0: logger.warning( "Vector store is empty or index not initialized. Cannot retrieve." ) return { "results": [], "message": "Vector store is empty. Index documents first.", } if vecstore.index.ntotal == 1: try: first_doc_id = list(vecstore.docstore._dict.keys())[0] first_doc_content = vecstore.docstore._dict[first_doc_id].page_content if "Initial dummy document for FAISS" in first_doc_content: logger.warning( "Vector store contains only the initial dummy document." ) except Exception: logger.warning( "Could not inspect docstore for dummy document, proceeding with retrieval." ) logger.info( f"Retrieving documents for query: '{request.query}' (top_k={request.top_k}). Total docs: {vecstore.index.ntotal}" ) results_with_scores = vecstore.similarity_search_with_score( query=request.query, k=request.top_k ) formatted_results = [ {"doc": doc.page_content, "score": float(score)} for doc, score in results_with_scores ] logger.info(f"Retrieved {len(formatted_results)} results.") return {"results": formatted_results} except Exception as e: logger.error(f"Error during retrieval: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Retrieval failed: {str(e)}")