File size: 7,155 Bytes
3f43e82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Dict, Optional


from langchain_community.embeddings import SentenceTransformerEmbeddings
import os
from dotenv import load_dotenv
from langchain_community.vectorstores import FAISS
from langchain_core.embeddings import Embeddings
import logging

load_dotenv()
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

app = FastAPI(title="Retriever Agent")

FAISS_INDEX_PATH = os.getenv(
    "FAISS_INDEX_PATH", "/app/faiss_index_store"
)  # Path inside container

EMBEDDING_MODEL_NAME = os.getenv("EMBEDDING_MODEL_NAME", "all-MiniLM-L6-v2")


embedding_model_instance: Optional[Embeddings] = None
vectorstore_instance: Optional[FAISS] = None


def get_embedding_model() -> Embeddings:
    """Initialize and return the SentenceTransformerEmbeddings model."""
    global embedding_model_instance
    if embedding_model_instance is None:
        try:
            logger.info(
                f"Loading SentenceTransformerEmbeddings with model: {EMBEDDING_MODEL_NAME}"
            )

            embedding_model_instance = SentenceTransformerEmbeddings(
                model_name=EMBEDDING_MODEL_NAME
            )
            logger.info(
                f"SentenceTransformerEmbeddings model '{EMBEDDING_MODEL_NAME}' loaded successfully."
            )
        except Exception as e:
            logger.error(
                f"Error loading SentenceTransformerEmbeddings model '{EMBEDDING_MODEL_NAME}': {e}",
                exc_info=True,
            )
            raise RuntimeError(f"Could not load embedding model: {e}")
    return embedding_model_instance


def get_vectorstore() -> FAISS:
    """Load or create the FAISS vector store."""
    global vectorstore_instance
    if vectorstore_instance is None:
        emb_model = get_embedding_model()
        if os.path.exists(FAISS_INDEX_PATH) and os.path.isdir(FAISS_INDEX_PATH):
            try:
                logger.info(
                    f"Attempting to load FAISS index from {FAISS_INDEX_PATH}..."
                )
                vectorstore_instance = FAISS.load_local(
                    FAISS_INDEX_PATH,
                    emb_model,
                    allow_dangerous_deserialization=True,
                )
                logger.info(
                    f"FAISS index loaded from {FAISS_INDEX_PATH}. Documents: {vectorstore_instance.index.ntotal if vectorstore_instance.index else 'N/A'}"
                )
            except Exception as e:
                logger.error(
                    f"Error loading FAISS index from {FAISS_INDEX_PATH}: {e}",
                    exc_info=True,
                )
                logger.warning("Creating a new FAISS index due to loading error.")
                try:
                    vectorstore_instance = FAISS.from_texts(
                        texts=["Initial dummy document for FAISS."],
                        embedding=emb_model,
                    )
                    vectorstore_instance.save_local(FAISS_INDEX_PATH)
                    logger.info(
                        f"New FAISS index created with dummy doc and saved to {FAISS_INDEX_PATH}"
                    )
                except Exception as create_e:
                    logger.error(
                        f"Failed to create new FAISS index: {create_e}", exc_info=True
                    )
                    raise RuntimeError(f"Could not create new FAISS index: {create_e}")
        else:
            logger.info(
                f"FAISS index path {FAISS_INDEX_PATH} not found or invalid. Creating new index."
            )
            try:
                vectorstore_instance = FAISS.from_texts(
                    texts=["Initial dummy document for FAISS."], embedding=emb_model
                )
                vectorstore_instance.save_local(FAISS_INDEX_PATH)
                logger.info(f"New FAISS index created and saved to {FAISS_INDEX_PATH}")
            except Exception as create_e:
                logger.error(
                    f"Failed to create new FAISS index: {create_e}", exc_info=True
                )
                raise RuntimeError(f"Could not create new FAISS index: {create_e}")
    return vectorstore_instance


class IndexRequest(BaseModel):
    docs: List[str]


class RetrieveRequest(BaseModel):
    query: str
    top_k: int = 3


@app.post("/index")
def index_docs(request: IndexRequest):
    try:
        vecstore = get_vectorstore()
        if not request.docs:
            logger.warning("No documents provided for indexing.")
            return {
                "status": "no documents provided",
                "num_docs_in_store": vecstore.index.ntotal if vecstore.index else 0,
            }
        logger.info(f"Indexing {len(request.docs)} new documents.")
        vecstore.add_texts(texts=request.docs)
        vecstore.save_local(FAISS_INDEX_PATH)
        logger.info(
            f"Index updated and saved. Total documents in store: {vecstore.index.ntotal}"
        )
        return {"status": "indexed", "num_docs_in_store": vecstore.index.ntotal}
    except Exception as e:
        logger.error(f"Error during indexing: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail=f"Indexing failed: {str(e)}")


@app.post("/retrieve")
def retrieve(request: RetrieveRequest):
    try:
        vecstore = get_vectorstore()
        if not vecstore.index or vecstore.index.ntotal == 0:
            logger.warning(
                "Vector store is empty or index not initialized. Cannot retrieve."
            )
            return {
                "results": [],
                "message": "Vector store is empty. Index documents first.",
            }

        if vecstore.index.ntotal == 1:

            try:
                first_doc_id = list(vecstore.docstore._dict.keys())[0]
                first_doc_content = vecstore.docstore._dict[first_doc_id].page_content
                if "Initial dummy document for FAISS" in first_doc_content:
                    logger.warning(
                        "Vector store contains only the initial dummy document."
                    )

            except Exception:
                logger.warning(
                    "Could not inspect docstore for dummy document, proceeding with retrieval."
                )

        logger.info(
            f"Retrieving documents for query: '{request.query}' (top_k={request.top_k}). Total docs: {vecstore.index.ntotal}"
        )
        results_with_scores = vecstore.similarity_search_with_score(
            query=request.query, k=request.top_k
        )
        formatted_results = [
            {"doc": doc.page_content, "score": float(score)}
            for doc, score in results_with_scores
        ]
        logger.info(f"Retrieved {len(formatted_results)} results.")
        return {"results": formatted_results}
    except Exception as e:
        logger.error(f"Error during retrieval: {e}", exc_info=True)

        raise HTTPException(status_code=500, detail=f"Retrieval failed: {str(e)}")