sunbal7's picture
Update app.py
3f00b29 verified
import streamlit as st
st.set_page_config(page_title="RAG Book Analyzer", layout="wide")
import torch
import numpy as np
import faiss
from transformers import AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
import fitz # PyMuPDF
import docx2txt
from langchain_text_splitters import RecursiveCharacterTextSplitter
# ------------------------
# Configuration (optimized for reliability)
# ------------------------
MODEL_NAME = "microsoft/phi-2"
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2" # Efficient embedding model
CHUNK_SIZE = 512
CHUNK_OVERLAP = 64
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_TEXT_LENGTH = 3000 # To prevent OOM errors
# ------------------------
# Model Loading with Robust Error Handling
# ------------------------
@st.cache_resource(show_spinner="Loading AI models...")
def load_models():
try:
# Load tokenizer with special settings for Phi-2
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
padding_side="left"
)
tokenizer.pad_token = tokenizer.eos_token
# Load model with safe defaults
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
trust_remote_code=True,
device_map="auto" if DEVICE == "cuda" else None,
low_cpu_mem_usage=True
)
# Load efficient embedding model
embedder = SentenceTransformer(EMBED_MODEL, device=DEVICE)
return tokenizer, model, embedder
except Exception as e:
st.error(f"Model loading failed: {str(e)}")
st.stop()
tokenizer, model, embedder = load_models()
# ------------------------
# Text Processing Functions
# ------------------------
def split_text(text):
splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP,
length_function=len
)
return splitter.split_text(text)
def extract_text(file):
try:
if file.type == "application/pdf":
doc = fitz.open(stream=file.read(), filetype="pdf")
return "\n".join([page.get_text() for page in doc])
elif file.type == "text/plain":
return file.read().decode("utf-8")
elif file.type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
return docx2txt.process(file)
else:
st.error(f"Unsupported file type: {file.type}")
return ""
except Exception as e:
st.error(f"Error processing file: {str(e)}")
return ""
def build_index(chunks):
embeddings = embedder.encode(chunks, show_progress_bar=False)
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
return index
# ------------------------
# AI Generation Functions (with safeguards)
# ------------------------
def generate_summary(text):
text = text[:MAX_TEXT_LENGTH] # Prevent long inputs
prompt = f"Instruction: Summarize this book in a concise paragraph\nText: {text}\nSummary:"
inputs = tokenizer(
prompt,
return_tensors="pt",
max_length=1024,
truncation=True
).to(DEVICE)
outputs = model.generate(
**inputs,
max_new_tokens=200,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
summary = tokenizer.decode(
outputs[0],
skip_special_tokens=True
)
# Extract just the summary part
if "Summary:" in summary:
return summary.split("Summary:")[-1].strip()
return summary.replace(prompt, "").strip()
def generate_answer(query, context):
context = context[:MAX_TEXT_LENGTH] # Limit context size
prompt = f"Instruction: Answer this question based on the context. If unsure, say 'I don't know'.\nQuestion: {query}\nContext: {context}\nAnswer:"
inputs = tokenizer(
prompt,
return_tensors="pt",
max_length=1024,
truncation=True
).to(DEVICE)
outputs = model.generate(
**inputs,
max_new_tokens=150,
temperature=0.4,
top_p=0.85,
repetition_penalty=1.1,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
answer = tokenizer.decode(
outputs[0],
skip_special_tokens=True
)
# Extract just the answer part
if "Answer:" in answer:
return answer.split("Answer:")[-1].strip()
return answer.replace(prompt, "").strip()
# ------------------------
# Streamlit UI
# ------------------------
st.title("πŸ“š RAG-Based Book Analyzer")
st.write("Upload a book (PDF, TXT, DOCX) to get a summary and ask questions about its content.")
st.warning("Note: First run will download models (~1.5GB). Please be patient!")
uploaded_file = st.file_uploader("Upload File", type=["pdf", "txt", "docx"])
if uploaded_file:
with st.spinner("Extracting text from file..."):
text = extract_text(uploaded_file)
if not text:
st.error("Failed to extract text. Please try another file.")
st.stop()
st.success(f"βœ… Extracted {len(text)} characters")
with st.spinner("Generating summary (this may take a minute)..."):
summary = generate_summary(text)
st.markdown("### Book Summary")
st.info(summary)
with st.spinner("Preparing document for questions..."):
chunks = split_text(text)
index = build_index(chunks)
st.session_state.chunks = chunks
st.session_state.index = index
st.success(f"βœ… Document indexed with {len(chunks)} chunks")
st.divider()
if 'chunks' in st.session_state:
st.markdown("### ❓ Ask a Question about the Book")
query = st.text_input("Enter your question:", key="question")
if query:
with st.spinner("Searching for answers..."):
# Retrieve top 3 relevant chunks
query_embedding = embedder.encode([query])
distances, indices = st.session_state.index.search(query_embedding, k=3)
# Safely retrieve chunks
retrieved_chunks = []
for i in indices[0]:
if i < len(st.session_state.chunks):
retrieved_chunks.append(st.session_state.chunks[i])
context = "\n\n".join(retrieved_chunks)
# Generate answer
answer = generate_answer(query, context)
# Display results
st.markdown("### πŸ’¬ Answer")
st.success(answer)
with st.expander("View context used for answer"):
st.text(context)