bainskarman commited on
Commit
783a14e
·
verified ·
1 Parent(s): 4610f9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -31
app.py CHANGED
@@ -80,7 +80,8 @@ def search_faiss_index(query_embedding, index, top_k=5):
80
  distances, indices = index.search(query_embedding, top_k)
81
  return indices[0], distances[0]
82
 
83
- # Streamlit App
 
84
  def main():
85
  st.title("Enhanced RAG Model with FAISS Indexing")
86
 
@@ -107,45 +108,66 @@ def main():
107
  # Input Prompt
108
  prompt = st.text_input("Enter your query:")
109
 
110
- if pdf_file and prompt:
111
- # Extract text from PDF
 
 
 
 
 
 
 
 
 
 
 
 
112
  text_lines = extract_text_from_pdf(pdf_file)
113
-
114
- # Detect Language
115
- lang = detect_language(" ".join(text_lines))
116
- st.write(f"**Detected Language:** {lang}")
117
-
118
  # Chunk the text
119
- chunks = split_text_into_chunks(text_lines)
120
 
121
  # Encode chunks
122
- chunk_embeddings = embedder.encode(chunks, convert_to_tensor=False)
123
 
124
  # Build FAISS index
125
- index = build_faiss_index(np.array(chunk_embeddings))
126
-
127
- # Embed the query
128
- query_embedding = embedder.encode([prompt], convert_to_tensor=False)
129
 
130
- # Search for relevant chunks
131
- top_k_indices, _ = search_faiss_index(np.array(query_embedding), index, top_k=5)
132
 
133
- # Retrieve relevant chunks
134
- relevant_chunks = [chunks[i] for i in top_k_indices]
135
-
136
- # Combine the context
137
- context = "\n".join(relevant_chunks)
138
-
139
- # Format the system prompt
140
  formatted_prompt = DEFAULT_SYSTEM_PROMPTS[query_translation].format(question=prompt)
141
-
142
- # Query LLM
143
- llm_input = f"{formatted_prompt}\n\nContext: {context}\n\nAnswer this question: {prompt}"
144
- response = query_huggingface_model(llm_input, max_new_tokens, temperature, top_k)
145
-
146
- # Display the result
147
- st.subheader("Response:")
148
- st.write(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  if __name__ == "__main__":
151
  main()
 
80
  distances, indices = index.search(query_embedding, top_k)
81
  return indices[0], distances[0]
82
 
83
+
84
+
85
  def main():
86
  st.title("Enhanced RAG Model with FAISS Indexing")
87
 
 
108
  # Input Prompt
109
  prompt = st.text_input("Enter your query:")
110
 
111
+ # State to hold intermediate results
112
+ if 'embeddings' not in st.session_state:
113
+ st.session_state.embeddings = None
114
+ if 'chunks' not in st.session_state:
115
+ st.session_state.chunks = []
116
+ if 'faiss_index' not in st.session_state:
117
+ st.session_state.faiss_index = None
118
+ if 'relevant_chunks' not in st.session_state:
119
+ st.session_state.relevant_chunks = []
120
+ if 'translated_queries' not in st.session_state:
121
+ st.session_state.translated_queries = []
122
+
123
+ # Button 1: Embed PDF
124
+ if st.button("1. Embed PDF") and pdf_file:
125
  text_lines = extract_text_from_pdf(pdf_file)
126
+ st.session_state.lang = detect_language(" ".join(text_lines))
127
+ st.write(f"**Detected Language:** {st.session_state.lang}")
128
+
 
 
129
  # Chunk the text
130
+ st.session_state.chunks = split_text_into_chunks(text_lines)
131
 
132
  # Encode chunks
133
+ chunk_embeddings = embedder.encode(st.session_state.chunks, convert_to_tensor=False)
134
 
135
  # Build FAISS index
136
+ st.session_state.faiss_index = build_faiss_index(np.array(chunk_embeddings))
 
 
 
137
 
138
+ st.success("PDF Embedded Successfully")
 
139
 
140
+ # Button 2: Generate Translated Queries
141
+ if st.button("2. Query Translation") and prompt:
 
 
 
 
 
142
  formatted_prompt = DEFAULT_SYSTEM_PROMPTS[query_translation].format(question=prompt)
143
+ response = query_huggingface_model(formatted_prompt, max_new_tokens, temperature, top_k)
144
+ st.session_state.translated_queries = response.split("\n")
145
+ st.write("**Generated Queries:**")
146
+ st.write(st.session_state.translated_queries)
147
+
148
+ # Button 3: Retrieve Document Details
149
+ if st.button("3. Retrieve Documents") and st.session_state.translated_queries:
150
+ st.session_state.relevant_chunks = []
151
+ for query in st.session_state.translated_queries:
152
+ query_embedding = embedder.encode([query], convert_to_tensor=False)
153
+ top_k_indices, _ = search_faiss_index(np.array(query_embedding), st.session_state.faiss_index, top_k=5)
154
+ relevant_chunks = [st.session_state.chunks[i] for i in top_k_indices]
155
+ st.session_state.relevant_chunks.append(relevant_chunks)
156
+
157
+ st.write("**Retrieved Documents (for each query):**")
158
+ for i, relevant_chunks in enumerate(st.session_state.relevant_chunks):
159
+ st.write(f"**Query {i + 1}: {st.session_state.translated_queries[i]}**")
160
+ for chunk in relevant_chunks:
161
+ st.write(f"{chunk[:100]}...")
162
+
163
+ # Button 4: Generate Final Response
164
+ if st.button("4. Final Response") and st.session_state.relevant_chunks:
165
+ context = "\n".join([chunk for sublist in st.session_state.relevant_chunks for chunk in sublist])
166
+ llm_input = f"{DEFAULT_SYSTEM_PROMPTS[query_translation].format(question=prompt)}\n\nContext: {context}\n\nAnswer this question: {prompt}"
167
+ final_response = query_huggingface_model(llm_input, max_new_tokens, temperature, top_k)
168
+
169
+ st.subheader("Final Response:")
170
+ st.write(final_response)
171
 
172
  if __name__ == "__main__":
173
  main()