File size: 5,569 Bytes
76552c4
ab46633
 
64794fb
76552c4
64794fb
66e139c
64794fb
ab46633
 
64794fb
 
 
 
 
66e139c
 
ab46633
 
64794fb
76552c4
64794fb
ab46633
64794fb
76552c4
 
64794fb
66e139c
64794fb
 
ab46633
64794fb
66e139c
b32efb7
ab46633
64794fb
 
 
 
76552c4
ab46633
 
64794fb
ab46633
64794fb
ab46633
64794fb
 
 
ab46633
64794fb
ab46633
64794fb
ab46633
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b32efb7
ab46633
b32efb7
ab46633
 
64794fb
 
ab46633
64794fb
ab46633
66e139c
 
ab46633
b32efb7
 
 
 
 
 
 
ab46633
66e139c
64794fb
ab46633
66e139c
 
b32efb7
ab46633
 
 
b32efb7
ab46633
 
 
 
 
66e139c
76552c4
64794fb
ab46633
64794fb
b32efb7
ab46633
 
 
 
 
 
 
b32efb7
 
 
 
 
 
76552c4
ab46633
 
 
 
 
76552c4
b32efb7
 
64794fb
b32efb7
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import streamlit as st
st.set_page_config(page_title="RAG Book Analyzer", layout="wide")  # Must be the first Streamlit command

import torch
import numpy as np
import faiss
from transformers import AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
import fitz  # PyMuPDF for PDF extraction
import docx2txt  # For DOCX extraction
from langchain_text_splitters import RecursiveCharacterTextSplitter

# ------------------------
# Configuration
# ------------------------
MODEL_NAME = "microsoft/phi-2"  # Open-source model with good performance
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"  # Smaller embedding model
CHUNK_SIZE = 512
CHUNK_OVERLAP = 64
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ------------------------
# Model Loading with Caching
# ------------------------
@st.cache_resource
def load_models():
    try:
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            device_map="auto" if DEVICE == "cuda" else None,
            torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
            trust_remote_code=True
        )
        embedder = SentenceTransformer(EMBED_MODEL, device=DEVICE)
        return tokenizer, model, embedder
    except Exception as e:
        st.error(f"Model loading failed: {str(e)}")
        st.stop()

tokenizer, model, embedder = load_models()

# ------------------------
# Text Processing Functions
# ------------------------
def split_text(text):
    splitter = RecursiveCharacterTextSplitter(
        chunk_size=CHUNK_SIZE,
        chunk_overlap=CHUNK_OVERLAP,
        length_function=len
    )
    return splitter.split_text(text)

def extract_text(file):
    file_type = file.type
    if file_type == "application/pdf":
        try:
            doc = fitz.open(stream=file.read(), filetype="pdf")
            return "\n".join([page.get_text() for page in doc])
        except Exception as e:
            st.error("Error processing PDF: " + str(e))
            return ""
    elif file_type == "text/plain":
        return file.read().decode("utf-8")
    elif file_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
        try:
            return docx2txt.process(file)
        except Exception as e:
            st.error("Error processing DOCX: " + str(e))
            return ""
    else:
        st.error("Unsupported file type: " + file_type)
        return ""

def build_index(chunks):
    embeddings = embedder.encode(chunks, show_progress_bar=False)
    dimension = embeddings.shape[1]
    index = faiss.IndexFlatL2(dimension)
    index.add(embeddings)
    return index

# ------------------------
# Summarization and Q&A Functions
# ------------------------
def generate_summary(text):
    # Create prompt for Phi-2 model
    prompt = f"Instruct: Summarize this book in a concise paragraph\nInput: {text[:3000]}\nOutput:"
    inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
    outputs = model.generate(
        **inputs,
        max_new_tokens=300,
        temperature=0.7,
        top_p=0.9,
        do_sample=True
    )
    summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return summary.split("Output:")[-1].strip()

def generate_answer(query, context):
    # Create prompt for Phi-2 model
    prompt = f"Instruct: Answer this question based on the context. If unsure, say 'I don't know'.\nQuestion: {query}\nContext: {context}\nOutput:"
    inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
    outputs = model.generate(
        **inputs,
        max_new_tokens=300,
        temperature=0.5,
        top_p=0.9,
        repetition_penalty=1.2,
        do_sample=True
    )
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return answer.split("Output:")[-1].strip()

# ------------------------
# Streamlit UI
# ------------------------
st.title("πŸ“š RAG-Based Book Analyzer")
st.write("Upload a book (PDF, TXT, DOCX) to get a summary and ask questions about its content.")

uploaded_file = st.file_uploader("Upload File", type=["pdf", "txt", "docx"])

if uploaded_file:
    text = extract_text(uploaded_file)
    if text:
        st.success("βœ… File successfully processed!")
        
        with st.spinner("Generating summary..."):
            summary = generate_summary(text)
            st.markdown("### Book Summary")
            st.info(summary)
        
        # Process text into chunks and build FAISS index
        chunks = split_text(text)
        index = build_index(chunks)
        st.session_state.chunks = chunks
        st.session_state.index = index
        
        st.markdown("### ❓ Ask a Question about the Book")
        query = st.text_input("Enter your question:")
        if query:
            with st.spinner("Searching for answers..."):
                # Retrieve top 3 relevant chunks as context
                query_embedding = embedder.encode([query])
                distances, indices = st.session_state.index.search(query_embedding, k=3)
                retrieved_chunks = [st.session_state.chunks[i] for i in indices[0] if i < len(st.session_state.chunks)]
                context = "\n\n".join(retrieved_chunks)
                
                answer = generate_answer(query, context)
                st.markdown("### πŸ’¬ Answer")
                st.success(answer)
                
                with st.expander("See context used"):
                    st.write(context)