Tesneem commited on
Commit
dc6ea0c
·
verified ·
1 Parent(s): ffd0f02

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -47
app.py CHANGED
@@ -12,7 +12,7 @@ from pymongo import MongoClient
12
  from PyPDF2 import PdfReader
13
  st.set_page_config(page_title="Grant Buddy RAG", page_icon="🤖")
14
 
15
-
16
 
17
  from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
18
  from langchain.embeddings import HuggingFaceEmbeddings
@@ -229,9 +229,7 @@ def init_vector_search() -> MongoDBAtlasVectorSearch:
229
  # if len(clean) > 10 and not any(bad in clean.lower() for bad in ["phone", "email", "address", "website"]):
230
  # prompts.append(clean)
231
  # return prompts
232
- from typing import List
233
- import os
234
- import openai
235
 
236
  def extract_with_llm_local(text: str, use_openai: bool = False) -> List[str]:
237
  # Example context to prime the model
@@ -266,7 +264,7 @@ PROMPTS:
266
  return "⚠️ OpenAI key missing."
267
  try:
268
  response = client.chat.completions.create(
269
- model="gpt-3.5-turbo",
270
  messages=[
271
  {"role": "system", "content": "You extract prompts and headers from grant text."},
272
  {"role": "user", "content": prompt},
@@ -276,6 +274,8 @@ PROMPTS:
276
  )
277
  # raw_output = response["choices"][0]["message"]["content"]
278
  raw_output = response.choices[0].message.content
 
 
279
  except Exception as e:
280
  st.error(f"❌ OpenAI extraction failed: {e}")
281
  return []
@@ -351,16 +351,12 @@ def load_local_model():
351
  tokenizer, model = load_local_model()
352
 
353
  def generate_response(input_dict, use_openai=False):
354
- if use_openai:
355
- if not openai.api_key:
356
- st.error("❌ OPENAI_API_KEY is not set.")
357
- return "⚠️ OpenAI key missing."
358
-
359
- prompt = grantbuddy_prompt.format(**input_dict)
360
 
 
361
  try:
362
  response = client.chat.completions.create(
363
- model="gpt-3.5-turbo",
364
  messages=[
365
  {"role": "system", "content": prompt},
366
  {"role": "user", "content": input_dict["question"]},
@@ -368,14 +364,30 @@ def generate_response(input_dict, use_openai=False):
368
  temperature=0.2,
369
  max_tokens=700,
370
  )
371
- return response.choices[0].message.content.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  except Exception as e:
373
  st.error(f"❌ OpenAI error: {e}")
374
- return "⚠️ OpenAI request failed."
 
 
 
375
 
376
  else:
377
- # Local TinyLlama path
378
- prompt = grantbuddy_prompt.format(**input_dict)
379
  inputs = tokenizer(prompt, return_tensors="pt")
380
  outputs = model.generate(
381
  **inputs,
@@ -385,17 +397,31 @@ def generate_response(input_dict, use_openai=False):
385
  pad_token_id=tokenizer.eos_token_id
386
  )
387
  decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
388
- return decoded[len(prompt):].strip()
 
 
 
389
 
390
 
391
 
392
 
393
  # =================== RAG Chain ===================
394
  def get_rag_chain(retriever, use_openai=False):
395
- return {
396
- "context": retriever | RunnableLambda(format_docs),
397
- "question": RunnablePassthrough()
398
- } | RunnableLambda(lambda input_dict: generate_response(input_dict, use_openai=use_openai))
 
 
 
 
 
 
 
 
 
 
 
399
 
400
  # =================== Streamlit UI ===================
401
  def main():
@@ -404,7 +430,8 @@ def main():
404
  USE_OPENAI = st.sidebar.checkbox("Use OpenAI (Costs Tokens)", value=False)
405
  if "generated_queries" not in st.session_state:
406
  st.session_state.generated_queries = {}
407
-
 
408
 
409
  retriever = init_vector_search().as_retriever(search_kwargs={"k": 10, "score_threshold": 0.75})
410
  rag_chain = get_rag_chain(retriever, use_openai=USE_OPENAI)
@@ -440,12 +467,18 @@ def main():
440
  selected_questions.append(q)
441
  submit_button = st.form_submit_button("Submit")
442
 
 
443
  if 'submit_button' in locals() and submit_button:
444
  if selected_questions:
445
  with st.spinner("💡 Generating answers..."):
446
  answers = []
447
  for q in selected_questions:
448
- full_query = f"{q}\n\nAdditional context:\n{uploaded_text}"
 
 
 
 
 
449
  # response = rag_chain.invoke(full_query)
450
  # answers.append({"question": q, "answer": response})
451
  if q in st.session_state.generated_queries:
@@ -456,29 +489,16 @@ def main():
456
  answers.append({"question": q, "answer": response})
457
  for item in answers:
458
  st.markdown(f"### ❓ {item['question']}")
459
- st.markdown(f"💬 {item['answer']}")
 
 
 
 
 
460
  else:
461
  st.info("No prompts selected for answering.")
462
 
463
 
464
-
465
- # #select prompts to answer
466
- # selected_questions = st.multiselect("✅ Choose prompts to answer:", filtered_questions, default=filtered_questions)
467
-
468
- # if selected_questions:
469
- # with st.spinner("💡 Generating answers..."):
470
- # answers = []
471
- # for q in selected_questions:
472
- # full_query = f"{q}\n\nAdditional context:\n{uploaded_text}"
473
- # response = rag_chain.invoke(full_query)
474
- # answers.append({"question": q, "answer": response})
475
-
476
- # for item in answers:
477
- # st.markdown(f"### ❓ {item['question']}")
478
- # st.markdown(f"💬 {item['answer']}")
479
- # else:
480
- # st.info("No prompts selected for answering.")
481
-
482
  # ✍️ Manual single-question input
483
  query = st.text_input("Ask a grant-related question")
484
  if st.button("Submit"):
@@ -486,13 +506,19 @@ def main():
486
  st.warning("Please enter a question.")
487
  return
488
 
489
- full_query = f"{query}\n\nAdditional context:\n{uploaded_text}" if uploaded_text else query
 
490
  with st.spinner("🤖 Thinking..."):
491
- response = rag_chain.invoke(full_query)
492
- st.text_area("Grant Buddy says:", value=response, height=250, disabled=True)
493
-
 
 
 
 
 
494
  with st.expander("🔍 Retrieved Chunks"):
495
- context_docs = retriever.get_relevant_documents(full_query)
496
  for doc in context_docs:
497
  # st.json(doc.metadata)
498
  st.markdown(f"**Chunk ID:** {doc.metadata.get('chunk_id', 'unknown')} | **Title:** {doc.metadata['metadata'].get('title', 'unknown')}")
 
12
  from PyPDF2 import PdfReader
13
  st.set_page_config(page_title="Grant Buddy RAG", page_icon="🤖")
14
 
15
+ from typing import List
16
 
17
  from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
18
  from langchain.embeddings import HuggingFaceEmbeddings
 
229
  # if len(clean) > 10 and not any(bad in clean.lower() for bad in ["phone", "email", "address", "website"]):
230
  # prompts.append(clean)
231
  # return prompts
232
+
 
 
233
 
234
  def extract_with_llm_local(text: str, use_openai: bool = False) -> List[str]:
235
  # Example context to prime the model
 
264
  return "⚠️ OpenAI key missing."
265
  try:
266
  response = client.chat.completions.create(
267
+ model="gpt-4o-mini",
268
  messages=[
269
  {"role": "system", "content": "You extract prompts and headers from grant text."},
270
  {"role": "user", "content": prompt},
 
274
  )
275
  # raw_output = response["choices"][0]["message"]["content"]
276
  raw_output = response.choices[0].message.content
277
+ st.markdown(f"🧮 Extract Tokens: Prompt = {response.usage.prompt_tokens}, "
278
+ f"Completion = {response.usage.completion_tokens}, Total = {response.usage.total_tokens}")
279
  except Exception as e:
280
  st.error(f"❌ OpenAI extraction failed: {e}")
281
  return []
 
351
  tokenizer, model = load_local_model()
352
 
353
  def generate_response(input_dict, use_openai=False):
354
+ prompt = grantbuddy_prompt.format(**input_dict)
 
 
 
 
 
355
 
356
+ if use_openai:
357
  try:
358
  response = client.chat.completions.create(
359
+ model="gpt-4o-mini",
360
  messages=[
361
  {"role": "system", "content": prompt},
362
  {"role": "user", "content": input_dict["question"]},
 
364
  temperature=0.2,
365
  max_tokens=700,
366
  )
367
+ answer = response.choices[0].message.content.strip()
368
+
369
+ # ✅ Token logging
370
+ prompt_tokens = response.usage.prompt_tokens
371
+ completion_tokens = response.usage.completion_tokens
372
+ total_tokens = response.usage.total_tokens
373
+
374
+ return {
375
+ "answer": answer,
376
+ "tokens": {
377
+ "prompt": prompt_tokens,
378
+ "completion": completion_tokens,
379
+ "total": total_tokens
380
+ }
381
+ }
382
+
383
  except Exception as e:
384
  st.error(f"❌ OpenAI error: {e}")
385
+ return {
386
+ "answer": "⚠️ OpenAI request failed.",
387
+ "tokens": {}
388
+ }
389
 
390
  else:
 
 
391
  inputs = tokenizer(prompt, return_tensors="pt")
392
  outputs = model.generate(
393
  **inputs,
 
397
  pad_token_id=tokenizer.eos_token_id
398
  )
399
  decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
400
+ return {
401
+ "answer": decoded[len(prompt):].strip(),
402
+ "tokens": {}
403
+ }
404
 
405
 
406
 
407
 
408
  # =================== RAG Chain ===================
409
  def get_rag_chain(retriever, use_openai=False):
410
+ def merge_contexts(inputs):
411
+ retrieved_chunks = format_docs(retriever.invoke(inputs["question"]))
412
+ combined = "\n\n".join(filter(None, [
413
+ inputs.get("manual_context", ""),
414
+ retrieved_chunks
415
+ ]))
416
+ return {
417
+ "context": combined,
418
+ "question": inputs["question"]
419
+ }
420
+
421
+ return RunnableLambda(merge_contexts) | RunnableLambda(
422
+ lambda input_dict: generate_response(input_dict, use_openai=use_openai)
423
+ )
424
+
425
 
426
  # =================== Streamlit UI ===================
427
  def main():
 
430
  USE_OPENAI = st.sidebar.checkbox("Use OpenAI (Costs Tokens)", value=False)
431
  if "generated_queries" not in st.session_state:
432
  st.session_state.generated_queries = {}
433
+
434
+ manual_context = st.text_area("📝 Optional: Add your own context (e.g., mission, goals)", height=150)
435
 
436
  retriever = init_vector_search().as_retriever(search_kwargs={"k": 10, "score_threshold": 0.75})
437
  rag_chain = get_rag_chain(retriever, use_openai=USE_OPENAI)
 
467
  selected_questions.append(q)
468
  submit_button = st.form_submit_button("Submit")
469
 
470
+ #Multi-Select Question
471
  if 'submit_button' in locals() and submit_button:
472
  if selected_questions:
473
  with st.spinner("💡 Generating answers..."):
474
  answers = []
475
  for q in selected_questions:
476
+ # full_query = f"{q}\n\nAdditional context:\n{uploaded_text}"
477
+ combined_context = "\n\n".join(filter(None, [manual_context.strip(), uploaded_text.strip()]))
478
+ response = rag_chain.invoke({
479
+ "question": q,
480
+ "manual_context": combined_context
481
+ })
482
  # response = rag_chain.invoke(full_query)
483
  # answers.append({"question": q, "answer": response})
484
  if q in st.session_state.generated_queries:
 
489
  answers.append({"question": q, "answer": response})
490
  for item in answers:
491
  st.markdown(f"### ❓ {item['question']}")
492
+ st.markdown(f"💬 {item['answer']['answer']}")
493
+ tokens = item['answer'].get("tokens", {})
494
+ if tokens:
495
+ st.markdown(f"🧮 **Token Usage:** Prompt = {tokens.get('prompt')}, "
496
+ f"Completion = {tokens.get('completion')}, Total = {tokens.get('total')}")
497
+
498
  else:
499
  st.info("No prompts selected for answering.")
500
 
501
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
  # ✍️ Manual single-question input
503
  query = st.text_input("Ask a grant-related question")
504
  if st.button("Submit"):
 
506
  st.warning("Please enter a question.")
507
  return
508
 
509
+ # full_query = f"{query}\n\nAdditional context:\n{uploaded_text}" if uploaded_text else query
510
+ combined_context = "\n\n".join(filter(None, [manual_context.strip(), uploaded_text.strip()]))
511
  with st.spinner("🤖 Thinking..."):
512
+ # response = rag_chain.invoke(full_query)
513
+ response = rag_chain.invoke({"question":query,"manual_context": combined_context})
514
+ st.text_area("Grant Buddy says:", value=response["answer"], height=250, disabled=True)
515
+ tokens=response.get("tokens",{})
516
+ if tokens:
517
+ st.markdown(f"🧮 **Token Usage:** Prompt = {tokens.get('prompt')}, "
518
+ f"Completion = {tokens.get('completion')}, Total = {tokens.get('total')}")
519
+
520
  with st.expander("🔍 Retrieved Chunks"):
521
+ context_docs = retriever.get_relevant_documents(query)
522
  for doc in context_docs:
523
  # st.json(doc.metadata)
524
  st.markdown(f"**Chunk ID:** {doc.metadata.get('chunk_id', 'unknown')} | **Title:** {doc.metadata['metadata'].get('title', 'unknown')}")