Tesneem commited on
Commit
4f22430
ยท
verified ยท
1 Parent(s): 7ca40a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +584 -7
app.py CHANGED
@@ -408,7 +408,10 @@ def generate_response(input_dict, use_openai=False, max_tokens=700):
408
  # =================== RAG Chain ===================
409
  def get_rag_chain(retriever, use_openai=False, max_tokens=700):
410
  def merge_contexts(inputs):
411
- retrieved_chunks = format_docs(retriever.invoke(inputs["question"]))
 
 
 
412
  combined = "\n\n".join(filter(None, [
413
  inputs.get("manual_context", ""),
414
  retrieved_chunks
@@ -422,6 +425,26 @@ def get_rag_chain(retriever, use_openai=False, max_tokens=700):
422
  lambda input_dict: generate_response(input_dict, use_openai=use_openai, max_tokens=max_tokens)
423
  )
424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
 
426
  # =================== Streamlit UI ===================
427
  def main():
@@ -432,7 +455,9 @@ def main():
432
 
433
  k_value = st.sidebar.slider("How many chunks to retrieve (k)", min_value=5, max_value=40, step=5, value=10)
434
  score_threshold = st.sidebar.slider("Minimum relevance score", min_value=0.0, max_value=1.0, step=0.05, value=0.75)
435
-
 
 
436
  st.sidebar.markdown("### Generation Settings")
437
  max_tokens = st.sidebar.number_input("Max tokens in response", min_value=100, max_value=1500, value=700, step=50)
438
 
@@ -440,8 +465,15 @@ def main():
440
  st.session_state.generated_queries = {}
441
 
442
  manual_context = st.text_area("๐Ÿ“ Optional: Add your own context (e.g., mission, goals)", height=150)
443
-
444
- retriever = init_vector_search().as_retriever(search_kwargs={"k": k_value, "score_threshold": score_threshold})
 
 
 
 
 
 
 
445
  rag_chain = get_rag_chain(retriever, use_openai=USE_OPENAI, max_tokens=max_tokens)
446
 
447
  uploaded_file = st.file_uploader("Upload PDF or TXT for extra context (optional)", type=["pdf", "txt"])
@@ -488,7 +520,8 @@ def main():
488
  else:
489
  response = rag_chain.invoke({
490
  "question": q,
491
- "manual_context": combined_context
 
492
  })
493
  st.session_state.generated_queries[q] = response
494
  answers.append({"question": q, "answer": response})
@@ -515,7 +548,7 @@ def main():
515
  combined_context = "\n\n".join(filter(None, [manual_context.strip(), uploaded_text.strip()]))
516
  with st.spinner("๐Ÿค– Thinking..."):
517
  # response = rag_chain.invoke(full_query)
518
- response = rag_chain.invoke({"question":query,"manual_context": combined_context})
519
  st.text_area("Grant Buddy says:", value=response["answer"], height=250, disabled=True)
520
  tokens=response.get("tokens",{})
521
  if tokens:
@@ -523,11 +556,14 @@ def main():
523
  f"Completion = {tokens.get('completion')}, Total = {tokens.get('total')}")
524
 
525
  with st.expander("๐Ÿ” Retrieved Chunks"):
526
- context_docs = retriever.get_relevant_documents(query)
527
  for doc in context_docs:
528
  # st.json(doc.metadata)
529
  st.markdown(f"**Chunk ID:** {doc.metadata.get('chunk_id', 'unknown')} | **Title:** {doc.metadata['metadata'].get('title', 'unknown')}")
530
  st.markdown(doc.page_content[:700] + "...")
 
 
 
531
  st.markdown("---")
532
 
533
 
@@ -538,3 +574,544 @@ if __name__ == "__main__":
538
  main()
539
 
540
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  # =================== RAG Chain ===================
409
  def get_rag_chain(retriever, use_openai=False, max_tokens=700):
410
  def merge_contexts(inputs):
411
+ #use chunks if provided
412
+ retrieved_chunks = format_docs(inputs["context_docs"]) if "context_docs" in inputs \
413
+ else format_docs(retriever.invoke(inputs["question"]))
414
+
415
  combined = "\n\n".join(filter(None, [
416
  inputs.get("manual_context", ""),
417
  retrieved_chunks
 
425
  lambda input_dict: generate_response(input_dict, use_openai=use_openai, max_tokens=max_tokens)
426
  )
427
 
428
+ )
429
+ def rerank_with_topics(chunks, topics, alpha=0.2):
430
+ """
431
+ Boosts similarity based on topic overlap.
432
+ Since chunks don't have scores, we use rank order and topic matches.
433
+ """
434
+ topics_lower = set(t.lower() for t in topics)
435
+
436
+ def score(chunk, rank):
437
+ chunk_topics = [t.lower() for t in chunk.metadata.get("topics", [])]
438
+ topic_matches = len(topics_lower.intersection(chunk_topics))
439
+ # Lower is better: original rank minus boost
440
+ return rank - alpha * topic_matches
441
+
442
+ reranked = sorted(
443
+ enumerate(chunks),
444
+ key=lambda x: score(x[1], x[0]) # x[0] is rank, x[1] is chunk
445
+ )
446
+ return [chunk for _, chunk in reranked]
447
+
448
 
449
  # =================== Streamlit UI ===================
450
  def main():
 
455
 
456
  k_value = st.sidebar.slider("How many chunks to retrieve (k)", min_value=5, max_value=40, step=5, value=10)
457
  score_threshold = st.sidebar.slider("Minimum relevance score", min_value=0.0, max_value=1.0, step=0.05, value=0.75)
458
+ topic_input=st.sidebar.text_input("Optional: Focus on specific topics (comma-separated)")
459
+ topics=[t.strip() for t in topic_input.split(",") if t.strip()]
460
+ topic_weight= st.sidebar.slider("Topic relevance score", min_value=0.0, max_value=1.0, step=0.05, value=0.2)
461
  st.sidebar.markdown("### Generation Settings")
462
  max_tokens = st.sidebar.number_input("Max tokens in response", min_value=100, max_value=1500, value=700, step=50)
463
 
 
465
  st.session_state.generated_queries = {}
466
 
467
  manual_context = st.text_area("๐Ÿ“ Optional: Add your own context (e.g., mission, goals)", height=150)
468
+
469
+ # retriever = init_vector_search().as_retriever(search_kwargs={"k": k_value, "score_threshold": score_threshold})
470
+ retriever = init_vector_search().as_retriever()
471
+
472
+ pre_k = k_value*4 # Retrieve more chunks first
473
+ context_docs = retriever.get_relevant_documents(query, k=pre_k)
474
+ if topics:
475
+ context_docs = rerank_with_topics(context_docs, topics, alpha=topic_weight)
476
+ context_docs = context_docs[:k_value] # Final top-k used in RAG
477
  rag_chain = get_rag_chain(retriever, use_openai=USE_OPENAI, max_tokens=max_tokens)
478
 
479
  uploaded_file = st.file_uploader("Upload PDF or TXT for extra context (optional)", type=["pdf", "txt"])
 
520
  else:
521
  response = rag_chain.invoke({
522
  "question": q,
523
+ "manual_context": combined_context,
524
+ "context_docs": context_docs
525
  })
526
  st.session_state.generated_queries[q] = response
527
  answers.append({"question": q, "answer": response})
 
548
  combined_context = "\n\n".join(filter(None, [manual_context.strip(), uploaded_text.strip()]))
549
  with st.spinner("๐Ÿค– Thinking..."):
550
  # response = rag_chain.invoke(full_query)
551
+ response = rag_chain.invoke({"question":query,"manual_context": combined_context, "context_docs": context_docs})
552
  st.text_area("Grant Buddy says:", value=response["answer"], height=250, disabled=True)
553
  tokens=response.get("tokens",{})
554
  if tokens:
 
556
  f"Completion = {tokens.get('completion')}, Total = {tokens.get('total')}")
557
 
558
  with st.expander("๐Ÿ” Retrieved Chunks"):
559
+ # context_docs = retriever.get_relevant_documents(query)
560
  for doc in context_docs:
561
  # st.json(doc.metadata)
562
  st.markdown(f"**Chunk ID:** {doc.metadata.get('chunk_id', 'unknown')} | **Title:** {doc.metadata['metadata'].get('title', 'unknown')}")
563
  st.markdown(doc.page_content[:700] + "...")
564
+ if topics:
565
+ matched_topics=set(doc.metadata['metadata'].get('topics',[])).intersection(topics)
566
+ st.markdown(f"**Matched Topics**{','.join(matched_topics)")
567
  st.markdown("---")
568
 
569
 
 
574
  main()
575
 
576
 
577
+
578
+ # # app.py
579
+ # import os
580
+ # import re
581
+ # import openai
582
+ # from huggingface_hub import InferenceClient
583
+ # import json
584
+ # from huggingface_hub import HfApi
585
+ # import streamlit as st
586
+ # from typing import List, Dict, Any
587
+ # from urllib.parse import quote_plus
588
+ # from pymongo import MongoClient
589
+ # from PyPDF2 import PdfReader
590
+ # st.set_page_config(page_title="Grant Buddy RAG", page_icon="๐Ÿค–")
591
+
592
+ # from typing import List
593
+
594
+ # from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
595
+ # from langchain.embeddings import HuggingFaceEmbeddings
596
+
597
+ # from langchain_community.vectorstores import MongoDBAtlasVectorSearch
598
+ # from langchain.prompts import PromptTemplate
599
+ # from langchain.schema import Document
600
+ # from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
601
+ # from huggingface_hub import InferenceClient
602
+
603
+ # # =================== Secure Env via Hugging Face Secrets ===================
604
+ # user = quote_plus(os.getenv("MONGO_USERNAME"))
605
+ # password = quote_plus(os.getenv("MONGO_PASSWORD"))
606
+ # cluster = os.getenv("MONGO_CLUSTER")
607
+ # db_name = os.getenv("MONGO_DB_NAME", "files")
608
+ # collection_name = os.getenv("MONGO_COLLECTION", "files_collection")
609
+ # index_name = os.getenv("MONGO_VECTOR_INDEX", "vector_index")
610
+
611
+ # HF_TOKEN = os.getenv("HF_TOKEN")
612
+ # OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "").strip()
613
+ # if OPENAI_API_KEY:
614
+ # openai.api_key = OPENAI_API_KEY
615
+ # from openai import OpenAI
616
+ # client = OpenAI(api_key=OPENAI_API_KEY)
617
+
618
+ # # MONGO_URI = f"mongodb+srv://{user}:{password}@{cluster}/{db_name}?retryWrites=true&w=majority"
619
+ # MONGO_URI = f"mongodb+srv://{user}:{password}@{cluster}/{db_name}?retryWrites=true&w=majority&tls=true&tlsAllowInvalidCertificates=true"
620
+
621
+
622
+ # # =================== Prompt ===================
623
+ # grantbuddy_prompt = PromptTemplate.from_template(
624
+ # """You are Grant Buddy, a specialized language model fine-tuned with instruction-tuning and RLHF.
625
+ # You help a nonprofit focused on social entrepreneurship, BIPOC empowerment, and edtech write clear, mission-aligned grant responses.
626
+
627
+ # **Instructions:**
628
+ # - Start with reasoning or context for your answer.
629
+ # - Always align with the nonprofitโ€™s mission.
630
+ # - Use structured formatting: headings, bullet points, numbered lists.
631
+ # - Include impact data or examples if relevant.
632
+ # - Do NOT repeat the same sentence or answer multiple times.
633
+ # - If no answer exists in the context, say: "This information is not available in the current context."
634
+
635
+ # CONTEXT:
636
+ # {context}
637
+
638
+ # QUESTION:
639
+ # {question}
640
+ # """
641
+ # )
642
+
643
+
644
+
645
+ # # =================== Vector Search Setup ===================
646
+ # @st.cache_resource
647
+ # def init_embedding_model():
648
+ # return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
649
+
650
+
651
+ # @st.cache_resource
652
+
653
+
654
+ # def init_vector_search() -> MongoDBAtlasVectorSearch:
655
+ # HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
656
+ # model_name = "sentence-transformers/all-MiniLM-L6-v2"
657
+ # st.write(f"๐Ÿ”Œ Connecting to Hugging Face model: `{model_name}`")
658
+
659
+ # embedding_model = HuggingFaceEmbeddings(model_name=model_name)
660
+
661
+ # # โœ… Manual MongoClient with TLS settings
662
+ # user = quote_plus(os.getenv("MONGO_USERNAME", "").strip())
663
+ # password = quote_plus(os.getenv("MONGO_PASSWORD", "").strip())
664
+ # cluster = os.getenv("MONGO_CLUSTER", "").strip()
665
+ # db_name = os.getenv("MONGO_DB_NAME", "files").strip()
666
+ # collection_name = os.getenv("MONGO_COLLECTION", "files_collection").strip()
667
+ # index_name = os.getenv("MONGO_VECTOR_INDEX", "vector_index").strip()
668
+
669
+ # mongo_uri = f"mongodb+srv://{user}:{password}@{cluster}/?retryWrites=true&w=majority"
670
+
671
+ # try:
672
+ # client = MongoClient(mongo_uri, tls=True, tlsAllowInvalidCertificates=True, serverSelectionTimeoutMS=20000)
673
+ # db = client[db_name]
674
+ # collection = db[collection_name]
675
+ # st.success("โœ… MongoClient connected successfully")
676
+
677
+ # return MongoDBAtlasVectorSearch(
678
+ # collection=collection,
679
+ # embedding=embedding_model,
680
+ # index_name=index_name,
681
+ # )
682
+
683
+ # except Exception as e:
684
+ # st.error("โŒ Failed to connect to MongoDB Atlas manually")
685
+ # st.error(str(e))
686
+ # raise e
687
+ # # =================== Question/Headers Extraction ===================
688
+ # # def extract_questions_and_headers(text: str) -> List[str]:
689
+ # # header_patterns = [
690
+ # # r'\d+\.\s+\*\*([^\*]+)\*\*',
691
+ # # r'\*\*([^*]+)\*\*',
692
+ # # r'^([A-Z][^a-z]*[A-Z])$',
693
+ # # r'^([A-Z][A-Za-z\s]{3,})$',
694
+ # # r'^[A-Z][A-Za-z\s]+:$'
695
+ # # ]
696
+ # # question_patterns = [
697
+ # # r'^.+\?$',
698
+ # # r'^\*?Please .+',
699
+ # # r'^How .+',
700
+ # # r'^What .+',
701
+ # # r'^Describe .+',
702
+ # # ]
703
+ # # combined_header_re = re.compile("|".join(header_patterns), re.MULTILINE)
704
+ # # combined_question_re = re.compile("|".join(question_patterns), re.MULTILINE)
705
+
706
+ # # headers = [match for group in combined_header_re.findall(text) for match in group if match]
707
+ # # questions = combined_question_re.findall(text)
708
+
709
+ # # return headers + questions
710
+ # # def extract_with_llm(text: str) -> List[str]:
711
+ # # client = InferenceClient(api_key=HF_TOKEN.strip())
712
+ # # try:
713
+ # # response = client.chat.completions.create(
714
+ # # model="mistralai/Mistral-Nemo-Instruct-2407", # or "HuggingFaceH4/zephyr-7b-beta"
715
+ # # messages=[
716
+ # # {
717
+ # # "role": "system",
718
+ # # "content": "You are an assistant helping extract questions and headers from grant applications.",
719
+ # # },
720
+ # # {
721
+ # # "role": "user",
722
+ # # "content": (
723
+ # # "Please extract all the grant application headers and questions from the following text. "
724
+ # # "Include section titles, prompts, and any question-like content. Return them as a numbered list.\n\n"
725
+ # # f"{text[:3000]}"
726
+ # # ),
727
+ # # },
728
+ # # ],
729
+ # # temperature=0.2,
730
+ # # max_tokens=512,
731
+ # # )
732
+ # # return [
733
+ # # line.strip("โ€ข-1234567890. ").strip()
734
+ # # for line in response.choices[0].message.content.strip().split("\n")
735
+ # # if line.strip()
736
+ # # ]
737
+ # # except Exception as e:
738
+ # # st.error("โŒ LLM extraction failed")
739
+ # # st.error(str(e))
740
+ # # return []
741
+ # # def extract_with_llm_local(text: str) -> List[str]:
742
+ # # prompt = (
743
+ # # "You are an assistant helping extract useful questions and section headers from a grant application.\n"
744
+ # # "Return only the important prompts as a numbered list.\n\n"
745
+ # # "TEXT:\n"
746
+ # # f"{text[:3000]}\n\n"
747
+ # # "PROMPTS:"
748
+ # # )
749
+ # # inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
750
+ # # outputs = model.generate(
751
+ # # **inputs,
752
+ # # max_new_tokens=512,
753
+ # # temperature=0.3,
754
+ # # do_sample=False
755
+ # # )
756
+ # # raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
757
+
758
+ # # # Extract prompts from the numbered list in the output
759
+ # # lines = raw_output.split("\n")
760
+ # # prompts = []
761
+ # # for line in lines:
762
+ # # line = line.strip("โ€ข-1234567890. ").strip()
763
+ # # if len(line) > 10:
764
+ # # prompts.append(line)
765
+ # # return prompts
766
+ # # def extract_with_llm_local(text: str) -> List[str]:
767
+ # # example_text = """TEXT:
768
+ # # 1. Project Summary: Please describe the main goals of your project.
769
+ # # 2. Contact Information: Address, phone, email.
770
+ # # 3. What is the mission of your organization?
771
+ # # 4. Who are the beneficiaries?
772
+ # # 5. Budget Breakdown
773
+ # # 6. Please describe how the funding will be used.
774
+ # # 7. Website: www.example.org
775
+
776
+ # # PROMPTS:
777
+ # # 1. Project Summary
778
+ # # 2. What is the mission of your organization?
779
+ # # 3. Who are the beneficiaries?
780
+ # # 4. Please describe how the funding will be used.
781
+ # # """
782
+
783
+ # # prompt = (
784
+ # # "You are an assistant helping extract important grant application prompts and section headers.\n"
785
+ # # "Return only questions and meaningful section titles that require thoughtful answers.\n"
786
+ # # "Avoid metadata like phone numbers, dates, contact info, or websites.\n\n"
787
+ # # f"{example_text}\n"
788
+ # # f"TEXT:\n{text[:3000]}\n\n"
789
+ # # "PROMPTS:"
790
+ # # )
791
+
792
+ # # inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
793
+ # # outputs = model.generate(
794
+ # # **inputs,
795
+ # # max_new_tokens=512,
796
+ # # temperature=0.3,
797
+ # # do_sample=False
798
+ # # )
799
+ # # raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
800
+
801
+ # # # Clean and extract numbered or bulleted lines
802
+ # # lines = raw_output.split("\n")
803
+ # # prompts = []
804
+ # # for line in lines:
805
+ # # clean = line.strip("โ€ข-1234567890. ").strip()
806
+ # # if len(clean) > 10 and not any(bad in clean.lower() for bad in ["phone", "email", "address", "website"]):
807
+ # # prompts.append(clean)
808
+ # # return prompts
809
+
810
+
811
+ # def extract_with_llm_local(text: str, use_openai: bool = False) -> List[str]:
812
+ # # Example context to prime the model
813
+ # example_text = """TEXT:
814
+ # 1. Project Summary: Please describe the main goals of your project.
815
+ # 2. Contact Information: Address, phone, email.
816
+ # 3. What is the mission of your organization?
817
+ # 4. Who are the beneficiaries?
818
+ # 5. Budget Breakdown
819
+ # 6. Please describe how the funding will be used.
820
+ # 7. Website: www.example.org
821
+
822
+ # PROMPTS:
823
+ # 1. Project Summary
824
+ # 2. What is the mission of your organization?
825
+ # 3. Who are the beneficiaries?
826
+ # 4. Please describe how the funding will be used.
827
+ # """
828
+
829
+ # prompt = (
830
+ # "You are an assistant helping extract important grant application prompts and section headers.\n"
831
+ # "Return only questions and meaningful section titles that require thoughtful answers.\n"
832
+ # "Avoid metadata like phone numbers, dates, contact info, or websites.\n\n"
833
+ # f"{example_text}\n"
834
+ # f"TEXT:\n{text[:3000]}\n\n"
835
+ # "PROMPTS:"
836
+ # )
837
+
838
+ # if use_openai:
839
+ # if not openai.api_key:
840
+ # st.error("โŒ OPENAI_API_KEY is not set.")
841
+ # return "โš ๏ธ OpenAI key missing."
842
+ # try:
843
+ # response = client.chat.completions.create(
844
+ # model="gpt-4o-mini",
845
+ # messages=[
846
+ # {"role": "system", "content": "You extract prompts and headers from grant text."},
847
+ # {"role": "user", "content": prompt},
848
+ # ],
849
+ # temperature=0.2,
850
+ # max_tokens=500,
851
+ # )
852
+ # # raw_output = response["choices"][0]["message"]["content"]
853
+ # raw_output = response.choices[0].message.content
854
+ # st.markdown(f"๐Ÿงฎ Extract Tokens: Prompt = {response.usage.prompt_tokens}, "
855
+ # f"Completion = {response.usage.completion_tokens}, Total = {response.usage.total_tokens}")
856
+ # except Exception as e:
857
+ # st.error(f"โŒ OpenAI extraction failed: {e}")
858
+ # return []
859
+ # else:
860
+ # inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
861
+ # outputs = model.generate(
862
+ # **inputs,
863
+ # max_new_tokens=min(ax_tokens,512),
864
+ # temperature=0.3,
865
+ # do_sample=False,
866
+ # pad_token_id=tokenizer.eos_token_id
867
+ # )
868
+ # raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
869
+
870
+ # # Clean and deduplicate prompts
871
+ # lines = raw_output.split("\n")
872
+ # prompts = []
873
+ # seen = set()
874
+ # for line in lines:
875
+ # clean = line.strip("โ€ข-1234567890. ").strip()
876
+ # if (
877
+ # len(clean) > 10
878
+ # and not any(bad in clean.lower() for bad in ["phone", "email", "address", "website"])
879
+ # and clean not in seen
880
+ # ):
881
+ # prompts.append(clean)
882
+ # seen.add(clean)
883
+
884
+ # return prompts
885
+
886
+
887
+ # # def is_meaningful_prompt(text: str) -> bool:
888
+ # # too_short = len(text.strip()) < 10
889
+ # # banned_keywords = ["phone", "email", "fax", "address", "date", "contact", "website"]
890
+ # # contains_bad_word = any(word in text.lower() for word in banned_keywords)
891
+ # # is_just_punctuation = all(c in ":.*- " for c in text.strip())
892
+
893
+ # # return not (too_short or contains_bad_word or is_just_punctuation)
894
+
895
+ # # =================== Format Retrieved Chunks ===================
896
+ # def format_docs(docs: List[Document]) -> str:
897
+ # return "\n\n".join(doc.page_content or doc.metadata.get("content", "") for doc in docs)
898
+
899
+ # # =================== Generate Response from Hugging Face Model ===================
900
+ # # def generate_response(input_dict: Dict[str, Any]) -> str:
901
+ # # client = InferenceClient(api_key=HF_TOKEN.strip())
902
+ # # prompt = grantbuddy_prompt.format(**input_dict)
903
+
904
+ # # try:
905
+ # # response = client.chat.completions.create(
906
+ # # model="HuggingFaceH4/zephyr-7b-beta",
907
+ # # messages=[
908
+ # # {"role": "system", "content": prompt},
909
+ # # {"role": "user", "content": input_dict["question"]},
910
+ # # ],
911
+ # # max_tokens=1000,
912
+ # # temperature=0.2,
913
+ # # )
914
+ # # return response.choices[0].message.content
915
+ # # except Exception as e:
916
+ # # st.error(f"โŒ Error from model: {e}")
917
+ # # return "โš ๏ธ Failed to generate response. Please check your model, HF token, or request format."
918
+ # from transformers import AutoModelForCausalLM, AutoTokenizer
919
+ # import torch
920
+
921
+ # @st.cache_resource
922
+ # def load_local_model():
923
+ # model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
924
+ # tokenizer = AutoTokenizer.from_pretrained(model_name)
925
+ # model = AutoModelForCausalLM.from_pretrained(model_name)
926
+ # return tokenizer, model
927
+
928
+ # tokenizer, model = load_local_model()
929
+
930
+ # def generate_response(input_dict, use_openai=False, max_tokens=700):
931
+ # prompt = grantbuddy_prompt.format(**input_dict)
932
+
933
+ # if use_openai:
934
+ # try:
935
+ # response = client.chat.completions.create(
936
+ # model="gpt-4o-mini",
937
+ # messages=[
938
+ # {"role": "system", "content": prompt},
939
+ # {"role": "user", "content": input_dict["question"]},
940
+ # ],
941
+ # temperature=0.2,
942
+ # max_tokens=max_tokens,
943
+ # )
944
+ # answer = response.choices[0].message.content.strip()
945
+
946
+ # # โœ… Token logging
947
+ # prompt_tokens = response.usage.prompt_tokens
948
+ # completion_tokens = response.usage.completion_tokens
949
+ # total_tokens = response.usage.total_tokens
950
+
951
+ # return {
952
+ # "answer": answer,
953
+ # "tokens": {
954
+ # "prompt": prompt_tokens,
955
+ # "completion": completion_tokens,
956
+ # "total": total_tokens
957
+ # }
958
+ # }
959
+
960
+ # except Exception as e:
961
+ # st.error(f"โŒ OpenAI error: {e}")
962
+ # return {
963
+ # "answer": "โš ๏ธ OpenAI request failed.",
964
+ # "tokens": {}
965
+ # }
966
+
967
+ # else:
968
+ # inputs = tokenizer(prompt, return_tensors="pt")
969
+ # outputs = model.generate(
970
+ # **inputs,
971
+ # max_new_tokens=512,
972
+ # temperature=0.7,
973
+ # do_sample=True,
974
+ # pad_token_id=tokenizer.eos_token_id
975
+ # )
976
+ # decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
977
+ # return {
978
+ # "answer": decoded[len(prompt):].strip(),
979
+ # "tokens": {}
980
+ # }
981
+
982
+
983
+
984
+
985
+ # # =================== RAG Chain ===================
986
+ # def get_rag_chain(retriever, use_openai=False, max_tokens=700):
987
+ # def merge_contexts(inputs):
988
+ # retrieved_chunks = format_docs(retriever.invoke(inputs["question"]))
989
+ # combined = "\n\n".join(filter(None, [
990
+ # inputs.get("manual_context", ""),
991
+ # retrieved_chunks
992
+ # ]))
993
+ # return {
994
+ # "context": combined,
995
+ # "question": inputs["question"]
996
+ # }
997
+
998
+ # return RunnableLambda(merge_contexts) | RunnableLambda(
999
+ # lambda input_dict: generate_response(input_dict, use_openai=use_openai, max_tokens=max_tokens)
1000
+ # )
1001
+
1002
+
1003
+ # # =================== Streamlit UI ===================
1004
+ # def main():
1005
+ # # st.set_page_config(page_title="Grant Buddy RAG", page_icon="๐Ÿค–")
1006
+ # st.title("๐Ÿค– Grant Buddy: Grant-Writing Assistant")
1007
+ # USE_OPENAI = st.sidebar.checkbox("Use OpenAI (Costs Tokens)", value=False)
1008
+ # st.sidebar.markdown("### Retrieval Settings")
1009
+
1010
+ # k_value = st.sidebar.slider("How many chunks to retrieve (k)", min_value=5, max_value=40, step=5, value=10)
1011
+ # score_threshold = st.sidebar.slider("Minimum relevance score", min_value=0.0, max_value=1.0, step=0.05, value=0.75)
1012
+
1013
+ # st.sidebar.markdown("### Generation Settings")
1014
+ # max_tokens = st.sidebar.number_input("Max tokens in response", min_value=100, max_value=1500, value=700, step=50)
1015
+
1016
+ # if "generated_queries" not in st.session_state:
1017
+ # st.session_state.generated_queries = {}
1018
+
1019
+ # manual_context = st.text_area("๐Ÿ“ Optional: Add your own context (e.g., mission, goals)", height=150)
1020
+
1021
+ # retriever = init_vector_search().as_retriever(search_kwargs={"k": k_value, "score_threshold": score_threshold})
1022
+ # rag_chain = get_rag_chain(retriever, use_openai=USE_OPENAI, max_tokens=max_tokens)
1023
+
1024
+ # uploaded_file = st.file_uploader("Upload PDF or TXT for extra context (optional)", type=["pdf", "txt"])
1025
+ # uploaded_text = ""
1026
+
1027
+ # if uploaded_file:
1028
+ # with st.spinner("๐Ÿ“„ Processing uploaded file..."):
1029
+ # if uploaded_file.name.endswith(".pdf"):
1030
+ # reader = PdfReader(uploaded_file)
1031
+ # uploaded_text = "\n".join([page.extract_text() for page in reader.pages if page.extract_text()])
1032
+ # elif uploaded_file.name.endswith(".txt"):
1033
+ # uploaded_text = uploaded_file.read().decode("utf-8")
1034
+
1035
+ # # extract qs and headers using llms
1036
+ # questions = extract_with_llm_local(uploaded_text, use_openai=USE_OPENAI)
1037
+
1038
+ # # filter out irrelevant text
1039
+ # def is_meaningful_prompt(text: str) -> bool:
1040
+ # too_short = len(text.strip()) < 10
1041
+ # banned_keywords = ["phone", "email", "fax", "address", "date", "contact", "website"]
1042
+ # contains_bad_word = any(word in text.lower() for word in banned_keywords)
1043
+ # is_just_punctuation = all(c in ":.*- " for c in text.strip())
1044
+ # return not (too_short or contains_bad_word or is_just_punctuation)
1045
+
1046
+ # filtered_questions = [q for q in questions if is_meaningful_prompt(q)]
1047
+ # with st.form("question_selection_form"):
1048
+ # st.subheader("Choose prompts to answer:")
1049
+ # selected_questions=[]
1050
+ # for i,q in enumerate(filtered_questions):
1051
+ # if st.checkbox(q, key=f"q_{i}", value=True):
1052
+ # selected_questions.append(q)
1053
+ # submit_button = st.form_submit_button("Submit")
1054
+
1055
+ # #Multi-Select Question
1056
+ # if 'submit_button' in locals() and submit_button:
1057
+ # if selected_questions:
1058
+ # with st.spinner("๐Ÿ’ก Generating answers..."):
1059
+ # answers = []
1060
+ # for q in selected_questions:
1061
+ # # full_query = f"{q}\n\nAdditional context:\n{uploaded_text}"
1062
+ # combined_context = "\n\n".join(filter(None, [manual_context.strip(), uploaded_text.strip()]))
1063
+ # if q in st.session_state.generated_queries:
1064
+ # response = st.session_state.generated_queries[q]
1065
+ # else:
1066
+ # response = rag_chain.invoke({
1067
+ # "question": q,
1068
+ # "manual_context": combined_context
1069
+ # })
1070
+ # st.session_state.generated_queries[q] = response
1071
+ # answers.append({"question": q, "answer": response})
1072
+ # for item in answers:
1073
+ # st.markdown(f"### โ“ {item['question']}")
1074
+ # st.markdown(f"๐Ÿ’ฌ {item['answer']['answer']}")
1075
+ # tokens = item['answer'].get("tokens", {})
1076
+ # if tokens:
1077
+ # st.markdown(f"๐Ÿงฎ **Token Usage:** Prompt = {tokens.get('prompt')}, "
1078
+ # f"Completion = {tokens.get('completion')}, Total = {tokens.get('total')}")
1079
+
1080
+ # else:
1081
+ # st.info("No prompts selected for answering.")
1082
+
1083
+
1084
+ # # โœ๏ธ Manual single-question input
1085
+ # query = st.text_input("Ask a grant-related question")
1086
+ # if st.button("Submit"):
1087
+ # if not query:
1088
+ # st.warning("Please enter a question.")
1089
+ # return
1090
+
1091
+ # # full_query = f"{query}\n\nAdditional context:\n{uploaded_text}" if uploaded_text else query
1092
+ # combined_context = "\n\n".join(filter(None, [manual_context.strip(), uploaded_text.strip()]))
1093
+ # with st.spinner("๐Ÿค– Thinking..."):
1094
+ # # response = rag_chain.invoke(full_query)
1095
+ # response = rag_chain.invoke({"question":query,"manual_context": combined_context})
1096
+ # st.text_area("Grant Buddy says:", value=response["answer"], height=250, disabled=True)
1097
+ # tokens=response.get("tokens",{})
1098
+ # if tokens:
1099
+ # st.markdown(f"๐Ÿงฎ **Token Usage:** Prompt = {tokens.get('prompt')}, "
1100
+ # f"Completion = {tokens.get('completion')}, Total = {tokens.get('total')}")
1101
+
1102
+ # with st.expander("๐Ÿ” Retrieved Chunks"):
1103
+ # context_docs = retriever.get_relevant_documents(query)
1104
+ # for doc in context_docs:
1105
+ # # st.json(doc.metadata)
1106
+ # st.markdown(f"**Chunk ID:** {doc.metadata.get('chunk_id', 'unknown')} | **Title:** {doc.metadata['metadata'].get('title', 'unknown')}")
1107
+ # st.markdown(doc.page_content[:700] + "...")
1108
+ # st.markdown("---")
1109
+
1110
+
1111
+
1112
+
1113
+
1114
+ # if __name__ == "__main__":
1115
+ # main()
1116
+
1117
+