sunbal7 commited on
Commit
3f00b29
Β·
verified Β·
1 Parent(s): 66e139c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -70
app.py CHANGED
@@ -1,38 +1,51 @@
1
  import streamlit as st
2
- st.set_page_config(page_title="RAG Book Analyzer", layout="wide") # Must be the first Streamlit command
3
 
4
  import torch
5
  import numpy as np
6
  import faiss
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from sentence_transformers import SentenceTransformer
9
- import fitz # PyMuPDF for PDF extraction
10
- import docx2txt # For DOCX extraction
11
  from langchain_text_splitters import RecursiveCharacterTextSplitter
12
 
13
  # ------------------------
14
- # Configuration
15
  # ------------------------
16
- MODEL_NAME = "microsoft/phi-2" # Open-source model with good performance
17
- EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2" # Smaller embedding model
18
  CHUNK_SIZE = 512
19
  CHUNK_OVERLAP = 64
20
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
21
 
22
  # ------------------------
23
- # Model Loading with Caching
24
  # ------------------------
25
- @st.cache_resource
26
  def load_models():
27
  try:
28
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
 
 
 
 
 
 
 
 
29
  model = AutoModelForCausalLM.from_pretrained(
30
  MODEL_NAME,
31
- device_map="auto" if DEVICE == "cuda" else None,
32
  torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
33
- trust_remote_code=True
 
 
34
  )
 
 
35
  embedder = SentenceTransformer(EMBED_MODEL, device=DEVICE)
 
36
  return tokenizer, model, embedder
37
  except Exception as e:
38
  st.error(f"Model loading failed: {str(e)}")
@@ -52,24 +65,19 @@ def split_text(text):
52
  return splitter.split_text(text)
53
 
54
  def extract_text(file):
55
- file_type = file.type
56
- if file_type == "application/pdf":
57
- try:
58
  doc = fitz.open(stream=file.read(), filetype="pdf")
59
  return "\n".join([page.get_text() for page in doc])
60
- except Exception as e:
61
- st.error("Error processing PDF: " + str(e))
62
- return ""
63
- elif file_type == "text/plain":
64
- return file.read().decode("utf-8")
65
- elif file_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
66
- try:
67
  return docx2txt.process(file)
68
- except Exception as e:
69
- st.error("Error processing DOCX: " + str(e))
70
  return ""
71
- else:
72
- st.error("Unsupported file type: " + file_type)
73
  return ""
74
 
75
  def build_index(chunks):
@@ -80,74 +88,126 @@ def build_index(chunks):
80
  return index
81
 
82
  # ------------------------
83
- # Summarization and Q&A Functions
84
  # ------------------------
85
  def generate_summary(text):
86
- # Create prompt for Phi-2 model
87
- prompt = f"Instruct: Summarize this book in a concise paragraph\nInput: {text[:3000]}\nOutput:"
88
- inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
 
 
 
 
 
 
 
89
  outputs = model.generate(
90
  **inputs,
91
- max_new_tokens=300,
92
  temperature=0.7,
93
  top_p=0.9,
94
- do_sample=True
 
95
  )
96
- summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
97
- return summary.split("Output:")[-1].strip()
 
 
 
 
 
 
 
 
98
 
99
  def generate_answer(query, context):
100
- # Create prompt for Phi-2 model
101
- prompt = f"Instruct: Answer this question based on the context. If unsure, say 'I don't know'.\nQuestion: {query}\nContext: {context}\nOutput:"
102
- inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
 
 
 
 
 
 
 
103
  outputs = model.generate(
104
  **inputs,
105
- max_new_tokens=300,
106
- temperature=0.5,
107
- top_p=0.9,
108
- repetition_penalty=1.2,
109
- do_sample=True
 
 
 
 
 
 
110
  )
111
- answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
112
- return answer.split("Output:")[-1].strip()
 
 
 
113
 
114
  # ------------------------
115
  # Streamlit UI
116
  # ------------------------
117
  st.title("πŸ“š RAG-Based Book Analyzer")
118
  st.write("Upload a book (PDF, TXT, DOCX) to get a summary and ask questions about its content.")
 
119
 
120
  uploaded_file = st.file_uploader("Upload File", type=["pdf", "txt", "docx"])
121
 
122
  if uploaded_file:
123
- text = extract_text(uploaded_file)
124
- if text:
125
- st.success("βœ… File successfully processed!")
126
-
127
- with st.spinner("Generating summary..."):
128
- summary = generate_summary(text)
129
- st.markdown("### Book Summary")
130
- st.info(summary)
131
-
132
- # Process text into chunks and build FAISS index
 
 
 
 
 
133
  chunks = split_text(text)
134
  index = build_index(chunks)
135
  st.session_state.chunks = chunks
136
  st.session_state.index = index
137
-
138
- st.markdown("### ❓ Ask a Question about the Book")
139
- query = st.text_input("Enter your question:")
140
- if query:
141
- with st.spinner("Searching for answers..."):
142
- # Retrieve top 3 relevant chunks as context
143
- query_embedding = embedder.encode([query])
144
- distances, indices = st.session_state.index.search(query_embedding, k=3)
145
- retrieved_chunks = [st.session_state.chunks[i] for i in indices[0] if i < len(st.session_state.chunks)]
146
- context = "\n\n".join(retrieved_chunks)
147
-
148
- answer = generate_answer(query, context)
149
- st.markdown("### πŸ’¬ Answer")
150
- st.success(answer)
151
-
152
- with st.expander("See context used"):
153
- st.write(context)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ st.set_page_config(page_title="RAG Book Analyzer", layout="wide")
3
 
4
  import torch
5
  import numpy as np
6
  import faiss
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from sentence_transformers import SentenceTransformer
9
+ import fitz # PyMuPDF
10
+ import docx2txt
11
  from langchain_text_splitters import RecursiveCharacterTextSplitter
12
 
13
  # ------------------------
14
+ # Configuration (optimized for reliability)
15
  # ------------------------
16
+ MODEL_NAME = "microsoft/phi-2"
17
+ EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2" # Efficient embedding model
18
  CHUNK_SIZE = 512
19
  CHUNK_OVERLAP = 64
20
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
+ MAX_TEXT_LENGTH = 3000 # To prevent OOM errors
22
 
23
  # ------------------------
24
+ # Model Loading with Robust Error Handling
25
  # ------------------------
26
+ @st.cache_resource(show_spinner="Loading AI models...")
27
  def load_models():
28
  try:
29
+ # Load tokenizer with special settings for Phi-2
30
+ tokenizer = AutoTokenizer.from_pretrained(
31
+ MODEL_NAME,
32
+ trust_remote_code=True,
33
+ padding_side="left"
34
+ )
35
+ tokenizer.pad_token = tokenizer.eos_token
36
+
37
+ # Load model with safe defaults
38
  model = AutoModelForCausalLM.from_pretrained(
39
  MODEL_NAME,
 
40
  torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
41
+ trust_remote_code=True,
42
+ device_map="auto" if DEVICE == "cuda" else None,
43
+ low_cpu_mem_usage=True
44
  )
45
+
46
+ # Load efficient embedding model
47
  embedder = SentenceTransformer(EMBED_MODEL, device=DEVICE)
48
+
49
  return tokenizer, model, embedder
50
  except Exception as e:
51
  st.error(f"Model loading failed: {str(e)}")
 
65
  return splitter.split_text(text)
66
 
67
  def extract_text(file):
68
+ try:
69
+ if file.type == "application/pdf":
 
70
  doc = fitz.open(stream=file.read(), filetype="pdf")
71
  return "\n".join([page.get_text() for page in doc])
72
+ elif file.type == "text/plain":
73
+ return file.read().decode("utf-8")
74
+ elif file.type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
 
 
 
 
75
  return docx2txt.process(file)
76
+ else:
77
+ st.error(f"Unsupported file type: {file.type}")
78
  return ""
79
+ except Exception as e:
80
+ st.error(f"Error processing file: {str(e)}")
81
  return ""
82
 
83
  def build_index(chunks):
 
88
  return index
89
 
90
  # ------------------------
91
+ # AI Generation Functions (with safeguards)
92
  # ------------------------
93
  def generate_summary(text):
94
+ text = text[:MAX_TEXT_LENGTH] # Prevent long inputs
95
+ prompt = f"Instruction: Summarize this book in a concise paragraph\nText: {text}\nSummary:"
96
+
97
+ inputs = tokenizer(
98
+ prompt,
99
+ return_tensors="pt",
100
+ max_length=1024,
101
+ truncation=True
102
+ ).to(DEVICE)
103
+
104
  outputs = model.generate(
105
  **inputs,
106
+ max_new_tokens=200,
107
  temperature=0.7,
108
  top_p=0.9,
109
+ do_sample=True,
110
+ pad_token_id=tokenizer.eos_token_id
111
  )
112
+
113
+ summary = tokenizer.decode(
114
+ outputs[0],
115
+ skip_special_tokens=True
116
+ )
117
+
118
+ # Extract just the summary part
119
+ if "Summary:" in summary:
120
+ return summary.split("Summary:")[-1].strip()
121
+ return summary.replace(prompt, "").strip()
122
 
123
  def generate_answer(query, context):
124
+ context = context[:MAX_TEXT_LENGTH] # Limit context size
125
+ prompt = f"Instruction: Answer this question based on the context. If unsure, say 'I don't know'.\nQuestion: {query}\nContext: {context}\nAnswer:"
126
+
127
+ inputs = tokenizer(
128
+ prompt,
129
+ return_tensors="pt",
130
+ max_length=1024,
131
+ truncation=True
132
+ ).to(DEVICE)
133
+
134
  outputs = model.generate(
135
  **inputs,
136
+ max_new_tokens=150,
137
+ temperature=0.4,
138
+ top_p=0.85,
139
+ repetition_penalty=1.1,
140
+ do_sample=True,
141
+ pad_token_id=tokenizer.eos_token_id
142
+ )
143
+
144
+ answer = tokenizer.decode(
145
+ outputs[0],
146
+ skip_special_tokens=True
147
  )
148
+
149
+ # Extract just the answer part
150
+ if "Answer:" in answer:
151
+ return answer.split("Answer:")[-1].strip()
152
+ return answer.replace(prompt, "").strip()
153
 
154
  # ------------------------
155
  # Streamlit UI
156
  # ------------------------
157
  st.title("πŸ“š RAG-Based Book Analyzer")
158
  st.write("Upload a book (PDF, TXT, DOCX) to get a summary and ask questions about its content.")
159
+ st.warning("Note: First run will download models (~1.5GB). Please be patient!")
160
 
161
  uploaded_file = st.file_uploader("Upload File", type=["pdf", "txt", "docx"])
162
 
163
  if uploaded_file:
164
+ with st.spinner("Extracting text from file..."):
165
+ text = extract_text(uploaded_file)
166
+
167
+ if not text:
168
+ st.error("Failed to extract text. Please try another file.")
169
+ st.stop()
170
+
171
+ st.success(f"βœ… Extracted {len(text)} characters")
172
+
173
+ with st.spinner("Generating summary (this may take a minute)..."):
174
+ summary = generate_summary(text)
175
+ st.markdown("### Book Summary")
176
+ st.info(summary)
177
+
178
+ with st.spinner("Preparing document for questions..."):
179
  chunks = split_text(text)
180
  index = build_index(chunks)
181
  st.session_state.chunks = chunks
182
  st.session_state.index = index
183
+ st.success(f"βœ… Document indexed with {len(chunks)} chunks")
184
+
185
+ st.divider()
186
+
187
+ if 'chunks' in st.session_state:
188
+ st.markdown("### ❓ Ask a Question about the Book")
189
+ query = st.text_input("Enter your question:", key="question")
190
+
191
+ if query:
192
+ with st.spinner("Searching for answers..."):
193
+ # Retrieve top 3 relevant chunks
194
+ query_embedding = embedder.encode([query])
195
+ distances, indices = st.session_state.index.search(query_embedding, k=3)
196
+
197
+ # Safely retrieve chunks
198
+ retrieved_chunks = []
199
+ for i in indices[0]:
200
+ if i < len(st.session_state.chunks):
201
+ retrieved_chunks.append(st.session_state.chunks[i])
202
+
203
+ context = "\n\n".join(retrieved_chunks)
204
+
205
+ # Generate answer
206
+ answer = generate_answer(query, context)
207
+
208
+ # Display results
209
+ st.markdown("### πŸ’¬ Answer")
210
+ st.success(answer)
211
+
212
+ with st.expander("View context used for answer"):
213
+ st.text(context)