bainskarman commited on
Commit
28202fc
·
verified ·
1 Parent(s): 783a14e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -108
app.py CHANGED
@@ -8,34 +8,25 @@ from sentence_transformers import SentenceTransformer
8
  from langdetect import detect
9
 
10
  # Load the Hugging Face token
11
- huggingface_token = os.environ.get("Key2")
 
12
 
13
  # Load Sentence Transformer Model
14
- embedder = SentenceTransformer("all-MiniLM-L6-v2")
15
-
16
- # Default system prompts for each query translation method
17
- DEFAULT_SYSTEM_PROMPTS = {
18
- "Multi-Query": """You are an AI language model assistant. Your task is to generate five \
19
- different versions of the given user question to retrieve relevant documents from a vector \
20
- database. By generating multiple perspectives on the user question, your goal is to help\
21
- the user overcome some of the limitations of the distance-based similarity search.\
22
- Provide these alternative questions separated by newlines. Original question: {question}""",
23
- "RAG Fusion": """You are an AI language model assistant. Your task is to combine multiple \
24
- queries into a single, refined query to improve retrieval accuracy. Original question: {question}""",
25
- "Decomposition": """You are an AI language model assistant. Your task is to break down \
26
- the given user question into simpler sub-questions. Provide these sub-questions separated \
27
- by newlines. Original question: {question}""",
28
- "Step Back": """You are an AI language model assistant. Your task is to refine the given \
29
- user question by taking a step back and asking a more general question. Original question: {question}""",
30
- "HyDE": """You are an AI language model assistant. Your task is to generate a hypothetical \
31
- document that would be relevant to the given user question. Original question: {question}""",
32
  }
33
 
34
- # Function to query the Hugging Face model
35
- def query_huggingface_model(prompt, max_new_tokens=1000, temperature=0.7, top_k=50):
36
- model_name = "HuggingFaceH4/zephyr-7b-alpha"
37
- api_url = f"https://api-inference.huggingface.co/models/{model_name}"
38
- headers = {"Authorization": f"Bearer {huggingface_token}"}
39
  payload = {
40
  "inputs": prompt,
41
  "parameters": {
@@ -44,63 +35,77 @@ def query_huggingface_model(prompt, max_new_tokens=1000, temperature=0.7, top_k=
44
  "top_k": top_k,
45
  },
46
  }
47
- response = requests.post(api_url, headers=headers, json=payload)
48
  if response.status_code == 200:
49
  return response.json()[0]["generated_text"]
50
- else:
51
- st.error(f"Error: {response.status_code} - {response.text}")
52
- return None
53
-
54
- # Function to detect language
55
- def detect_language(text):
56
- try:
57
- return detect(text)
58
- except:
59
- return "en"
60
-
61
- # Extract text from PDF with line and page numbers
62
- def extract_text_from_pdf(pdf_file):
63
- text = extract_text(pdf_file)
64
- return text.split("\n")
65
-
66
- # Chunk text into smaller segments
67
- def split_text_into_chunks(text_lines, chunk_size=500):
68
  words = " ".join(text_lines).split()
69
  return [" ".join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)]
70
 
 
71
  # Build FAISS Index
72
- def build_faiss_index(embeddings):
73
  dimension = embeddings.shape[1]
74
  index = faiss.IndexFlatL2(dimension)
75
  index.add(embeddings)
76
  return index
77
 
78
- # Search in FAISS Index
79
- def search_faiss_index(query_embedding, index, top_k=5):
 
80
  distances, indices = index.search(query_embedding, top_k)
81
- return indices[0], distances[0]
82
 
83
 
 
 
 
 
 
 
 
84
 
85
- def main():
86
- st.title("Enhanced RAG Model with FAISS Indexing")
87
 
88
- # Sidebar for options
89
- st.sidebar.header("Upload PDF")
90
- pdf_file = st.sidebar.file_uploader("Upload a PDF file", type="pdf")
 
 
91
 
92
- st.sidebar.header("Query Translation")
93
- query_translation = st.sidebar.selectbox(
94
- "Select Query Translation Method",
95
- ["Multi-Query", "RAG Fusion", "Decomposition", "Step Back", "HyDE"]
96
- )
 
 
 
97
 
98
- st.sidebar.header("Similarity Search")
99
- similarity_method = st.sidebar.selectbox("Select Similarity Search Method", ["Cosine Similarity", "KNN"])
100
- if similarity_method == "KNN":
101
- k_value = st.sidebar.slider("Select K Value", 1, 10, 5)
102
 
103
- # LLM Parameters
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  max_new_tokens = st.sidebar.slider("Max New Tokens", 10, 1000, 500)
105
  temperature = st.sidebar.slider("Temperature", 0.1, 1.0, 0.7)
106
  top_k = st.sidebar.slider("Top K", 1, 100, 50)
@@ -108,66 +113,34 @@ def main():
108
  # Input Prompt
109
  prompt = st.text_input("Enter your query:")
110
 
111
- # State to hold intermediate results
112
- if 'embeddings' not in st.session_state:
113
- st.session_state.embeddings = None
114
  if 'chunks' not in st.session_state:
115
  st.session_state.chunks = []
116
  if 'faiss_index' not in st.session_state:
117
  st.session_state.faiss_index = None
118
- if 'relevant_chunks' not in st.session_state:
119
- st.session_state.relevant_chunks = []
120
- if 'translated_queries' not in st.session_state:
121
- st.session_state.translated_queries = []
122
 
123
- # Button 1: Embed PDF
124
  if st.button("1. Embed PDF") and pdf_file:
125
- text_lines = extract_text_from_pdf(pdf_file)
126
- st.session_state.lang = detect_language(" ".join(text_lines))
127
- st.write(f"**Detected Language:** {st.session_state.lang}")
128
-
129
- # Chunk the text
130
- st.session_state.chunks = split_text_into_chunks(text_lines)
131
-
132
- # Encode chunks
133
- chunk_embeddings = embedder.encode(st.session_state.chunks, convert_to_tensor=False)
134
-
135
- # Build FAISS index
136
- st.session_state.faiss_index = build_faiss_index(np.array(chunk_embeddings))
137
-
138
  st.success("PDF Embedded Successfully")
139
 
140
- # Button 2: Generate Translated Queries
141
  if st.button("2. Query Translation") and prompt:
142
- formatted_prompt = DEFAULT_SYSTEM_PROMPTS[query_translation].format(question=prompt)
143
- response = query_huggingface_model(formatted_prompt, max_new_tokens, temperature, top_k)
144
- st.session_state.translated_queries = response.split("\n")
145
- st.write("**Generated Queries:**")
146
- st.write(st.session_state.translated_queries)
147
 
148
- # Button 3: Retrieve Document Details
149
  if st.button("3. Retrieve Documents") and st.session_state.translated_queries:
150
- st.session_state.relevant_chunks = []
151
- for query in st.session_state.translated_queries:
152
- query_embedding = embedder.encode([query], convert_to_tensor=False)
153
- top_k_indices, _ = search_faiss_index(np.array(query_embedding), st.session_state.faiss_index, top_k=5)
154
- relevant_chunks = [st.session_state.chunks[i] for i in top_k_indices]
155
- st.session_state.relevant_chunks.append(relevant_chunks)
156
-
157
- st.write("**Retrieved Documents (for each query):**")
158
- for i, relevant_chunks in enumerate(st.session_state.relevant_chunks):
159
- st.write(f"**Query {i + 1}: {st.session_state.translated_queries[i]}**")
160
- for chunk in relevant_chunks:
161
- st.write(f"{chunk[:100]}...")
162
-
163
- # Button 4: Generate Final Response
164
- if st.button("4. Final Response") and st.session_state.relevant_chunks:
165
- context = "\n".join([chunk for sublist in st.session_state.relevant_chunks for chunk in sublist])
166
- llm_input = f"{DEFAULT_SYSTEM_PROMPTS[query_translation].format(question=prompt)}\n\nContext: {context}\n\nAnswer this question: {prompt}"
167
- final_response = query_huggingface_model(llm_input, max_new_tokens, temperature, top_k)
168
 
 
 
 
 
169
  st.subheader("Final Response:")
170
  st.write(final_response)
171
 
 
172
  if __name__ == "__main__":
173
  main()
 
8
  from langdetect import detect
9
 
10
  # Load the Hugging Face token
11
+ HUGGINGFACE_TOKEN = os.environ.get("Key2")
12
+ HF_MODEL = "HuggingFaceH4/zephyr-7b-alpha"
13
 
14
  # Load Sentence Transformer Model
15
+ EMBEDDER = SentenceTransformer("all-MiniLM-L6-v2")
16
+
17
+ # Default system prompts
18
+ SYSTEM_PROMPTS = {
19
+ "Multi-Query": "Generate five alternative versions of the user question: {question}",
20
+ "RAG Fusion": "Combine multiple queries into a single, refined query: {question}",
21
+ "Decomposition": "Break down the user question into simpler sub-questions: {question}",
22
+ "Step Back": "Refine the user question by asking a more general question: {question}",
23
+ "HyDE": "Generate a hypothetical document relevant to the user question: {question}",
 
 
 
 
 
 
 
 
 
24
  }
25
 
26
+
27
+ # Helper function to interact with Hugging Face API
28
+ def query_hf(prompt, max_new_tokens=1000, temperature=0.7, top_k=50):
29
+ headers = {"Authorization": f"Bearer {HUGGINGFACE_TOKEN}"}
 
30
  payload = {
31
  "inputs": prompt,
32
  "parameters": {
 
35
  "top_k": top_k,
36
  },
37
  }
38
+ response = requests.post(f"https://api-inference.huggingface.co/models/{HF_MODEL}", headers=headers, json=payload)
39
  if response.status_code == 200:
40
  return response.json()[0]["generated_text"]
41
+ st.error(f"Error: {response.status_code} - {response.text}")
42
+
43
+
44
+ # Extract text from PDF
45
+ def extract_pdf_text(pdf_file):
46
+ return extract_text(pdf_file).split("\n")
47
+
48
+
49
+ # Chunk text into segments
50
+ def chunk_text(text_lines, chunk_size=500):
 
 
 
 
 
 
 
 
51
  words = " ".join(text_lines).split()
52
  return [" ".join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)]
53
 
54
+
55
  # Build FAISS Index
56
+ def build_index(embeddings):
57
  dimension = embeddings.shape[1]
58
  index = faiss.IndexFlatL2(dimension)
59
  index.add(embeddings)
60
  return index
61
 
62
+
63
+ # Search FAISS Index
64
+ def search_index(query_embedding, index, top_k=5):
65
  distances, indices = index.search(query_embedding, top_k)
66
+ return indices[0]
67
 
68
 
69
+ # Embed PDF content and build FAISS index
70
+ def process_pdf(pdf_file):
71
+ text_lines = extract_pdf_text(pdf_file)
72
+ chunks = chunk_text(text_lines)
73
+ embeddings = EMBEDDER.encode(chunks, convert_to_tensor=False)
74
+ faiss_index = build_index(np.array(embeddings))
75
+ return chunks, faiss_index
76
 
 
 
77
 
78
+ # Generate query translations
79
+ def translate_query(prompt, method, max_new_tokens, temperature, top_k):
80
+ formatted_prompt = SYSTEM_PROMPTS[method].format(question=prompt)
81
+ return query_hf(formatted_prompt, max_new_tokens, temperature, top_k).split("\n")
82
+
83
 
84
+ # Retrieve relevant chunks from FAISS index
85
+ def retrieve_chunks(translated_queries, faiss_index, chunks, top_k=5):
86
+ relevant_chunks = []
87
+ for query in translated_queries:
88
+ query_embedding = EMBEDDER.encode([query], convert_to_tensor=False)
89
+ indices = search_index(np.array(query_embedding), faiss_index, top_k)
90
+ relevant_chunks.extend([chunks[i] for i in indices])
91
+ return relevant_chunks
92
 
 
 
 
 
93
 
94
+ # Generate final response using RAG approach
95
+ def generate_final_response(prompt, context, max_new_tokens, temperature, top_k):
96
+ input_text = f"Context: {context}\n\nAnswer this question: {prompt}"
97
+ return query_hf(input_text, max_new_tokens, temperature, top_k)
98
+
99
+
100
+ # Streamlit UI
101
+ def main():
102
+ st.title("Enhanced RAG Model with FAISS Indexing")
103
+
104
+ # Sidebar Inputs
105
+ pdf_file = st.sidebar.file_uploader("Upload PDF", type="pdf")
106
+ query_translation = st.sidebar.selectbox("Query Translation Method", list(SYSTEM_PROMPTS.keys()))
107
+ similarity_method = st.sidebar.selectbox("Similarity Search Method", ["Cosine Similarity", "KNN"])
108
+ k_value = st.sidebar.slider("K Value (for KNN)", 1, 10, 5) if similarity_method == "KNN" else 5
109
  max_new_tokens = st.sidebar.slider("Max New Tokens", 10, 1000, 500)
110
  temperature = st.sidebar.slider("Temperature", 0.1, 1.0, 0.7)
111
  top_k = st.sidebar.slider("Top K", 1, 100, 50)
 
113
  # Input Prompt
114
  prompt = st.text_input("Enter your query:")
115
 
116
+ # State Management
 
 
117
  if 'chunks' not in st.session_state:
118
  st.session_state.chunks = []
119
  if 'faiss_index' not in st.session_state:
120
  st.session_state.faiss_index = None
 
 
 
 
121
 
122
+ # Step 1: Process PDF
123
  if st.button("1. Embed PDF") and pdf_file:
124
+ st.session_state.chunks, st.session_state.faiss_index = process_pdf(pdf_file)
 
 
 
 
 
 
 
 
 
 
 
 
125
  st.success("PDF Embedded Successfully")
126
 
127
+ # Step 2: Generate Translated Queries
128
  if st.button("2. Query Translation") and prompt:
129
+ st.session_state.translated_queries = translate_query(prompt, query_translation, max_new_tokens, temperature, top_k)
130
+ st.write("**Generated Queries:**", st.session_state.translated_queries)
 
 
 
131
 
132
+ # Step 3: Retrieve Relevant Chunks
133
  if st.button("3. Retrieve Documents") and st.session_state.translated_queries:
134
+ st.session_state.relevant_chunks = retrieve_chunks(st.session_state.translated_queries, st.session_state.faiss_index, st.session_state.chunks, top_k=k_value)
135
+ st.write("**Retrieved Chunks:**", st.session_state.relevant_chunks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
+ # Step 4: Generate Final Response
138
+ if st.button("4. Generate Final Response") and st.session_state.relevant_chunks:
139
+ context = "\n".join(st.session_state.relevant_chunks)
140
+ final_response = generate_final_response(prompt, context, max_new_tokens, temperature, top_k)
141
  st.subheader("Final Response:")
142
  st.write(final_response)
143
 
144
+
145
  if __name__ == "__main__":
146
  main()