import os from langchain.vectorstores import FAISS from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.document_loaders import TextLoader from langchain.text_splitter import CharacterTextSplitter from langchain.docstore.document import Document from langchain.chains import RetrievalQA from langchain_community.llms import HuggingFaceHub from langchain.embeddings.base import Embeddings # Set safe caching directories to avoid permission denied errors os.environ["TRANSFORMERS_CACHE"] = "/app/cache" os.environ["HF_HOME"] = "/app/cache" os.makedirs("/app/cache", exist_ok=True) # Constants DATA_PATH = "/app/data" VECTORSTORE_PATH = "/app/vectorstore" DOCS_FILENAME = "context.txt" EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-MiniLM-L6-v2" def load_embedding_model() -> Embeddings: """Initialize and return the HuggingFace embedding model.""" return HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME) def load_documents() -> list[Document]: """Load and split documents into chunks.""" loader = TextLoader(os.path.join(DATA_PATH, DOCS_FILENAME)) raw_docs = loader.load() splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100) docs = splitter.split_documents(raw_docs) return docs def load_vectorstore() -> FAISS: """Load or create FAISS vectorstore from documents.""" vectorstore_file = os.path.join(VECTORSTORE_PATH, "faiss_index") embedding_model = load_embedding_model() if os.path.exists(vectorstore_file): return FAISS.load_local(vectorstore_file, embedding_model, allow_dangerous_deserialization=True) docs = load_documents() vectorstore = FAISS.from_documents(docs, embedding_model) vectorstore.save_local(vectorstore_file) return vectorstore def ask_question(query: str) -> str: """Query the vectorstore and return the answer using the language model.""" vectorstore = load_vectorstore() llm = HuggingFaceHub( repo_id="mistralai/Mistral-7B-Instruct-v0.1", model_kwargs={"temperature": 0.5, "max_tokens": 256}, ) qa = RetrievalQA.from_chain_type(llm=llm, retriever=vectorstore.as_retriever()) return qa.run(query)