File size: 7,155 Bytes
3f43e82 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Dict, Optional
from langchain_community.embeddings import SentenceTransformerEmbeddings
import os
from dotenv import load_dotenv
from langchain_community.vectorstores import FAISS
from langchain_core.embeddings import Embeddings
import logging
load_dotenv()
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
app = FastAPI(title="Retriever Agent")
FAISS_INDEX_PATH = os.getenv(
"FAISS_INDEX_PATH", "/app/faiss_index_store"
) # Path inside container
EMBEDDING_MODEL_NAME = os.getenv("EMBEDDING_MODEL_NAME", "all-MiniLM-L6-v2")
embedding_model_instance: Optional[Embeddings] = None
vectorstore_instance: Optional[FAISS] = None
def get_embedding_model() -> Embeddings:
"""Initialize and return the SentenceTransformerEmbeddings model."""
global embedding_model_instance
if embedding_model_instance is None:
try:
logger.info(
f"Loading SentenceTransformerEmbeddings with model: {EMBEDDING_MODEL_NAME}"
)
embedding_model_instance = SentenceTransformerEmbeddings(
model_name=EMBEDDING_MODEL_NAME
)
logger.info(
f"SentenceTransformerEmbeddings model '{EMBEDDING_MODEL_NAME}' loaded successfully."
)
except Exception as e:
logger.error(
f"Error loading SentenceTransformerEmbeddings model '{EMBEDDING_MODEL_NAME}': {e}",
exc_info=True,
)
raise RuntimeError(f"Could not load embedding model: {e}")
return embedding_model_instance
def get_vectorstore() -> FAISS:
"""Load or create the FAISS vector store."""
global vectorstore_instance
if vectorstore_instance is None:
emb_model = get_embedding_model()
if os.path.exists(FAISS_INDEX_PATH) and os.path.isdir(FAISS_INDEX_PATH):
try:
logger.info(
f"Attempting to load FAISS index from {FAISS_INDEX_PATH}..."
)
vectorstore_instance = FAISS.load_local(
FAISS_INDEX_PATH,
emb_model,
allow_dangerous_deserialization=True,
)
logger.info(
f"FAISS index loaded from {FAISS_INDEX_PATH}. Documents: {vectorstore_instance.index.ntotal if vectorstore_instance.index else 'N/A'}"
)
except Exception as e:
logger.error(
f"Error loading FAISS index from {FAISS_INDEX_PATH}: {e}",
exc_info=True,
)
logger.warning("Creating a new FAISS index due to loading error.")
try:
vectorstore_instance = FAISS.from_texts(
texts=["Initial dummy document for FAISS."],
embedding=emb_model,
)
vectorstore_instance.save_local(FAISS_INDEX_PATH)
logger.info(
f"New FAISS index created with dummy doc and saved to {FAISS_INDEX_PATH}"
)
except Exception as create_e:
logger.error(
f"Failed to create new FAISS index: {create_e}", exc_info=True
)
raise RuntimeError(f"Could not create new FAISS index: {create_e}")
else:
logger.info(
f"FAISS index path {FAISS_INDEX_PATH} not found or invalid. Creating new index."
)
try:
vectorstore_instance = FAISS.from_texts(
texts=["Initial dummy document for FAISS."], embedding=emb_model
)
vectorstore_instance.save_local(FAISS_INDEX_PATH)
logger.info(f"New FAISS index created and saved to {FAISS_INDEX_PATH}")
except Exception as create_e:
logger.error(
f"Failed to create new FAISS index: {create_e}", exc_info=True
)
raise RuntimeError(f"Could not create new FAISS index: {create_e}")
return vectorstore_instance
class IndexRequest(BaseModel):
docs: List[str]
class RetrieveRequest(BaseModel):
query: str
top_k: int = 3
@app.post("/index")
def index_docs(request: IndexRequest):
try:
vecstore = get_vectorstore()
if not request.docs:
logger.warning("No documents provided for indexing.")
return {
"status": "no documents provided",
"num_docs_in_store": vecstore.index.ntotal if vecstore.index else 0,
}
logger.info(f"Indexing {len(request.docs)} new documents.")
vecstore.add_texts(texts=request.docs)
vecstore.save_local(FAISS_INDEX_PATH)
logger.info(
f"Index updated and saved. Total documents in store: {vecstore.index.ntotal}"
)
return {"status": "indexed", "num_docs_in_store": vecstore.index.ntotal}
except Exception as e:
logger.error(f"Error during indexing: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Indexing failed: {str(e)}")
@app.post("/retrieve")
def retrieve(request: RetrieveRequest):
try:
vecstore = get_vectorstore()
if not vecstore.index or vecstore.index.ntotal == 0:
logger.warning(
"Vector store is empty or index not initialized. Cannot retrieve."
)
return {
"results": [],
"message": "Vector store is empty. Index documents first.",
}
if vecstore.index.ntotal == 1:
try:
first_doc_id = list(vecstore.docstore._dict.keys())[0]
first_doc_content = vecstore.docstore._dict[first_doc_id].page_content
if "Initial dummy document for FAISS" in first_doc_content:
logger.warning(
"Vector store contains only the initial dummy document."
)
except Exception:
logger.warning(
"Could not inspect docstore for dummy document, proceeding with retrieval."
)
logger.info(
f"Retrieving documents for query: '{request.query}' (top_k={request.top_k}). Total docs: {vecstore.index.ntotal}"
)
results_with_scores = vecstore.similarity_search_with_score(
query=request.query, k=request.top_k
)
formatted_results = [
{"doc": doc.page_content, "score": float(score)}
for doc, score in results_with_scores
]
logger.info(f"Retrieved {len(formatted_results)} results.")
return {"results": formatted_results}
except Exception as e:
logger.error(f"Error during retrieval: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Retrieval failed: {str(e)}")
|