Spaces:
Sleeping
Sleeping
import streamlit as st | |
st.set_page_config(page_title="RAG Book Analyzer", layout="wide") # Must be the first Streamlit command | |
import torch | |
import numpy as np | |
import faiss | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from sentence_transformers import SentenceTransformer | |
import fitz # PyMuPDF for PDF extraction | |
import docx2txt # For DOCX extraction | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
# ------------------------ | |
# Configuration | |
# ------------------------ | |
MODEL_NAME = "microsoft/phi-2" # Open-source model with good performance | |
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2" # Smaller embedding model | |
CHUNK_SIZE = 512 | |
CHUNK_OVERLAP = 64 | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
# ------------------------ | |
# Model Loading with Caching | |
# ------------------------ | |
def load_models(): | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
device_map="auto" if DEVICE == "cuda" else None, | |
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, | |
trust_remote_code=True | |
) | |
embedder = SentenceTransformer(EMBED_MODEL, device=DEVICE) | |
return tokenizer, model, embedder | |
except Exception as e: | |
st.error(f"Model loading failed: {str(e)}") | |
st.stop() | |
tokenizer, model, embedder = load_models() | |
# ------------------------ | |
# Text Processing Functions | |
# ------------------------ | |
def split_text(text): | |
splitter = RecursiveCharacterTextSplitter( | |
chunk_size=CHUNK_SIZE, | |
chunk_overlap=CHUNK_OVERLAP, | |
length_function=len | |
) | |
return splitter.split_text(text) | |
def extract_text(file): | |
file_type = file.type | |
if file_type == "application/pdf": | |
try: | |
doc = fitz.open(stream=file.read(), filetype="pdf") | |
return "\n".join([page.get_text() for page in doc]) | |
except Exception as e: | |
st.error("Error processing PDF: " + str(e)) | |
return "" | |
elif file_type == "text/plain": | |
return file.read().decode("utf-8") | |
elif file_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document": | |
try: | |
return docx2txt.process(file) | |
except Exception as e: | |
st.error("Error processing DOCX: " + str(e)) | |
return "" | |
else: | |
st.error("Unsupported file type: " + file_type) | |
return "" | |
def build_index(chunks): | |
embeddings = embedder.encode(chunks, show_progress_bar=False) | |
dimension = embeddings.shape[1] | |
index = faiss.IndexFlatL2(dimension) | |
index.add(embeddings) | |
return index | |
# ------------------------ | |
# Summarization and Q&A Functions | |
# ------------------------ | |
def generate_summary(text): | |
# Create prompt for Phi-2 model | |
prompt = f"Instruct: Summarize this book in a concise paragraph\nInput: {text[:3000]}\nOutput:" | |
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=300, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True | |
) | |
summary = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return summary.split("Output:")[-1].strip() | |
def generate_answer(query, context): | |
# Create prompt for Phi-2 model | |
prompt = f"Instruct: Answer this question based on the context. If unsure, say 'I don't know'.\nQuestion: {query}\nContext: {context}\nOutput:" | |
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=300, | |
temperature=0.5, | |
top_p=0.9, | |
repetition_penalty=1.2, | |
do_sample=True | |
) | |
answer = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return answer.split("Output:")[-1].strip() | |
# ------------------------ | |
# Streamlit UI | |
# ------------------------ | |
st.title("π RAG-Based Book Analyzer") | |
st.write("Upload a book (PDF, TXT, DOCX) to get a summary and ask questions about its content.") | |
uploaded_file = st.file_uploader("Upload File", type=["pdf", "txt", "docx"]) | |
if uploaded_file: | |
text = extract_text(uploaded_file) | |
if text: | |
st.success("β File successfully processed!") | |
with st.spinner("Generating summary..."): | |
summary = generate_summary(text) | |
st.markdown("### Book Summary") | |
st.info(summary) | |
# Process text into chunks and build FAISS index | |
chunks = split_text(text) | |
index = build_index(chunks) | |
st.session_state.chunks = chunks | |
st.session_state.index = index | |
st.markdown("### β Ask a Question about the Book") | |
query = st.text_input("Enter your question:") | |
if query: | |
with st.spinner("Searching for answers..."): | |
# Retrieve top 3 relevant chunks as context | |
query_embedding = embedder.encode([query]) | |
distances, indices = st.session_state.index.search(query_embedding, k=3) | |
retrieved_chunks = [st.session_state.chunks[i] for i in indices[0] if i < len(st.session_state.chunks)] | |
context = "\n\n".join(retrieved_chunks) | |
answer = generate_answer(query, context) | |
st.markdown("### π¬ Answer") | |
st.success(answer) | |
with st.expander("See context used"): | |
st.write(context) |