from flask import Flask, request, jsonify import os from langchain_community.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import Chroma from langchain.chains import ConversationalRetrievalChain from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.llms import HuggingFaceEndpoint from langchain.memory import ConversationBufferMemory from pathlib import Path import chromadb from unidecode import unidecode import re app = Flask(__name__) # Configuration variables PDF_PATH = "https://huggingface.co/spaces/CCCDev/PDFChat/resolve/main/Data-privacy-policy.pdf" # Replace with your static PDF path CHUNK_SIZE = 512 CHUNK_OVERLAP = 24 LLM_MODEL = "mistralai/Mistral-7B-Instruct-v0.2" TEMPERATURE = 0.1 MAX_TOKENS = 512 TOP_K = 20 # Load PDF document and create doc splits def load_doc(pdf_path, chunk_size, chunk_overlap): loader = PyPDFLoader(pdf_path) pages = loader.load() text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) doc_splits = text_splitter.split_documents(pages) return doc_splits # Create vector database def create_db(splits, collection_name): embedding = HuggingFaceEmbeddings() new_client = chromadb.EphemeralClient() vectordb = Chroma.from_documents( documents=splits, embedding=embedding, client=new_client, collection_name=collection_name, ) return vectordb # Initialize langchain LLM chain def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db): llm = HuggingFaceEndpoint( repo_id=llm_model, temperature=temperature, max_new_tokens=max_tokens, top_k=top_k, ) memory = ConversationBufferMemory( memory_key="chat_history", output_key='answer', return_messages=True ) retriever = vector_db.as_retriever() qa_chain = ConversationalRetrievalChain.from_llm( llm, retriever=retriever, chain_type="stuff", memory=memory, return_source_documents=True, verbose=False, ) return qa_chain # Generate collection name for vector database def create_collection_name(filepath): collection_name = Path(filepath).stem collection_name = collection_name.replace(" ", "-") collection_name = unidecode(collection_name) collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name) collection_name = collection_name[:50] if len(collection_name) < 3: collection_name = collection_name + 'xyz' if not collection_name[0].isalnum(): collection_name = 'A' + collection_name[1:] if not collection_name[-1].isalnum(): collection_name = collection_name[:-1] + 'Z' return collection_name # Initialize database and QA chain doc_splits = load_doc(PDF_PATH, CHUNK_SIZE, CHUNK_OVERLAP) collection_name = create_collection_name(PDF_PATH) vector_db = create_db(doc_splits, collection_name) qa_chain = initialize_llmchain(LLM_MODEL, TEMPERATURE, MAX_TOKENS, TOP_K, vector_db) @app.route('/chat', methods=['POST']) def chat(): data = request.json message = data.get('message', '') history = data.get('history', []) formatted_chat_history = [] for user_message, bot_message in history: formatted_chat_history.append(f"User: {user_message}") formatted_chat_history.append(f"Assistant: {bot_message}") response = qa_chain({"question": message, "chat_history": formatted_chat_history}) response_answer = response["answer"] if response_answer.find("Helpful Answer:") != -1: response_answer = response_answer.split("Helpful Answer:")[-1] response_sources = response["source_documents"] result = { "answer": response_answer, "sources": [ {"content": doc.page_content.strip(), "page": doc.metadata["page"] + 1} for doc in response_sources ] } return jsonify(result) if __name__ == '__main__': app.run(debug=True, host='0.0.0.0', port=5000)