Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -408,7 +408,10 @@ def generate_response(input_dict, use_openai=False, max_tokens=700):
|
|
408 |
# =================== RAG Chain ===================
|
409 |
def get_rag_chain(retriever, use_openai=False, max_tokens=700):
|
410 |
def merge_contexts(inputs):
|
411 |
-
|
|
|
|
|
|
|
412 |
combined = "\n\n".join(filter(None, [
|
413 |
inputs.get("manual_context", ""),
|
414 |
retrieved_chunks
|
@@ -422,6 +425,26 @@ def get_rag_chain(retriever, use_openai=False, max_tokens=700):
|
|
422 |
lambda input_dict: generate_response(input_dict, use_openai=use_openai, max_tokens=max_tokens)
|
423 |
)
|
424 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
425 |
|
426 |
# =================== Streamlit UI ===================
|
427 |
def main():
|
@@ -432,7 +455,9 @@ def main():
|
|
432 |
|
433 |
k_value = st.sidebar.slider("How many chunks to retrieve (k)", min_value=5, max_value=40, step=5, value=10)
|
434 |
score_threshold = st.sidebar.slider("Minimum relevance score", min_value=0.0, max_value=1.0, step=0.05, value=0.75)
|
435 |
-
|
|
|
|
|
436 |
st.sidebar.markdown("### Generation Settings")
|
437 |
max_tokens = st.sidebar.number_input("Max tokens in response", min_value=100, max_value=1500, value=700, step=50)
|
438 |
|
@@ -440,8 +465,15 @@ def main():
|
|
440 |
st.session_state.generated_queries = {}
|
441 |
|
442 |
manual_context = st.text_area("๐ Optional: Add your own context (e.g., mission, goals)", height=150)
|
443 |
-
|
444 |
-
retriever = init_vector_search().as_retriever(search_kwargs={"k": k_value, "score_threshold": score_threshold})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
445 |
rag_chain = get_rag_chain(retriever, use_openai=USE_OPENAI, max_tokens=max_tokens)
|
446 |
|
447 |
uploaded_file = st.file_uploader("Upload PDF or TXT for extra context (optional)", type=["pdf", "txt"])
|
@@ -488,7 +520,8 @@ def main():
|
|
488 |
else:
|
489 |
response = rag_chain.invoke({
|
490 |
"question": q,
|
491 |
-
"manual_context": combined_context
|
|
|
492 |
})
|
493 |
st.session_state.generated_queries[q] = response
|
494 |
answers.append({"question": q, "answer": response})
|
@@ -515,7 +548,7 @@ def main():
|
|
515 |
combined_context = "\n\n".join(filter(None, [manual_context.strip(), uploaded_text.strip()]))
|
516 |
with st.spinner("๐ค Thinking..."):
|
517 |
# response = rag_chain.invoke(full_query)
|
518 |
-
response = rag_chain.invoke({"question":query,"manual_context": combined_context})
|
519 |
st.text_area("Grant Buddy says:", value=response["answer"], height=250, disabled=True)
|
520 |
tokens=response.get("tokens",{})
|
521 |
if tokens:
|
@@ -523,11 +556,14 @@ def main():
|
|
523 |
f"Completion = {tokens.get('completion')}, Total = {tokens.get('total')}")
|
524 |
|
525 |
with st.expander("๐ Retrieved Chunks"):
|
526 |
-
context_docs = retriever.get_relevant_documents(query)
|
527 |
for doc in context_docs:
|
528 |
# st.json(doc.metadata)
|
529 |
st.markdown(f"**Chunk ID:** {doc.metadata.get('chunk_id', 'unknown')} | **Title:** {doc.metadata['metadata'].get('title', 'unknown')}")
|
530 |
st.markdown(doc.page_content[:700] + "...")
|
|
|
|
|
|
|
531 |
st.markdown("---")
|
532 |
|
533 |
|
@@ -538,3 +574,544 @@ if __name__ == "__main__":
|
|
538 |
main()
|
539 |
|
540 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
408 |
# =================== RAG Chain ===================
|
409 |
def get_rag_chain(retriever, use_openai=False, max_tokens=700):
|
410 |
def merge_contexts(inputs):
|
411 |
+
#use chunks if provided
|
412 |
+
retrieved_chunks = format_docs(inputs["context_docs"]) if "context_docs" in inputs \
|
413 |
+
else format_docs(retriever.invoke(inputs["question"]))
|
414 |
+
|
415 |
combined = "\n\n".join(filter(None, [
|
416 |
inputs.get("manual_context", ""),
|
417 |
retrieved_chunks
|
|
|
425 |
lambda input_dict: generate_response(input_dict, use_openai=use_openai, max_tokens=max_tokens)
|
426 |
)
|
427 |
|
428 |
+
)
|
429 |
+
def rerank_with_topics(chunks, topics, alpha=0.2):
|
430 |
+
"""
|
431 |
+
Boosts similarity based on topic overlap.
|
432 |
+
Since chunks don't have scores, we use rank order and topic matches.
|
433 |
+
"""
|
434 |
+
topics_lower = set(t.lower() for t in topics)
|
435 |
+
|
436 |
+
def score(chunk, rank):
|
437 |
+
chunk_topics = [t.lower() for t in chunk.metadata.get("topics", [])]
|
438 |
+
topic_matches = len(topics_lower.intersection(chunk_topics))
|
439 |
+
# Lower is better: original rank minus boost
|
440 |
+
return rank - alpha * topic_matches
|
441 |
+
|
442 |
+
reranked = sorted(
|
443 |
+
enumerate(chunks),
|
444 |
+
key=lambda x: score(x[1], x[0]) # x[0] is rank, x[1] is chunk
|
445 |
+
)
|
446 |
+
return [chunk for _, chunk in reranked]
|
447 |
+
|
448 |
|
449 |
# =================== Streamlit UI ===================
|
450 |
def main():
|
|
|
455 |
|
456 |
k_value = st.sidebar.slider("How many chunks to retrieve (k)", min_value=5, max_value=40, step=5, value=10)
|
457 |
score_threshold = st.sidebar.slider("Minimum relevance score", min_value=0.0, max_value=1.0, step=0.05, value=0.75)
|
458 |
+
topic_input=st.sidebar.text_input("Optional: Focus on specific topics (comma-separated)")
|
459 |
+
topics=[t.strip() for t in topic_input.split(",") if t.strip()]
|
460 |
+
topic_weight= st.sidebar.slider("Topic relevance score", min_value=0.0, max_value=1.0, step=0.05, value=0.2)
|
461 |
st.sidebar.markdown("### Generation Settings")
|
462 |
max_tokens = st.sidebar.number_input("Max tokens in response", min_value=100, max_value=1500, value=700, step=50)
|
463 |
|
|
|
465 |
st.session_state.generated_queries = {}
|
466 |
|
467 |
manual_context = st.text_area("๐ Optional: Add your own context (e.g., mission, goals)", height=150)
|
468 |
+
|
469 |
+
# retriever = init_vector_search().as_retriever(search_kwargs={"k": k_value, "score_threshold": score_threshold})
|
470 |
+
retriever = init_vector_search().as_retriever()
|
471 |
+
|
472 |
+
pre_k = k_value*4 # Retrieve more chunks first
|
473 |
+
context_docs = retriever.get_relevant_documents(query, k=pre_k)
|
474 |
+
if topics:
|
475 |
+
context_docs = rerank_with_topics(context_docs, topics, alpha=topic_weight)
|
476 |
+
context_docs = context_docs[:k_value] # Final top-k used in RAG
|
477 |
rag_chain = get_rag_chain(retriever, use_openai=USE_OPENAI, max_tokens=max_tokens)
|
478 |
|
479 |
uploaded_file = st.file_uploader("Upload PDF or TXT for extra context (optional)", type=["pdf", "txt"])
|
|
|
520 |
else:
|
521 |
response = rag_chain.invoke({
|
522 |
"question": q,
|
523 |
+
"manual_context": combined_context,
|
524 |
+
"context_docs": context_docs
|
525 |
})
|
526 |
st.session_state.generated_queries[q] = response
|
527 |
answers.append({"question": q, "answer": response})
|
|
|
548 |
combined_context = "\n\n".join(filter(None, [manual_context.strip(), uploaded_text.strip()]))
|
549 |
with st.spinner("๐ค Thinking..."):
|
550 |
# response = rag_chain.invoke(full_query)
|
551 |
+
response = rag_chain.invoke({"question":query,"manual_context": combined_context, "context_docs": context_docs})
|
552 |
st.text_area("Grant Buddy says:", value=response["answer"], height=250, disabled=True)
|
553 |
tokens=response.get("tokens",{})
|
554 |
if tokens:
|
|
|
556 |
f"Completion = {tokens.get('completion')}, Total = {tokens.get('total')}")
|
557 |
|
558 |
with st.expander("๐ Retrieved Chunks"):
|
559 |
+
# context_docs = retriever.get_relevant_documents(query)
|
560 |
for doc in context_docs:
|
561 |
# st.json(doc.metadata)
|
562 |
st.markdown(f"**Chunk ID:** {doc.metadata.get('chunk_id', 'unknown')} | **Title:** {doc.metadata['metadata'].get('title', 'unknown')}")
|
563 |
st.markdown(doc.page_content[:700] + "...")
|
564 |
+
if topics:
|
565 |
+
matched_topics=set(doc.metadata['metadata'].get('topics',[])).intersection(topics)
|
566 |
+
st.markdown(f"**Matched Topics**{','.join(matched_topics)")
|
567 |
st.markdown("---")
|
568 |
|
569 |
|
|
|
574 |
main()
|
575 |
|
576 |
|
577 |
+
|
578 |
+
# # app.py
|
579 |
+
# import os
|
580 |
+
# import re
|
581 |
+
# import openai
|
582 |
+
# from huggingface_hub import InferenceClient
|
583 |
+
# import json
|
584 |
+
# from huggingface_hub import HfApi
|
585 |
+
# import streamlit as st
|
586 |
+
# from typing import List, Dict, Any
|
587 |
+
# from urllib.parse import quote_plus
|
588 |
+
# from pymongo import MongoClient
|
589 |
+
# from PyPDF2 import PdfReader
|
590 |
+
# st.set_page_config(page_title="Grant Buddy RAG", page_icon="๐ค")
|
591 |
+
|
592 |
+
# from typing import List
|
593 |
+
|
594 |
+
# from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
|
595 |
+
# from langchain.embeddings import HuggingFaceEmbeddings
|
596 |
+
|
597 |
+
# from langchain_community.vectorstores import MongoDBAtlasVectorSearch
|
598 |
+
# from langchain.prompts import PromptTemplate
|
599 |
+
# from langchain.schema import Document
|
600 |
+
# from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
|
601 |
+
# from huggingface_hub import InferenceClient
|
602 |
+
|
603 |
+
# # =================== Secure Env via Hugging Face Secrets ===================
|
604 |
+
# user = quote_plus(os.getenv("MONGO_USERNAME"))
|
605 |
+
# password = quote_plus(os.getenv("MONGO_PASSWORD"))
|
606 |
+
# cluster = os.getenv("MONGO_CLUSTER")
|
607 |
+
# db_name = os.getenv("MONGO_DB_NAME", "files")
|
608 |
+
# collection_name = os.getenv("MONGO_COLLECTION", "files_collection")
|
609 |
+
# index_name = os.getenv("MONGO_VECTOR_INDEX", "vector_index")
|
610 |
+
|
611 |
+
# HF_TOKEN = os.getenv("HF_TOKEN")
|
612 |
+
# OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "").strip()
|
613 |
+
# if OPENAI_API_KEY:
|
614 |
+
# openai.api_key = OPENAI_API_KEY
|
615 |
+
# from openai import OpenAI
|
616 |
+
# client = OpenAI(api_key=OPENAI_API_KEY)
|
617 |
+
|
618 |
+
# # MONGO_URI = f"mongodb+srv://{user}:{password}@{cluster}/{db_name}?retryWrites=true&w=majority"
|
619 |
+
# MONGO_URI = f"mongodb+srv://{user}:{password}@{cluster}/{db_name}?retryWrites=true&w=majority&tls=true&tlsAllowInvalidCertificates=true"
|
620 |
+
|
621 |
+
|
622 |
+
# # =================== Prompt ===================
|
623 |
+
# grantbuddy_prompt = PromptTemplate.from_template(
|
624 |
+
# """You are Grant Buddy, a specialized language model fine-tuned with instruction-tuning and RLHF.
|
625 |
+
# You help a nonprofit focused on social entrepreneurship, BIPOC empowerment, and edtech write clear, mission-aligned grant responses.
|
626 |
+
|
627 |
+
# **Instructions:**
|
628 |
+
# - Start with reasoning or context for your answer.
|
629 |
+
# - Always align with the nonprofitโs mission.
|
630 |
+
# - Use structured formatting: headings, bullet points, numbered lists.
|
631 |
+
# - Include impact data or examples if relevant.
|
632 |
+
# - Do NOT repeat the same sentence or answer multiple times.
|
633 |
+
# - If no answer exists in the context, say: "This information is not available in the current context."
|
634 |
+
|
635 |
+
# CONTEXT:
|
636 |
+
# {context}
|
637 |
+
|
638 |
+
# QUESTION:
|
639 |
+
# {question}
|
640 |
+
# """
|
641 |
+
# )
|
642 |
+
|
643 |
+
|
644 |
+
|
645 |
+
# # =================== Vector Search Setup ===================
|
646 |
+
# @st.cache_resource
|
647 |
+
# def init_embedding_model():
|
648 |
+
# return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
649 |
+
|
650 |
+
|
651 |
+
# @st.cache_resource
|
652 |
+
|
653 |
+
|
654 |
+
# def init_vector_search() -> MongoDBAtlasVectorSearch:
|
655 |
+
# HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
|
656 |
+
# model_name = "sentence-transformers/all-MiniLM-L6-v2"
|
657 |
+
# st.write(f"๐ Connecting to Hugging Face model: `{model_name}`")
|
658 |
+
|
659 |
+
# embedding_model = HuggingFaceEmbeddings(model_name=model_name)
|
660 |
+
|
661 |
+
# # โ
Manual MongoClient with TLS settings
|
662 |
+
# user = quote_plus(os.getenv("MONGO_USERNAME", "").strip())
|
663 |
+
# password = quote_plus(os.getenv("MONGO_PASSWORD", "").strip())
|
664 |
+
# cluster = os.getenv("MONGO_CLUSTER", "").strip()
|
665 |
+
# db_name = os.getenv("MONGO_DB_NAME", "files").strip()
|
666 |
+
# collection_name = os.getenv("MONGO_COLLECTION", "files_collection").strip()
|
667 |
+
# index_name = os.getenv("MONGO_VECTOR_INDEX", "vector_index").strip()
|
668 |
+
|
669 |
+
# mongo_uri = f"mongodb+srv://{user}:{password}@{cluster}/?retryWrites=true&w=majority"
|
670 |
+
|
671 |
+
# try:
|
672 |
+
# client = MongoClient(mongo_uri, tls=True, tlsAllowInvalidCertificates=True, serverSelectionTimeoutMS=20000)
|
673 |
+
# db = client[db_name]
|
674 |
+
# collection = db[collection_name]
|
675 |
+
# st.success("โ
MongoClient connected successfully")
|
676 |
+
|
677 |
+
# return MongoDBAtlasVectorSearch(
|
678 |
+
# collection=collection,
|
679 |
+
# embedding=embedding_model,
|
680 |
+
# index_name=index_name,
|
681 |
+
# )
|
682 |
+
|
683 |
+
# except Exception as e:
|
684 |
+
# st.error("โ Failed to connect to MongoDB Atlas manually")
|
685 |
+
# st.error(str(e))
|
686 |
+
# raise e
|
687 |
+
# # =================== Question/Headers Extraction ===================
|
688 |
+
# # def extract_questions_and_headers(text: str) -> List[str]:
|
689 |
+
# # header_patterns = [
|
690 |
+
# # r'\d+\.\s+\*\*([^\*]+)\*\*',
|
691 |
+
# # r'\*\*([^*]+)\*\*',
|
692 |
+
# # r'^([A-Z][^a-z]*[A-Z])$',
|
693 |
+
# # r'^([A-Z][A-Za-z\s]{3,})$',
|
694 |
+
# # r'^[A-Z][A-Za-z\s]+:$'
|
695 |
+
# # ]
|
696 |
+
# # question_patterns = [
|
697 |
+
# # r'^.+\?$',
|
698 |
+
# # r'^\*?Please .+',
|
699 |
+
# # r'^How .+',
|
700 |
+
# # r'^What .+',
|
701 |
+
# # r'^Describe .+',
|
702 |
+
# # ]
|
703 |
+
# # combined_header_re = re.compile("|".join(header_patterns), re.MULTILINE)
|
704 |
+
# # combined_question_re = re.compile("|".join(question_patterns), re.MULTILINE)
|
705 |
+
|
706 |
+
# # headers = [match for group in combined_header_re.findall(text) for match in group if match]
|
707 |
+
# # questions = combined_question_re.findall(text)
|
708 |
+
|
709 |
+
# # return headers + questions
|
710 |
+
# # def extract_with_llm(text: str) -> List[str]:
|
711 |
+
# # client = InferenceClient(api_key=HF_TOKEN.strip())
|
712 |
+
# # try:
|
713 |
+
# # response = client.chat.completions.create(
|
714 |
+
# # model="mistralai/Mistral-Nemo-Instruct-2407", # or "HuggingFaceH4/zephyr-7b-beta"
|
715 |
+
# # messages=[
|
716 |
+
# # {
|
717 |
+
# # "role": "system",
|
718 |
+
# # "content": "You are an assistant helping extract questions and headers from grant applications.",
|
719 |
+
# # },
|
720 |
+
# # {
|
721 |
+
# # "role": "user",
|
722 |
+
# # "content": (
|
723 |
+
# # "Please extract all the grant application headers and questions from the following text. "
|
724 |
+
# # "Include section titles, prompts, and any question-like content. Return them as a numbered list.\n\n"
|
725 |
+
# # f"{text[:3000]}"
|
726 |
+
# # ),
|
727 |
+
# # },
|
728 |
+
# # ],
|
729 |
+
# # temperature=0.2,
|
730 |
+
# # max_tokens=512,
|
731 |
+
# # )
|
732 |
+
# # return [
|
733 |
+
# # line.strip("โข-1234567890. ").strip()
|
734 |
+
# # for line in response.choices[0].message.content.strip().split("\n")
|
735 |
+
# # if line.strip()
|
736 |
+
# # ]
|
737 |
+
# # except Exception as e:
|
738 |
+
# # st.error("โ LLM extraction failed")
|
739 |
+
# # st.error(str(e))
|
740 |
+
# # return []
|
741 |
+
# # def extract_with_llm_local(text: str) -> List[str]:
|
742 |
+
# # prompt = (
|
743 |
+
# # "You are an assistant helping extract useful questions and section headers from a grant application.\n"
|
744 |
+
# # "Return only the important prompts as a numbered list.\n\n"
|
745 |
+
# # "TEXT:\n"
|
746 |
+
# # f"{text[:3000]}\n\n"
|
747 |
+
# # "PROMPTS:"
|
748 |
+
# # )
|
749 |
+
# # inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
|
750 |
+
# # outputs = model.generate(
|
751 |
+
# # **inputs,
|
752 |
+
# # max_new_tokens=512,
|
753 |
+
# # temperature=0.3,
|
754 |
+
# # do_sample=False
|
755 |
+
# # )
|
756 |
+
# # raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
757 |
+
|
758 |
+
# # # Extract prompts from the numbered list in the output
|
759 |
+
# # lines = raw_output.split("\n")
|
760 |
+
# # prompts = []
|
761 |
+
# # for line in lines:
|
762 |
+
# # line = line.strip("โข-1234567890. ").strip()
|
763 |
+
# # if len(line) > 10:
|
764 |
+
# # prompts.append(line)
|
765 |
+
# # return prompts
|
766 |
+
# # def extract_with_llm_local(text: str) -> List[str]:
|
767 |
+
# # example_text = """TEXT:
|
768 |
+
# # 1. Project Summary: Please describe the main goals of your project.
|
769 |
+
# # 2. Contact Information: Address, phone, email.
|
770 |
+
# # 3. What is the mission of your organization?
|
771 |
+
# # 4. Who are the beneficiaries?
|
772 |
+
# # 5. Budget Breakdown
|
773 |
+
# # 6. Please describe how the funding will be used.
|
774 |
+
# # 7. Website: www.example.org
|
775 |
+
|
776 |
+
# # PROMPTS:
|
777 |
+
# # 1. Project Summary
|
778 |
+
# # 2. What is the mission of your organization?
|
779 |
+
# # 3. Who are the beneficiaries?
|
780 |
+
# # 4. Please describe how the funding will be used.
|
781 |
+
# # """
|
782 |
+
|
783 |
+
# # prompt = (
|
784 |
+
# # "You are an assistant helping extract important grant application prompts and section headers.\n"
|
785 |
+
# # "Return only questions and meaningful section titles that require thoughtful answers.\n"
|
786 |
+
# # "Avoid metadata like phone numbers, dates, contact info, or websites.\n\n"
|
787 |
+
# # f"{example_text}\n"
|
788 |
+
# # f"TEXT:\n{text[:3000]}\n\n"
|
789 |
+
# # "PROMPTS:"
|
790 |
+
# # )
|
791 |
+
|
792 |
+
# # inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
|
793 |
+
# # outputs = model.generate(
|
794 |
+
# # **inputs,
|
795 |
+
# # max_new_tokens=512,
|
796 |
+
# # temperature=0.3,
|
797 |
+
# # do_sample=False
|
798 |
+
# # )
|
799 |
+
# # raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
800 |
+
|
801 |
+
# # # Clean and extract numbered or bulleted lines
|
802 |
+
# # lines = raw_output.split("\n")
|
803 |
+
# # prompts = []
|
804 |
+
# # for line in lines:
|
805 |
+
# # clean = line.strip("โข-1234567890. ").strip()
|
806 |
+
# # if len(clean) > 10 and not any(bad in clean.lower() for bad in ["phone", "email", "address", "website"]):
|
807 |
+
# # prompts.append(clean)
|
808 |
+
# # return prompts
|
809 |
+
|
810 |
+
|
811 |
+
# def extract_with_llm_local(text: str, use_openai: bool = False) -> List[str]:
|
812 |
+
# # Example context to prime the model
|
813 |
+
# example_text = """TEXT:
|
814 |
+
# 1. Project Summary: Please describe the main goals of your project.
|
815 |
+
# 2. Contact Information: Address, phone, email.
|
816 |
+
# 3. What is the mission of your organization?
|
817 |
+
# 4. Who are the beneficiaries?
|
818 |
+
# 5. Budget Breakdown
|
819 |
+
# 6. Please describe how the funding will be used.
|
820 |
+
# 7. Website: www.example.org
|
821 |
+
|
822 |
+
# PROMPTS:
|
823 |
+
# 1. Project Summary
|
824 |
+
# 2. What is the mission of your organization?
|
825 |
+
# 3. Who are the beneficiaries?
|
826 |
+
# 4. Please describe how the funding will be used.
|
827 |
+
# """
|
828 |
+
|
829 |
+
# prompt = (
|
830 |
+
# "You are an assistant helping extract important grant application prompts and section headers.\n"
|
831 |
+
# "Return only questions and meaningful section titles that require thoughtful answers.\n"
|
832 |
+
# "Avoid metadata like phone numbers, dates, contact info, or websites.\n\n"
|
833 |
+
# f"{example_text}\n"
|
834 |
+
# f"TEXT:\n{text[:3000]}\n\n"
|
835 |
+
# "PROMPTS:"
|
836 |
+
# )
|
837 |
+
|
838 |
+
# if use_openai:
|
839 |
+
# if not openai.api_key:
|
840 |
+
# st.error("โ OPENAI_API_KEY is not set.")
|
841 |
+
# return "โ ๏ธ OpenAI key missing."
|
842 |
+
# try:
|
843 |
+
# response = client.chat.completions.create(
|
844 |
+
# model="gpt-4o-mini",
|
845 |
+
# messages=[
|
846 |
+
# {"role": "system", "content": "You extract prompts and headers from grant text."},
|
847 |
+
# {"role": "user", "content": prompt},
|
848 |
+
# ],
|
849 |
+
# temperature=0.2,
|
850 |
+
# max_tokens=500,
|
851 |
+
# )
|
852 |
+
# # raw_output = response["choices"][0]["message"]["content"]
|
853 |
+
# raw_output = response.choices[0].message.content
|
854 |
+
# st.markdown(f"๐งฎ Extract Tokens: Prompt = {response.usage.prompt_tokens}, "
|
855 |
+
# f"Completion = {response.usage.completion_tokens}, Total = {response.usage.total_tokens}")
|
856 |
+
# except Exception as e:
|
857 |
+
# st.error(f"โ OpenAI extraction failed: {e}")
|
858 |
+
# return []
|
859 |
+
# else:
|
860 |
+
# inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
|
861 |
+
# outputs = model.generate(
|
862 |
+
# **inputs,
|
863 |
+
# max_new_tokens=min(ax_tokens,512),
|
864 |
+
# temperature=0.3,
|
865 |
+
# do_sample=False,
|
866 |
+
# pad_token_id=tokenizer.eos_token_id
|
867 |
+
# )
|
868 |
+
# raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
869 |
+
|
870 |
+
# # Clean and deduplicate prompts
|
871 |
+
# lines = raw_output.split("\n")
|
872 |
+
# prompts = []
|
873 |
+
# seen = set()
|
874 |
+
# for line in lines:
|
875 |
+
# clean = line.strip("โข-1234567890. ").strip()
|
876 |
+
# if (
|
877 |
+
# len(clean) > 10
|
878 |
+
# and not any(bad in clean.lower() for bad in ["phone", "email", "address", "website"])
|
879 |
+
# and clean not in seen
|
880 |
+
# ):
|
881 |
+
# prompts.append(clean)
|
882 |
+
# seen.add(clean)
|
883 |
+
|
884 |
+
# return prompts
|
885 |
+
|
886 |
+
|
887 |
+
# # def is_meaningful_prompt(text: str) -> bool:
|
888 |
+
# # too_short = len(text.strip()) < 10
|
889 |
+
# # banned_keywords = ["phone", "email", "fax", "address", "date", "contact", "website"]
|
890 |
+
# # contains_bad_word = any(word in text.lower() for word in banned_keywords)
|
891 |
+
# # is_just_punctuation = all(c in ":.*- " for c in text.strip())
|
892 |
+
|
893 |
+
# # return not (too_short or contains_bad_word or is_just_punctuation)
|
894 |
+
|
895 |
+
# # =================== Format Retrieved Chunks ===================
|
896 |
+
# def format_docs(docs: List[Document]) -> str:
|
897 |
+
# return "\n\n".join(doc.page_content or doc.metadata.get("content", "") for doc in docs)
|
898 |
+
|
899 |
+
# # =================== Generate Response from Hugging Face Model ===================
|
900 |
+
# # def generate_response(input_dict: Dict[str, Any]) -> str:
|
901 |
+
# # client = InferenceClient(api_key=HF_TOKEN.strip())
|
902 |
+
# # prompt = grantbuddy_prompt.format(**input_dict)
|
903 |
+
|
904 |
+
# # try:
|
905 |
+
# # response = client.chat.completions.create(
|
906 |
+
# # model="HuggingFaceH4/zephyr-7b-beta",
|
907 |
+
# # messages=[
|
908 |
+
# # {"role": "system", "content": prompt},
|
909 |
+
# # {"role": "user", "content": input_dict["question"]},
|
910 |
+
# # ],
|
911 |
+
# # max_tokens=1000,
|
912 |
+
# # temperature=0.2,
|
913 |
+
# # )
|
914 |
+
# # return response.choices[0].message.content
|
915 |
+
# # except Exception as e:
|
916 |
+
# # st.error(f"โ Error from model: {e}")
|
917 |
+
# # return "โ ๏ธ Failed to generate response. Please check your model, HF token, or request format."
|
918 |
+
# from transformers import AutoModelForCausalLM, AutoTokenizer
|
919 |
+
# import torch
|
920 |
+
|
921 |
+
# @st.cache_resource
|
922 |
+
# def load_local_model():
|
923 |
+
# model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
924 |
+
# tokenizer = AutoTokenizer.from_pretrained(model_name)
|
925 |
+
# model = AutoModelForCausalLM.from_pretrained(model_name)
|
926 |
+
# return tokenizer, model
|
927 |
+
|
928 |
+
# tokenizer, model = load_local_model()
|
929 |
+
|
930 |
+
# def generate_response(input_dict, use_openai=False, max_tokens=700):
|
931 |
+
# prompt = grantbuddy_prompt.format(**input_dict)
|
932 |
+
|
933 |
+
# if use_openai:
|
934 |
+
# try:
|
935 |
+
# response = client.chat.completions.create(
|
936 |
+
# model="gpt-4o-mini",
|
937 |
+
# messages=[
|
938 |
+
# {"role": "system", "content": prompt},
|
939 |
+
# {"role": "user", "content": input_dict["question"]},
|
940 |
+
# ],
|
941 |
+
# temperature=0.2,
|
942 |
+
# max_tokens=max_tokens,
|
943 |
+
# )
|
944 |
+
# answer = response.choices[0].message.content.strip()
|
945 |
+
|
946 |
+
# # โ
Token logging
|
947 |
+
# prompt_tokens = response.usage.prompt_tokens
|
948 |
+
# completion_tokens = response.usage.completion_tokens
|
949 |
+
# total_tokens = response.usage.total_tokens
|
950 |
+
|
951 |
+
# return {
|
952 |
+
# "answer": answer,
|
953 |
+
# "tokens": {
|
954 |
+
# "prompt": prompt_tokens,
|
955 |
+
# "completion": completion_tokens,
|
956 |
+
# "total": total_tokens
|
957 |
+
# }
|
958 |
+
# }
|
959 |
+
|
960 |
+
# except Exception as e:
|
961 |
+
# st.error(f"โ OpenAI error: {e}")
|
962 |
+
# return {
|
963 |
+
# "answer": "โ ๏ธ OpenAI request failed.",
|
964 |
+
# "tokens": {}
|
965 |
+
# }
|
966 |
+
|
967 |
+
# else:
|
968 |
+
# inputs = tokenizer(prompt, return_tensors="pt")
|
969 |
+
# outputs = model.generate(
|
970 |
+
# **inputs,
|
971 |
+
# max_new_tokens=512,
|
972 |
+
# temperature=0.7,
|
973 |
+
# do_sample=True,
|
974 |
+
# pad_token_id=tokenizer.eos_token_id
|
975 |
+
# )
|
976 |
+
# decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
977 |
+
# return {
|
978 |
+
# "answer": decoded[len(prompt):].strip(),
|
979 |
+
# "tokens": {}
|
980 |
+
# }
|
981 |
+
|
982 |
+
|
983 |
+
|
984 |
+
|
985 |
+
# # =================== RAG Chain ===================
|
986 |
+
# def get_rag_chain(retriever, use_openai=False, max_tokens=700):
|
987 |
+
# def merge_contexts(inputs):
|
988 |
+
# retrieved_chunks = format_docs(retriever.invoke(inputs["question"]))
|
989 |
+
# combined = "\n\n".join(filter(None, [
|
990 |
+
# inputs.get("manual_context", ""),
|
991 |
+
# retrieved_chunks
|
992 |
+
# ]))
|
993 |
+
# return {
|
994 |
+
# "context": combined,
|
995 |
+
# "question": inputs["question"]
|
996 |
+
# }
|
997 |
+
|
998 |
+
# return RunnableLambda(merge_contexts) | RunnableLambda(
|
999 |
+
# lambda input_dict: generate_response(input_dict, use_openai=use_openai, max_tokens=max_tokens)
|
1000 |
+
# )
|
1001 |
+
|
1002 |
+
|
1003 |
+
# # =================== Streamlit UI ===================
|
1004 |
+
# def main():
|
1005 |
+
# # st.set_page_config(page_title="Grant Buddy RAG", page_icon="๐ค")
|
1006 |
+
# st.title("๐ค Grant Buddy: Grant-Writing Assistant")
|
1007 |
+
# USE_OPENAI = st.sidebar.checkbox("Use OpenAI (Costs Tokens)", value=False)
|
1008 |
+
# st.sidebar.markdown("### Retrieval Settings")
|
1009 |
+
|
1010 |
+
# k_value = st.sidebar.slider("How many chunks to retrieve (k)", min_value=5, max_value=40, step=5, value=10)
|
1011 |
+
# score_threshold = st.sidebar.slider("Minimum relevance score", min_value=0.0, max_value=1.0, step=0.05, value=0.75)
|
1012 |
+
|
1013 |
+
# st.sidebar.markdown("### Generation Settings")
|
1014 |
+
# max_tokens = st.sidebar.number_input("Max tokens in response", min_value=100, max_value=1500, value=700, step=50)
|
1015 |
+
|
1016 |
+
# if "generated_queries" not in st.session_state:
|
1017 |
+
# st.session_state.generated_queries = {}
|
1018 |
+
|
1019 |
+
# manual_context = st.text_area("๐ Optional: Add your own context (e.g., mission, goals)", height=150)
|
1020 |
+
|
1021 |
+
# retriever = init_vector_search().as_retriever(search_kwargs={"k": k_value, "score_threshold": score_threshold})
|
1022 |
+
# rag_chain = get_rag_chain(retriever, use_openai=USE_OPENAI, max_tokens=max_tokens)
|
1023 |
+
|
1024 |
+
# uploaded_file = st.file_uploader("Upload PDF or TXT for extra context (optional)", type=["pdf", "txt"])
|
1025 |
+
# uploaded_text = ""
|
1026 |
+
|
1027 |
+
# if uploaded_file:
|
1028 |
+
# with st.spinner("๐ Processing uploaded file..."):
|
1029 |
+
# if uploaded_file.name.endswith(".pdf"):
|
1030 |
+
# reader = PdfReader(uploaded_file)
|
1031 |
+
# uploaded_text = "\n".join([page.extract_text() for page in reader.pages if page.extract_text()])
|
1032 |
+
# elif uploaded_file.name.endswith(".txt"):
|
1033 |
+
# uploaded_text = uploaded_file.read().decode("utf-8")
|
1034 |
+
|
1035 |
+
# # extract qs and headers using llms
|
1036 |
+
# questions = extract_with_llm_local(uploaded_text, use_openai=USE_OPENAI)
|
1037 |
+
|
1038 |
+
# # filter out irrelevant text
|
1039 |
+
# def is_meaningful_prompt(text: str) -> bool:
|
1040 |
+
# too_short = len(text.strip()) < 10
|
1041 |
+
# banned_keywords = ["phone", "email", "fax", "address", "date", "contact", "website"]
|
1042 |
+
# contains_bad_word = any(word in text.lower() for word in banned_keywords)
|
1043 |
+
# is_just_punctuation = all(c in ":.*- " for c in text.strip())
|
1044 |
+
# return not (too_short or contains_bad_word or is_just_punctuation)
|
1045 |
+
|
1046 |
+
# filtered_questions = [q for q in questions if is_meaningful_prompt(q)]
|
1047 |
+
# with st.form("question_selection_form"):
|
1048 |
+
# st.subheader("Choose prompts to answer:")
|
1049 |
+
# selected_questions=[]
|
1050 |
+
# for i,q in enumerate(filtered_questions):
|
1051 |
+
# if st.checkbox(q, key=f"q_{i}", value=True):
|
1052 |
+
# selected_questions.append(q)
|
1053 |
+
# submit_button = st.form_submit_button("Submit")
|
1054 |
+
|
1055 |
+
# #Multi-Select Question
|
1056 |
+
# if 'submit_button' in locals() and submit_button:
|
1057 |
+
# if selected_questions:
|
1058 |
+
# with st.spinner("๐ก Generating answers..."):
|
1059 |
+
# answers = []
|
1060 |
+
# for q in selected_questions:
|
1061 |
+
# # full_query = f"{q}\n\nAdditional context:\n{uploaded_text}"
|
1062 |
+
# combined_context = "\n\n".join(filter(None, [manual_context.strip(), uploaded_text.strip()]))
|
1063 |
+
# if q in st.session_state.generated_queries:
|
1064 |
+
# response = st.session_state.generated_queries[q]
|
1065 |
+
# else:
|
1066 |
+
# response = rag_chain.invoke({
|
1067 |
+
# "question": q,
|
1068 |
+
# "manual_context": combined_context
|
1069 |
+
# })
|
1070 |
+
# st.session_state.generated_queries[q] = response
|
1071 |
+
# answers.append({"question": q, "answer": response})
|
1072 |
+
# for item in answers:
|
1073 |
+
# st.markdown(f"### โ {item['question']}")
|
1074 |
+
# st.markdown(f"๐ฌ {item['answer']['answer']}")
|
1075 |
+
# tokens = item['answer'].get("tokens", {})
|
1076 |
+
# if tokens:
|
1077 |
+
# st.markdown(f"๐งฎ **Token Usage:** Prompt = {tokens.get('prompt')}, "
|
1078 |
+
# f"Completion = {tokens.get('completion')}, Total = {tokens.get('total')}")
|
1079 |
+
|
1080 |
+
# else:
|
1081 |
+
# st.info("No prompts selected for answering.")
|
1082 |
+
|
1083 |
+
|
1084 |
+
# # โ๏ธ Manual single-question input
|
1085 |
+
# query = st.text_input("Ask a grant-related question")
|
1086 |
+
# if st.button("Submit"):
|
1087 |
+
# if not query:
|
1088 |
+
# st.warning("Please enter a question.")
|
1089 |
+
# return
|
1090 |
+
|
1091 |
+
# # full_query = f"{query}\n\nAdditional context:\n{uploaded_text}" if uploaded_text else query
|
1092 |
+
# combined_context = "\n\n".join(filter(None, [manual_context.strip(), uploaded_text.strip()]))
|
1093 |
+
# with st.spinner("๐ค Thinking..."):
|
1094 |
+
# # response = rag_chain.invoke(full_query)
|
1095 |
+
# response = rag_chain.invoke({"question":query,"manual_context": combined_context})
|
1096 |
+
# st.text_area("Grant Buddy says:", value=response["answer"], height=250, disabled=True)
|
1097 |
+
# tokens=response.get("tokens",{})
|
1098 |
+
# if tokens:
|
1099 |
+
# st.markdown(f"๐งฎ **Token Usage:** Prompt = {tokens.get('prompt')}, "
|
1100 |
+
# f"Completion = {tokens.get('completion')}, Total = {tokens.get('total')}")
|
1101 |
+
|
1102 |
+
# with st.expander("๐ Retrieved Chunks"):
|
1103 |
+
# context_docs = retriever.get_relevant_documents(query)
|
1104 |
+
# for doc in context_docs:
|
1105 |
+
# # st.json(doc.metadata)
|
1106 |
+
# st.markdown(f"**Chunk ID:** {doc.metadata.get('chunk_id', 'unknown')} | **Title:** {doc.metadata['metadata'].get('title', 'unknown')}")
|
1107 |
+
# st.markdown(doc.page_content[:700] + "...")
|
1108 |
+
# st.markdown("---")
|
1109 |
+
|
1110 |
+
|
1111 |
+
|
1112 |
+
|
1113 |
+
|
1114 |
+
# if __name__ == "__main__":
|
1115 |
+
# main()
|
1116 |
+
|
1117 |
+
|