import streamlit as st import logging import os import tempfile import shutil import pdfplumber import ollama import time import httpx from langchain_community.document_loaders import UnstructuredPDFLoader from langchain_community.embeddings import OllamaEmbeddings from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_community.vectorstores import Chroma from langchain.prompts import ChatPromptTemplate, PromptTemplate from langchain_core.output_parsers import StrOutputParser from langchain_community.chat_models import ChatOllama from langchain_core.runnables import RunnablePassthrough from langchain.retrievers.multi_query import MultiQueryRetriever from typing import List, Tuple, Dict, Any, Optional # Streamlit page configuration st.set_page_config( page_title="Ollama PDF RAG Streamlit UI", page_icon="🎈", layout="wide", initial_sidebar_state="collapsed", ) # Logging configuration logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) logger = logging.getLogger(__name__) def ollama_list_with_retry(retries=3, delay=5): """Attempt to list models from Ollama with retry logic.""" for attempt in range(retries): try: response = ollama.list() logger.info("Successfully retrieved model list from Ollama") return response except httpx.ConnectError as e: logger.error(f"Connection error: {e}. Attempt {attempt + 1} of {retries}") if attempt < retries - 1: time.sleep(delay) else: logger.error("All retry attempts failed. Cannot connect to Ollama service.") raise @st.cache_resource(show_spinner=True) def extract_model_names(models_info: Dict[str, List[Dict[str, Any]]]) -> Tuple[str, ...]: """Extract model names from the provided models information.""" logger.info("Extracting model names from models_info") model_names = tuple(model["name"] for model in models_info["models"]) logger.info(f"Extracted model names: {model_names}") return model_names def create_vector_db(file_upload) -> Chroma: """Create a vector database from an uploaded PDF file.""" logger.info(f"Creating vector DB from file upload: {file_upload.name}") temp_dir = tempfile.mkdtemp() path = os.path.join(temp_dir, file_upload.name) with open(path, "wb") as f: f.write(file_upload.getvalue()) logger.info(f"File saved to temporary path: {path}") loader = UnstructuredPDFLoader(path) data = loader.load() text_splitter = RecursiveCharacterTextSplitter(chunk_size=7500, chunk_overlap=100) chunks = text_splitter.split_documents(data) logger.info("Document split into chunks") embeddings = OllamaEmbeddings(model="nomic-embed-text", show_progress=True) vector_db = Chroma.from_documents( documents=chunks, embedding=embeddings, collection_name="myRAG" ) logger.info("Vector DB created") shutil.rmtree(temp_dir) logger.info(f"Temporary directory {temp_dir} removed") return vector_db def process_question(question: str, vector_db: Chroma, selected_model: str) -> str: """Process a user question using the vector database and selected language model.""" logger.info(f"Processing question: {question} using model: {selected_model}") llm = ChatOllama(model=selected_model, temperature=0) QUERY_PROMPT = PromptTemplate( input_variables=["question"], template="""You are an AI language model assistant. Your task is to generate 3 different versions of the given user question to retrieve relevant documents from a vector database. By generating multiple perspectives on the user question, your goal is to help the user overcome some of the limitations of the distance-based similarity search. Provide these alternative questions separated by newlines. Original question: {question}""", ) retriever = MultiQueryRetriever.from_llm( vector_db.as_retriever(), llm, prompt=QUERY_PROMPT ) template = """Answer the question based ONLY on the following context: {context} Question: {question} If you don't know the answer, just say that you don't know, don't try to make up an answer. Only provide the answer from the {context}, nothing else. Add snippets of the context you used to answer the question. """ prompt = ChatPromptTemplate.from_template(template) chain = ( {"context": retriever, "question": RunnablePassthrough()} | prompt | llm | StrOutputParser() ) response = chain.invoke(question) logger.info("Question processed and response generated") return response @st.cache_data def extract_all_pages_as_images(file_upload) -> List[Any]: """Extract all pages from a PDF file as images.""" logger.info(f"Extracting all pages as images from file: {file_upload.name}") pdf_pages = [] with pdfplumber.open(file_upload) as pdf: pdf_pages = [page.to_image().original for page in pdf.pages] logger.info("PDF pages extracted as images") return pdf_pages def delete_vector_db(vector_db: Optional[Chroma]) -> None: """Delete the vector database and clear related session state.""" logger.info("Deleting vector DB") if vector_db is not None: vector_db.delete_collection() st.session_state.pop("pdf_pages", None) st.session_state.pop("file_upload", None) st.session_state.pop("vector_db", None) st.success("Collection and temporary files deleted successfully.") logger.info("Vector DB and related session state cleared") st.rerun() else: st.error("No vector database found to delete.") logger.warning("Attempted to delete vector DB, but none was found") def main() -> None: """Main function to run the Streamlit application.""" st.subheader("🧠 Ollama PDF RAG playground", divider="gray", anchor=False) try: models_info = ollama_list_with_retry() available_models = extract_model_names(models_info) except httpx.ConnectError: st.error("Could not connect to the Ollama service. Please check your setup and try again.") return col1, col2 = st.columns([1.5, 2]) if "messages" not in st.session_state: st.session_state["messages"] = [] if "vector_db" not in st.session_state: st.session_state["vector_db"]