import streamlit as st from streamlit_option_menu import option_menu import fitz # PyMuPDF from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain.chains import RetrievalQA from langchain_community.llms import HuggingFaceEndpoint import requests import os import json # Page configuration st.set_page_config( page_title="PDF Study Assistant", page_icon="📚", layout="wide", initial_sidebar_state="collapsed" ) # Custom CSS for colorful design st.markdown(""" """, unsafe_allow_html=True) # Initialize session state if 'pdf_processed' not in st.session_state: st.session_state.pdf_processed = False if 'vector_store' not in st.session_state: st.session_state.vector_store = None if 'pages' not in st.session_state: st.session_state.pages = [] if 'history' not in st.session_state: st.session_state.history = [] # Load embedding model with caching @st.cache_resource def load_embedding_model(): return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") def query_hf_inference_api(prompt, model="google/flan-t5-xxl", max_tokens=200): """Query Hugging Face Inference API directly""" API_URL = f"https://api-inference.huggingface.co/models/{model}" headers = {"Authorization": f"Bearer {os.getenv('HF_API_KEY')}"} payload = { "inputs": prompt, "parameters": { "max_new_tokens": max_tokens, "temperature": 0.5, "do_sample": False } } try: response = requests.post(API_URL, headers=headers, json=payload) response.raise_for_status() result = response.json() return result[0]['generated_text'] if result else "" except Exception as e: st.error(f"Error querying model: {str(e)}") return "" def process_pdf(pdf_file): """Extract text from PDF and create vector store""" with st.spinner("📖 Reading PDF..."): doc = fitz.open(stream=pdf_file.read(), filetype="pdf") text = "" st.session_state.pages = [] for page in doc: page_text = page.get_text() text += page_text st.session_state.pages.append(page_text) with st.spinner("🔍 Processing text..."): text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=200, length_function=len ) chunks = text_splitter.split_text(text) embeddings = load_embedding_model() st.session_state.vector_store = FAISS.from_texts(chunks, embeddings) st.session_state.pdf_processed = True st.success("✅ PDF processed successfully!") def ask_question(question): """Answer a question using the vector store and Hugging Face API""" if not st.session_state.vector_store: return "PDF not processed yet", [] # Find relevant passages docs = st.session_state.vector_store.similarity_search(question, k=3) context = "\n\n".join([doc.page_content for doc in docs]) # Format prompt for the model prompt = f""" Based on the following context, answer the question. If the answer isn't in the context, say "I don't know". Context: {context} Question: {question} Answer: """ # Query the model answer = query_hf_inference_api(prompt) # Add to history st.session_state.history.append({ "question": question, "answer": answer, "sources": [doc.page_content for doc in docs] }) return answer, docs def generate_qa_for_chapter(start_page, end_page): """Generate Q&A for specific chapter pages""" if start_page < 1 or end_page > len(st.session_state.pages) or start_page > end_page: st.error("Invalid page range") return [] chapter_text = "\n".join(st.session_state.pages[start_page-1:end_page]) text_splitter = RecursiveCharacterTextSplitter( chunk_size=800, chunk_overlap=100, length_function=len ) chunks = text_splitter.split_text(chapter_text) qa_pairs = [] with st.spinner(f"🧠 Generating Q&A for pages {start_page}-{end_page}..."): for i, chunk in enumerate(chunks): if i % 2 == 0: # Generate question prompt = f"Based on this text, generate one study question: {chunk[:500]}" question = query_hf_inference_api(prompt, max_tokens=100) if question and not question.endswith("?"): question += "?" else: # Generate answer if qa_pairs: # Ensure we have a question to answer prompt = f"Answer this question: {qa_pairs[-1][0]} using this context: {chunk[:500]}" answer = query_hf_inference_api(prompt, max_tokens=200) qa_pairs[-1] = (qa_pairs[-1][0], answer) return qa_pairs # App header st.markdown("