Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import numpy as np | |
import faiss | |
import time | |
import re | |
from typing import List, Tuple | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
from sentence_transformers import SentenceTransformer | |
import fitz # PyMuPDF | |
import docx2txt | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from io import BytesIO | |
# ------------------------ | |
# Configuration | |
# ------------------------ | |
MODEL_NAME = "ibm-granite/granite-3.1-1b-a400m-instruct" | |
EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2" | |
CHUNK_SIZE = 1024 # Increased for better context | |
CHUNK_OVERLAP = 128 | |
MAX_FILE_SIZE_MB = 10 | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
# ------------------------ | |
# Model Loading with Quantization | |
# ------------------------ | |
def load_models(): | |
try: | |
# Configure quantization for CPU deployment | |
quant_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True, | |
) if DEVICE == "cpu" else None | |
tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_NAME, | |
trust_remote_code=True, | |
revision="main" | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
trust_remote_code=True, | |
revision="main", | |
device_map="auto", | |
quantization_config=quant_config, | |
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, | |
low_cpu_mem_usage=True | |
).eval() | |
# Load embedding model with FP16 optimization | |
embedder = SentenceTransformer( | |
EMBED_MODEL, | |
device=DEVICE, | |
device_kwargs={"keep_all_models": True} | |
) | |
if DEVICE == "cuda": | |
embedder = embedder.half() | |
return tokenizer, model, embedder | |
except Exception as e: | |
st.error(f"Model loading failed: {str(e)}") | |
st.stop() | |
# ------------------------ | |
# Enhanced Text Processing | |
# ------------------------ | |
def clean_text(text: str) -> str: | |
"""Advanced text cleaning with multiple normalization steps""" | |
text = re.sub(r'\s+', ' ', text) # Remove extra whitespace | |
text = re.sub(r'[^\x00-\x7F]+', ' ', text) # Remove non-ASCII | |
text = re.sub(r'\bPage \d+\b', '', text) # Remove page numbers | |
text = re.sub(r'http\S+', '', text) # Remove URLs | |
text = re.sub(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '', text) # Remove emails | |
return text.strip() | |
def extract_text(file: BytesIO) -> Tuple[str, List[str]]: | |
"""Improved text extraction with format-specific handling""" | |
try: | |
if file.size > MAX_FILE_SIZE_MB * 1024 * 1024: | |
raise ValueError(f"File size exceeds {MAX_FILE_SIZE_MB}MB limit") | |
file_type = file.type | |
text = "" | |
if file_type == "application/pdf": | |
doc = fitz.open(stream=file.read(), filetype="pdf") | |
text = "\n".join([page.get_text("text", flags=fitz.TEXT_PRESERVE_WHITESPACE) for page in doc]) | |
# Extract images metadata for future multimodal expansion | |
images = [img for page in doc for img in page.get_images()] | |
if images: | |
st.session_state.images = images | |
elif file_type == "text/plain": | |
text = file.read().decode("utf-8", errors="replace") | |
elif file_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document": | |
text = docx2txt.process(file) | |
else: | |
raise ValueError("Unsupported file type") | |
return clean_text(text) | |
except Exception as e: | |
st.error(f"Text extraction failed: {str(e)}") | |
st.stop() | |
def semantic_chunking(text: str) -> List[str]: | |
"""Context-aware text splitting with metadata tracking""" | |
splitter = RecursiveCharacterTextSplitter( | |
chunk_size=CHUNK_SIZE, | |
chunk_overlap=CHUNK_OVERLAP, | |
length_function=len, | |
add_start_index=True | |
) | |
chunks = splitter.split_text(text) | |
return chunks | |
# ------------------------ | |
# Enhanced Vector Indexing | |
# ------------------------ | |
def build_faiss_index(chunks: List[str], embedder) -> faiss.Index: | |
"""Build optimized FAISS index with error handling""" | |
try: | |
embeddings = embedder.encode( | |
chunks, | |
batch_size=32, | |
show_progress_bar=True, | |
convert_to_tensor=True | |
) | |
if DEVICE == "cuda": | |
embeddings = embeddings.cpu().numpy() | |
else: | |
embeddings = embeddings.numpy() | |
dimension = embeddings.shape[1] | |
index = faiss.IndexFlatIP(dimension) | |
faiss.normalize_L2(embeddings) | |
index.add(embeddings) | |
return index | |
except Exception as e: | |
st.error(f"Index creation failed: {str(e)}") | |
st.stop() | |
# ------------------------ | |
# Improved Generation Functions | |
# ------------------------ | |
def format_prompt(system_prompt: str, user_input: str) -> str: | |
"""Structured prompt formatting for better model performance""" | |
return f"""<|system|> | |
{system_prompt} | |
<|user|> | |
{user_input} | |
<|assistant|> | |
""" | |
def generate_summary(text: str, tokenizer, model) -> str: | |
"""Hierarchical summarization with chunk processing""" | |
try: | |
# First-stage summary | |
chunks = [text[i:i+3000] for i in range(0, len(text), 3000)] | |
summaries = [] | |
for chunk in chunks: | |
prompt = format_prompt( | |
"Generate a detailed summary of this text excerpt:", | |
chunk[:2500] | |
) | |
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=300, | |
temperature=0.3, | |
do_sample=True | |
) | |
summary = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
summaries.append(summary.split("<|assistant|>")[-1].strip()) | |
# Final synthesis | |
final_prompt = format_prompt( | |
"Synthesize these summaries into a comprehensive overview:", | |
"\n".join(summaries) | |
) | |
inputs = tokenizer(final_prompt, return_tensors="pt").to(DEVICE) | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=500, | |
temperature=0.4, | |
do_sample=True | |
) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True).split("<|assistant|>")[-1].strip() | |
except Exception as e: | |
st.error(f"Summarization failed: {str(e)}") | |
return "Summary generation failed" | |
def retrieve_context(query: str, index, chunks: List[str], embedder, top_k: int = 3) -> str: | |
"""Enhanced retrieval with score thresholding""" | |
query_embed = embedder.encode([query], convert_to_tensor=True) | |
if DEVICE == "cuda": | |
query_embed = query_embed.cpu().numpy() | |
else: | |
query_embed = query_embed.numpy() | |
faiss.normalize_L2(query_embed) | |
scores, indices = index.search(query_embed, top_k*2) # Retrieve extra for filtering | |
# Apply similarity threshold | |
valid_indices = [i for i, score in zip(indices[0], scores[0]) if score > 0.35] | |
return " ".join([chunks[i] for i in valid_indices[:top_k]]) | |
# ------------------------ | |
# Streamlit UI Improvements | |
# ------------------------ | |
def main(): | |
st.set_page_config( | |
page_title="RAG Book Analyzer Pro", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# Initialize session state | |
if "processed" not in st.session_state: | |
st.session_state.processed = False | |
if "index" not in st.session_state: | |
st.session_state.index = None | |
# Load models once | |
tokenizer, model, embedder = load_models() | |
# Sidebar controls | |
with st.sidebar: | |
st.header("Settings") | |
top_k = st.slider("Number of context passages", 1, 5, 3) | |
temp = st.slider("Generation Temperature", 0.1, 1.0, 0.4) | |
# Main interface | |
st.title("π Advanced Book Analyzer") | |
st.write("Upload technical manuals, research papers, or books for deep analysis") | |
uploaded_file = st.file_uploader( | |
"Choose a document", | |
type=["pdf", "txt", "docx"], | |
accept_multiple_files=False | |
) | |
if uploaded_file and not st.session_state.processed: | |
with st.spinner("Analyzing document..."): | |
start_time = time.time() | |
# Process document | |
text = extract_text(uploaded_file) | |
chunks = semantic_chunking(text) | |
index = build_faiss_index(chunks, embedder) | |
# Store in session state | |
st.session_state.update({ | |
"chunks": chunks, | |
"index": index, | |
"processed": True, | |
"text": text | |
}) | |
st.success(f"Processed {len(chunks)} chunks in {time.time()-start_time:.1f}s") | |
if st.session_state.processed: | |
# Summary section | |
with st.expander("Document Summary", expanded=True): | |
summary = generate_summary(st.session_state.text, tokenizer, model) | |
st.markdown(summary) | |
# Q&A Section | |
st.divider() | |
col1, col2 = st.columns([3, 1]) | |
with col1: | |
query = st.text_input("Ask about the document:", placeholder="What are the key findings...") | |
with col2: | |
show_context = st.checkbox("Show context sources") | |
if query: | |
with st.spinner("Searching document..."): | |
context = retrieve_context( | |
query, | |
st.session_state.index, | |
st.session_state.chunks, | |
embedder, | |
top_k=top_k | |
) | |
if not context: | |
st.warning("No relevant context found in document") | |
return | |
with st.expander("Generated Answer", expanded=True): | |
answer = generate_answer(query, context, tokenizer, model, temp) | |
st.markdown(answer) | |
if show_context: | |
st.divider() | |
st.subheader("Source Context") | |
st.write(context) | |
def generate_answer(query: str, context: str, tokenizer, model, temp: float) -> str: | |
"""Improved answer generation with context validation""" | |
try: | |
prompt = format_prompt( | |
f"""Answer the question using only the provided context. | |
Follow these rules: | |
1. Be precise and factual | |
2. If unsure, say 'The document does not specify' | |
3. Use bullet points when listing items | |
4. Keep answers under 3 sentences | |
Context: {context[:2000]}""", | |
query | |
) | |
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=400, | |
temperature=temp, | |
top_p=0.9, | |
repetition_penalty=1.2, | |
do_sample=True | |
) | |
answer = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return answer.split("<|assistant|>")[-1].strip() | |
except Exception as e: | |
st.error(f"Generation failed: {str(e)}") | |
return "Unable to generate answer" | |
if __name__ == "__main__": | |
main() |