import streamlit as st st.set_page_config(page_title="RAG Book Analyzer", layout="wide") import torch import numpy as np import faiss from transformers import AutoModelForCausalLM, AutoTokenizer from sentence_transformers import SentenceTransformer import fitz # PyMuPDF import docx2txt from langchain_text_splitters import RecursiveCharacterTextSplitter # ------------------------ # Configuration (optimized for reliability) # ------------------------ MODEL_NAME = "microsoft/phi-2" EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2" # Efficient embedding model CHUNK_SIZE = 512 CHUNK_OVERLAP = 64 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MAX_TEXT_LENGTH = 3000 # To prevent OOM errors # ------------------------ # Model Loading with Robust Error Handling # ------------------------ @st.cache_resource(show_spinner="Loading AI models...") def load_models(): try: # Load tokenizer with special settings for Phi-2 tokenizer = AutoTokenizer.from_pretrained( MODEL_NAME, trust_remote_code=True, padding_side="left" ) tokenizer.pad_token = tokenizer.eos_token # Load model with safe defaults model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, trust_remote_code=True, device_map="auto" if DEVICE == "cuda" else None, low_cpu_mem_usage=True ) # Load efficient embedding model embedder = SentenceTransformer(EMBED_MODEL, device=DEVICE) return tokenizer, model, embedder except Exception as e: st.error(f"Model loading failed: {str(e)}") st.stop() tokenizer, model, embedder = load_models() # ------------------------ # Text Processing Functions # ------------------------ def split_text(text): splitter = RecursiveCharacterTextSplitter( chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, length_function=len ) return splitter.split_text(text) def extract_text(file): try: if file.type == "application/pdf": doc = fitz.open(stream=file.read(), filetype="pdf") return "\n".join([page.get_text() for page in doc]) elif file.type == "text/plain": return file.read().decode("utf-8") elif file.type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document": return docx2txt.process(file) else: st.error(f"Unsupported file type: {file.type}") return "" except Exception as e: st.error(f"Error processing file: {str(e)}") return "" def build_index(chunks): embeddings = embedder.encode(chunks, show_progress_bar=False) dimension = embeddings.shape[1] index = faiss.IndexFlatL2(dimension) index.add(embeddings) return index # ------------------------ # AI Generation Functions (with safeguards) # ------------------------ def generate_summary(text): text = text[:MAX_TEXT_LENGTH] # Prevent long inputs prompt = f"Instruction: Summarize this book in a concise paragraph\nText: {text}\nSummary:" inputs = tokenizer( prompt, return_tensors="pt", max_length=1024, truncation=True ).to(DEVICE) outputs = model.generate( **inputs, max_new_tokens=200, temperature=0.7, top_p=0.9, do_sample=True, pad_token_id=tokenizer.eos_token_id ) summary = tokenizer.decode( outputs[0], skip_special_tokens=True ) # Extract just the summary part if "Summary:" in summary: return summary.split("Summary:")[-1].strip() return summary.replace(prompt, "").strip() def generate_answer(query, context): context = context[:MAX_TEXT_LENGTH] # Limit context size prompt = f"Instruction: Answer this question based on the context. If unsure, say 'I don't know'.\nQuestion: {query}\nContext: {context}\nAnswer:" inputs = tokenizer( prompt, return_tensors="pt", max_length=1024, truncation=True ).to(DEVICE) outputs = model.generate( **inputs, max_new_tokens=150, temperature=0.4, top_p=0.85, repetition_penalty=1.1, do_sample=True, pad_token_id=tokenizer.eos_token_id ) answer = tokenizer.decode( outputs[0], skip_special_tokens=True ) # Extract just the answer part if "Answer:" in answer: return answer.split("Answer:")[-1].strip() return answer.replace(prompt, "").strip() # ------------------------ # Streamlit UI # ------------------------ st.title("📚 RAG-Based Book Analyzer") st.write("Upload a book (PDF, TXT, DOCX) to get a summary and ask questions about its content.") st.warning("Note: First run will download models (~1.5GB). Please be patient!") uploaded_file = st.file_uploader("Upload File", type=["pdf", "txt", "docx"]) if uploaded_file: with st.spinner("Extracting text from file..."): text = extract_text(uploaded_file) if not text: st.error("Failed to extract text. Please try another file.") st.stop() st.success(f"✅ Extracted {len(text)} characters") with st.spinner("Generating summary (this may take a minute)..."): summary = generate_summary(text) st.markdown("### Book Summary") st.info(summary) with st.spinner("Preparing document for questions..."): chunks = split_text(text) index = build_index(chunks) st.session_state.chunks = chunks st.session_state.index = index st.success(f"✅ Document indexed with {len(chunks)} chunks") st.divider() if 'chunks' in st.session_state: st.markdown("### ❓ Ask a Question about the Book") query = st.text_input("Enter your question:", key="question") if query: with st.spinner("Searching for answers..."): # Retrieve top 3 relevant chunks query_embedding = embedder.encode([query]) distances, indices = st.session_state.index.search(query_embedding, k=3) # Safely retrieve chunks retrieved_chunks = [] for i in indices[0]: if i < len(st.session_state.chunks): retrieved_chunks.append(st.session_state.chunks[i]) context = "\n\n".join(retrieved_chunks) # Generate answer answer = generate_answer(query, context) # Display results st.markdown("### 💬 Answer") st.success(answer) with st.expander("View context used for answer"): st.text(context)