Tesneem commited on
Commit
1bba21a
Β·
verified Β·
1 Parent(s): ccb0207

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -33
app.py CHANGED
@@ -161,30 +161,60 @@ def extract_with_llm(text: str) -> List[str]:
161
  st.error(str(e))
162
  return []
163
 
 
 
 
 
 
 
 
164
 
165
  # =================== Format Retrieved Chunks ===================
166
  def format_docs(docs: List[Document]) -> str:
167
  return "\n\n".join(doc.page_content or doc.metadata.get("content", "") for doc in docs)
168
 
169
  # =================== Generate Response from Hugging Face Model ===================
170
- def generate_response(input_dict: Dict[str, Any]) -> str:
171
- client = InferenceClient(api_key=HF_TOKEN.strip())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  prompt = grantbuddy_prompt.format(**input_dict)
 
 
 
 
 
 
 
 
173
 
174
- try:
175
- response = client.chat.completions.create(
176
- model="HuggingFaceH4/zephyr-7b-beta",
177
- messages=[
178
- {"role": "system", "content": prompt},
179
- {"role": "user", "content": input_dict["question"]},
180
- ],
181
- max_tokens=1000,
182
- temperature=0.2,
183
- )
184
- return response.choices[0].message.content
185
- except Exception as e:
186
- st.error(f"❌ Error from model: {e}")
187
- return "⚠️ Failed to generate response. Please check your model, HF token, or request format."
188
 
189
 
190
  # =================== RAG Chain ===================
@@ -199,13 +229,12 @@ def main():
199
  st.set_page_config(page_title="Grant Buddy RAG", page_icon="πŸ€–")
200
  st.title("πŸ€– Grant Buddy: Grant-Writing Assistant")
201
 
 
 
 
202
  uploaded_file = st.file_uploader("Upload PDF or TXT for extra context (optional)", type=["pdf", "txt"])
203
  uploaded_text = ""
204
 
205
- retriever = init_vector_search().as_retriever(search_kwargs={"k": 10, "score_threshold": 0.75})
206
- rag_chain = get_rag_chain(retriever) # βœ… Initialize before usage
207
-
208
- # πŸ” Process uploaded file
209
  if uploaded_file:
210
  with st.spinner("πŸ“„ Processing uploaded file..."):
211
  if uploaded_file.name.endswith(".pdf"):
@@ -214,26 +243,40 @@ def main():
214
  elif uploaded_file.name.endswith(".txt"):
215
  uploaded_text = uploaded_file.read().decode("utf-8")
216
 
217
- questions = extract_with_llm(uploaded_text)
218
- st.success(f"βœ… Found {len(questions)} questions or headers.")
219
- with st.expander("🧠 Extracted Prompts from Upload"):
220
- st.write(questions)
 
 
 
 
 
 
 
 
221
 
222
- # Generate answers
223
- answers = []
224
- for q in questions:
225
- full_query = f"{q}\n\nAdditional context:\n{uploaded_text}"
226
- response = rag_chain.invoke(full_query)
227
- answers.append({"question": q, "answer": response})
 
 
 
 
228
 
229
  for item in answers:
230
  st.markdown(f"### ❓ {item['question']}")
231
  st.markdown(f"πŸ’¬ {item['answer']}")
 
 
232
 
233
- # βœ… Manual query box
234
  query = st.text_input("Ask a grant-related question")
235
  if st.button("Submit"):
236
- if not query and not uploaded_file:
237
  st.warning("Please enter a question.")
238
  return
239
 
@@ -245,13 +288,14 @@ def main():
245
  with st.expander("πŸ” Retrieved Chunks"):
246
  context_docs = retriever.get_relevant_documents(full_query)
247
  for doc in context_docs:
248
- st.markdown(f"**Chunk ID:** {doc.metadata.get('chunk_id', 'unknown','title')}")
249
  st.markdown(doc.page_content[:700] + "...")
250
  st.markdown("---")
251
 
252
 
253
 
254
 
 
255
  if __name__ == "__main__":
256
  main()
257
 
 
161
  st.error(str(e))
162
  return []
163
 
164
+ # def is_meaningful_prompt(text: str) -> bool:
165
+ # too_short = len(text.strip()) < 10
166
+ # banned_keywords = ["phone", "email", "fax", "address", "date", "contact", "website"]
167
+ # contains_bad_word = any(word in text.lower() for word in banned_keywords)
168
+ # is_just_punctuation = all(c in ":.*- " for c in text.strip())
169
+
170
+ # return not (too_short or contains_bad_word or is_just_punctuation)
171
 
172
  # =================== Format Retrieved Chunks ===================
173
  def format_docs(docs: List[Document]) -> str:
174
  return "\n\n".join(doc.page_content or doc.metadata.get("content", "") for doc in docs)
175
 
176
  # =================== Generate Response from Hugging Face Model ===================
177
+ # def generate_response(input_dict: Dict[str, Any]) -> str:
178
+ # client = InferenceClient(api_key=HF_TOKEN.strip())
179
+ # prompt = grantbuddy_prompt.format(**input_dict)
180
+
181
+ # try:
182
+ # response = client.chat.completions.create(
183
+ # model="HuggingFaceH4/zephyr-7b-beta",
184
+ # messages=[
185
+ # {"role": "system", "content": prompt},
186
+ # {"role": "user", "content": input_dict["question"]},
187
+ # ],
188
+ # max_tokens=1000,
189
+ # temperature=0.2,
190
+ # )
191
+ # return response.choices[0].message.content
192
+ # except Exception as e:
193
+ # st.error(f"❌ Error from model: {e}")
194
+ # return "⚠️ Failed to generate response. Please check your model, HF token, or request format."
195
+ from transformers import AutoModelForCausalLM, AutoTokenizer
196
+ import torch
197
+
198
+ @st.cache_resource
199
+ def load_local_model():
200
+ model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
201
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
202
+ model = AutoModelForCausalLM.from_pretrained(model_name)
203
+ return tokenizer, model
204
+
205
+ tokenizer, model = load_local_model()
206
+
207
+ def generate_response(input_dict):
208
  prompt = grantbuddy_prompt.format(**input_dict)
209
+ inputs = tokenizer(prompt, return_tensors="pt")
210
+ outputs = model.generate(
211
+ **inputs,
212
+ max_new_tokens=512,
213
+ temperature=0.7,
214
+ do_sample=True
215
+ )
216
+ return tokenizer.decode(outputs[0], skip_special_tokens=True).split("QUESTION:")[-1].strip()
217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
 
220
  # =================== RAG Chain ===================
 
229
  st.set_page_config(page_title="Grant Buddy RAG", page_icon="πŸ€–")
230
  st.title("πŸ€– Grant Buddy: Grant-Writing Assistant")
231
 
232
+ retriever = init_vector_search().as_retriever(search_kwargs={"k": 10, "score_threshold": 0.75})
233
+ rag_chain = get_rag_chain(retriever)
234
+
235
  uploaded_file = st.file_uploader("Upload PDF or TXT for extra context (optional)", type=["pdf", "txt"])
236
  uploaded_text = ""
237
 
 
 
 
 
238
  if uploaded_file:
239
  with st.spinner("πŸ“„ Processing uploaded file..."):
240
  if uploaded_file.name.endswith(".pdf"):
 
243
  elif uploaded_file.name.endswith(".txt"):
244
  uploaded_text = uploaded_file.read().decode("utf-8")
245
 
246
+ # 🧠 Extract prompts using LLM
247
+ questions = extract_with_llm(uploaded_text)
248
+
249
+ # 🚫 Filter out irrelevant junk
250
+ def is_meaningful_prompt(text: str) -> bool:
251
+ too_short = len(text.strip()) < 10
252
+ banned_keywords = ["phone", "email", "fax", "address", "date", "contact", "website"]
253
+ contains_bad_word = any(word in text.lower() for word in banned_keywords)
254
+ is_just_punctuation = all(c in ":.*- " for c in text.strip())
255
+ return not (too_short or contains_bad_word or is_just_punctuation)
256
+
257
+ filtered_questions = [q for q in questions if is_meaningful_prompt(q)]
258
 
259
+ # 🎯 Prompt selection UI
260
+ selected_questions = st.multiselect("βœ… Choose prompts to answer:", filtered_questions, default=filtered_questions)
261
+
262
+ if selected_questions:
263
+ with st.spinner("πŸ’‘ Generating answers..."):
264
+ answers = []
265
+ for q in selected_questions:
266
+ full_query = f"{q}\n\nAdditional context:\n{uploaded_text}"
267
+ response = rag_chain.invoke(full_query)
268
+ answers.append({"question": q, "answer": response})
269
 
270
  for item in answers:
271
  st.markdown(f"### ❓ {item['question']}")
272
  st.markdown(f"πŸ’¬ {item['answer']}")
273
+ else:
274
+ st.info("No prompts selected for answering.")
275
 
276
+ # ✍️ Manual single-question input
277
  query = st.text_input("Ask a grant-related question")
278
  if st.button("Submit"):
279
+ if not query:
280
  st.warning("Please enter a question.")
281
  return
282
 
 
288
  with st.expander("πŸ” Retrieved Chunks"):
289
  context_docs = retriever.get_relevant_documents(full_query)
290
  for doc in context_docs:
291
+ st.markdown(f"**Chunk ID:** {doc.metadata.get('chunk_id', 'unknown')} | **Title:** {doc.metadata.get('title', 'unknown')}")
292
  st.markdown(doc.page_content[:700] + "...")
293
  st.markdown("---")
294
 
295
 
296
 
297
 
298
+
299
  if __name__ == "__main__":
300
  main()
301