Spaces:
Running
Running
# app.py | |
import os | |
import re | |
import openai | |
from huggingface_hub import InferenceClient | |
import json | |
from huggingface_hub import HfApi | |
import streamlit as st | |
from typing import List, Dict, Any | |
from urllib.parse import quote_plus | |
from pymongo import MongoClient | |
from PyPDF2 import PdfReader | |
st.set_page_config(page_title="Grant Buddy RAG", page_icon="🤖") | |
from typing import List | |
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import MongoDBAtlasVectorSearch | |
from langchain.prompts import PromptTemplate | |
from langchain.schema import Document | |
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough | |
from huggingface_hub import InferenceClient | |
# =================== Secure Env via Hugging Face Secrets =================== | |
user = quote_plus(os.getenv("MONGO_USERNAME")) | |
password = quote_plus(os.getenv("MONGO_PASSWORD")) | |
cluster = os.getenv("MONGO_CLUSTER") | |
db_name = os.getenv("MONGO_DB_NAME", "files") | |
collection_name = os.getenv("MONGO_COLLECTION", "files_collection") | |
index_name = os.getenv("MONGO_VECTOR_INDEX", "vector_index") | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "").strip() | |
if OPENAI_API_KEY: | |
openai.api_key = OPENAI_API_KEY | |
from openai import OpenAI | |
client = OpenAI(api_key=OPENAI_API_KEY) | |
# MONGO_URI = f"mongodb+srv://{user}:{password}@{cluster}/{db_name}?retryWrites=true&w=majority" | |
MONGO_URI = f"mongodb+srv://{user}:{password}@{cluster}/{db_name}?retryWrites=true&w=majority&tls=true&tlsAllowInvalidCertificates=true" | |
# =================== Prompt =================== | |
grantbuddy_prompt = PromptTemplate.from_template( | |
"""You are Grant Buddy, a specialized language model fine-tuned with instruction-tuning and RLHF. | |
You help a nonprofit focused on social entrepreneurship, BIPOC empowerment, and edtech write clear, mission-aligned grant responses. | |
**Instructions:** | |
- Start with reasoning or context for your answer. | |
- Always align with the nonprofit’s mission. | |
- Use structured formatting: headings, bullet points, numbered lists. | |
- Include impact data or examples if relevant. | |
- Do NOT repeat the same sentence or answer multiple times. | |
- If no answer exists in the context, say: "This information is not available in the current context." | |
CONTEXT: | |
{context} | |
QUESTION: | |
{question} | |
""" | |
) | |
# =================== Vector Search Setup =================== | |
def init_embedding_model(): | |
return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
def init_vector_search() -> MongoDBAtlasVectorSearch: | |
HF_TOKEN = os.getenv("HF_TOKEN", "").strip() | |
model_name = "sentence-transformers/all-MiniLM-L6-v2" | |
st.write(f"🔌 Connecting to Hugging Face model: `{model_name}`") | |
embedding_model = HuggingFaceEmbeddings(model_name=model_name) | |
# ✅ Manual MongoClient with TLS settings | |
user = quote_plus(os.getenv("MONGO_USERNAME", "").strip()) | |
password = quote_plus(os.getenv("MONGO_PASSWORD", "").strip()) | |
cluster = os.getenv("MONGO_CLUSTER", "").strip() | |
db_name = os.getenv("MONGO_DB_NAME", "files").strip() | |
collection_name = os.getenv("MONGO_COLLECTION", "files_collection").strip() | |
index_name = os.getenv("MONGO_VECTOR_INDEX", "vector_index").strip() | |
mongo_uri = f"mongodb+srv://{user}:{password}@{cluster}/?retryWrites=true&w=majority" | |
try: | |
client = MongoClient(mongo_uri, tls=True, tlsAllowInvalidCertificates=True, serverSelectionTimeoutMS=20000) | |
db = client[db_name] | |
collection = db[collection_name] | |
st.success("✅ MongoClient connected successfully") | |
return MongoDBAtlasVectorSearch( | |
collection=collection, | |
embedding=embedding_model, | |
index_name=index_name, | |
) | |
except Exception as e: | |
st.error("❌ Failed to connect to MongoDB Atlas manually") | |
st.error(str(e)) | |
raise e | |
# =================== Question/Headers Extraction =================== | |
# def extract_questions_and_headers(text: str) -> List[str]: | |
# header_patterns = [ | |
# r'\d+\.\s+\*\*([^\*]+)\*\*', | |
# r'\*\*([^*]+)\*\*', | |
# r'^([A-Z][^a-z]*[A-Z])$', | |
# r'^([A-Z][A-Za-z\s]{3,})$', | |
# r'^[A-Z][A-Za-z\s]+:$' | |
# ] | |
# question_patterns = [ | |
# r'^.+\?$', | |
# r'^\*?Please .+', | |
# r'^How .+', | |
# r'^What .+', | |
# r'^Describe .+', | |
# ] | |
# combined_header_re = re.compile("|".join(header_patterns), re.MULTILINE) | |
# combined_question_re = re.compile("|".join(question_patterns), re.MULTILINE) | |
# headers = [match for group in combined_header_re.findall(text) for match in group if match] | |
# questions = combined_question_re.findall(text) | |
# return headers + questions | |
# def extract_with_llm(text: str) -> List[str]: | |
# client = InferenceClient(api_key=HF_TOKEN.strip()) | |
# try: | |
# response = client.chat.completions.create( | |
# model="mistralai/Mistral-Nemo-Instruct-2407", # or "HuggingFaceH4/zephyr-7b-beta" | |
# messages=[ | |
# { | |
# "role": "system", | |
# "content": "You are an assistant helping extract questions and headers from grant applications.", | |
# }, | |
# { | |
# "role": "user", | |
# "content": ( | |
# "Please extract all the grant application headers and questions from the following text. " | |
# "Include section titles, prompts, and any question-like content. Return them as a numbered list.\n\n" | |
# f"{text[:3000]}" | |
# ), | |
# }, | |
# ], | |
# temperature=0.2, | |
# max_tokens=512, | |
# ) | |
# return [ | |
# line.strip("•-1234567890. ").strip() | |
# for line in response.choices[0].message.content.strip().split("\n") | |
# if line.strip() | |
# ] | |
# except Exception as e: | |
# st.error("❌ LLM extraction failed") | |
# st.error(str(e)) | |
# return [] | |
# def extract_with_llm_local(text: str) -> List[str]: | |
# prompt = ( | |
# "You are an assistant helping extract useful questions and section headers from a grant application.\n" | |
# "Return only the important prompts as a numbered list.\n\n" | |
# "TEXT:\n" | |
# f"{text[:3000]}\n\n" | |
# "PROMPTS:" | |
# ) | |
# inputs = tokenizer(prompt, return_tensors="pt", truncation=True) | |
# outputs = model.generate( | |
# **inputs, | |
# max_new_tokens=512, | |
# temperature=0.3, | |
# do_sample=False | |
# ) | |
# raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# # Extract prompts from the numbered list in the output | |
# lines = raw_output.split("\n") | |
# prompts = [] | |
# for line in lines: | |
# line = line.strip("•-1234567890. ").strip() | |
# if len(line) > 10: | |
# prompts.append(line) | |
# return prompts | |
# def extract_with_llm_local(text: str) -> List[str]: | |
# example_text = """TEXT: | |
# 1. Project Summary: Please describe the main goals of your project. | |
# 2. Contact Information: Address, phone, email. | |
# 3. What is the mission of your organization? | |
# 4. Who are the beneficiaries? | |
# 5. Budget Breakdown | |
# 6. Please describe how the funding will be used. | |
# 7. Website: www.example.org | |
# PROMPTS: | |
# 1. Project Summary | |
# 2. What is the mission of your organization? | |
# 3. Who are the beneficiaries? | |
# 4. Please describe how the funding will be used. | |
# """ | |
# prompt = ( | |
# "You are an assistant helping extract important grant application prompts and section headers.\n" | |
# "Return only questions and meaningful section titles that require thoughtful answers.\n" | |
# "Avoid metadata like phone numbers, dates, contact info, or websites.\n\n" | |
# f"{example_text}\n" | |
# f"TEXT:\n{text[:3000]}\n\n" | |
# "PROMPTS:" | |
# ) | |
# inputs = tokenizer(prompt, return_tensors="pt", truncation=True) | |
# outputs = model.generate( | |
# **inputs, | |
# max_new_tokens=512, | |
# temperature=0.3, | |
# do_sample=False | |
# ) | |
# raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# # Clean and extract numbered or bulleted lines | |
# lines = raw_output.split("\n") | |
# prompts = [] | |
# for line in lines: | |
# clean = line.strip("•-1234567890. ").strip() | |
# if len(clean) > 10 and not any(bad in clean.lower() for bad in ["phone", "email", "address", "website"]): | |
# prompts.append(clean) | |
# return prompts | |
def extract_with_llm_local(text: str, use_openai: bool = False) -> List[str]: | |
# Example context to prime the model | |
example_text = """TEXT: | |
1. Project Summary: Please describe the main goals of your project. | |
2. Contact Information: Address, phone, email. | |
3. What is the mission of your organization? | |
4. Who are the beneficiaries? | |
5. Budget Breakdown | |
6. Please describe how the funding will be used. | |
7. Website: www.example.org | |
PROMPTS: | |
1. Project Summary | |
2. What is the mission of your organization? | |
3. Who are the beneficiaries? | |
4. Please describe how the funding will be used. | |
""" | |
prompt = ( | |
"You are an assistant helping extract important grant application prompts and section headers.\n" | |
"Return only questions and meaningful section titles that require thoughtful answers.\n" | |
"Avoid metadata like phone numbers, dates, contact info, or websites.\n\n" | |
f"{example_text}\n" | |
f"TEXT:\n{text[:3000]}\n\n" | |
"PROMPTS:" | |
) | |
if use_openai: | |
if not openai.api_key: | |
st.error("❌ OPENAI_API_KEY is not set.") | |
return "⚠️ OpenAI key missing." | |
try: | |
response = client.chat.completions.create( | |
model="gpt-4o-mini", | |
messages=[ | |
{"role": "system", "content": "You extract prompts and headers from grant text."}, | |
{"role": "user", "content": prompt}, | |
], | |
temperature=0.2, | |
max_tokens=500, | |
) | |
# raw_output = response["choices"][0]["message"]["content"] | |
raw_output = response.choices[0].message.content | |
st.markdown(f"🧮 Extract Tokens: Prompt = {response.usage.prompt_tokens}, " | |
f"Completion = {response.usage.completion_tokens}, Total = {response.usage.total_tokens}") | |
except Exception as e: | |
st.error(f"❌ OpenAI extraction failed: {e}") | |
return [] | |
else: | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True) | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=min(max_tokens,512), | |
temperature=0.3, | |
do_sample=False, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Clean and deduplicate prompts | |
lines = raw_output.split("\n") | |
prompts = [] | |
seen = set() | |
for line in lines: | |
clean = line.strip("•-1234567890. ").strip() | |
if ( | |
len(clean) > 10 | |
and not any(bad in clean.lower() for bad in ["phone", "email", "address", "website"]) | |
and clean not in seen | |
): | |
prompts.append(clean) | |
seen.add(clean) | |
return prompts | |
# def is_meaningful_prompt(text: str) -> bool: | |
# too_short = len(text.strip()) < 10 | |
# banned_keywords = ["phone", "email", "fax", "address", "date", "contact", "website"] | |
# contains_bad_word = any(word in text.lower() for word in banned_keywords) | |
# is_just_punctuation = all(c in ":.*- " for c in text.strip()) | |
# return not (too_short or contains_bad_word or is_just_punctuation) | |
# =================== Format Retrieved Chunks =================== | |
def format_docs(docs: List[Document]) -> str: | |
return "\n\n".join(doc.page_content or doc.metadata.get("content", "") for doc in docs) | |
# =================== Generate Response from Hugging Face Model =================== | |
# def generate_response(input_dict: Dict[str, Any]) -> str: | |
# client = InferenceClient(api_key=HF_TOKEN.strip()) | |
# prompt = grantbuddy_prompt.format(**input_dict) | |
# try: | |
# response = client.chat.completions.create( | |
# model="HuggingFaceH4/zephyr-7b-beta", | |
# messages=[ | |
# {"role": "system", "content": prompt}, | |
# {"role": "user", "content": input_dict["question"]}, | |
# ], | |
# max_tokens=1000, | |
# temperature=0.2, | |
# ) | |
# return response.choices[0].message.content | |
# except Exception as e: | |
# st.error(f"❌ Error from model: {e}") | |
# return "⚠️ Failed to generate response. Please check your model, HF token, or request format." | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
def load_local_model(): | |
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
return tokenizer, model | |
tokenizer, model = load_local_model() | |
def generate_response(input_dict, use_openai=False, max_tokens=700): | |
prompt = grantbuddy_prompt.format(**input_dict) | |
if use_openai: | |
try: | |
response = client.chat.completions.create( | |
model="gpt-4o-mini", | |
messages=[ | |
{"role": "system", "content": prompt}, | |
{"role": "user", "content": input_dict["question"]}, | |
], | |
temperature=0.2, | |
max_tokens=max_tokens, | |
) | |
answer = response.choices[0].message.content.strip() | |
# ✅ Token logging | |
prompt_tokens = response.usage.prompt_tokens | |
completion_tokens = response.usage.completion_tokens | |
total_tokens = response.usage.total_tokens | |
return { | |
"answer": answer, | |
"tokens": { | |
"prompt": prompt_tokens, | |
"completion": completion_tokens, | |
"total": total_tokens | |
} | |
} | |
except Exception as e: | |
st.error(f"❌ OpenAI error: {e}") | |
return { | |
"answer": "⚠️ OpenAI request failed.", | |
"tokens": {} | |
} | |
else: | |
inputs = tokenizer(prompt, return_tensors="pt") | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=512, | |
temperature=0.7, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return { | |
"answer": decoded[len(prompt):].strip(), | |
"tokens": {} | |
} | |
# =================== RAG Chain =================== | |
def get_rag_chain(retriever, use_openai=False, max_tokens=700): | |
def merge_contexts(inputs): | |
#use chunks if provided | |
retrieved_chunks = format_docs(inputs["context_docs"]) if "context_docs" in inputs \ | |
else format_docs(retriever.invoke(inputs["question"])) | |
combined = "\n\n".join(filter(None, [ | |
inputs.get("manual_context", ""), | |
retrieved_chunks | |
])) | |
return { | |
"context": combined, | |
"question": inputs["question"] | |
} | |
return RunnableLambda(merge_contexts) | RunnableLambda( | |
lambda input_dict: generate_response(input_dict, use_openai=use_openai, max_tokens=max_tokens) | |
) | |
def rerank_with_topics(chunks, topics, alpha=0.2): | |
""" | |
Boosts similarity based on topic overlap. | |
Since chunks don't have scores, we use rank order and topic matches. | |
""" | |
topics_lower = set(t.lower() for t in topics) | |
def score(chunk, rank): | |
chunk_topics = [t.lower() for t in chunk.metadata.get("topics", [])] | |
topic_matches = len(topics_lower.intersection(chunk_topics)) | |
# Lower is better: original rank minus boost | |
return rank - alpha * topic_matches | |
reranked = sorted( | |
enumerate(chunks), | |
key=lambda x: score(x[1], x[0]) # x[0] is rank, x[1] is chunk | |
) | |
return [chunk for _, chunk in reranked] | |
# =================== Streamlit UI =================== | |
def main(): | |
# st.set_page_config(page_title="Grant Buddy RAG", page_icon="🤖") | |
st.title("🤖 Grant Buddy: Grant-Writing Assistant") | |
USE_OPENAI = st.sidebar.checkbox("Use OpenAI (Costs Tokens)", value=False) | |
st.sidebar.markdown("### Retrieval Settings") | |
k_value = st.sidebar.slider("How many chunks to retrieve (k)", min_value=5, max_value=40, step=5, value=10) | |
score_threshold = st.sidebar.slider("Minimum relevance score", min_value=0.0, max_value=1.0, step=0.05, value=0.75) | |
topic_input=st.sidebar.text_input("Optional: Focus on specific topics (comma-separated)") | |
topics=[t.strip() for t in topic_input.split(",") if t.strip()] | |
topic_weight= st.sidebar.slider("Topic relevance score", min_value=0.0, max_value=1.0, step=0.05, value=0.2) | |
st.sidebar.markdown("### Generation Settings") | |
max_tokens = st.sidebar.number_input("Max tokens in response", min_value=100, max_value=1500, value=700, step=50) | |
if "generated_queries" not in st.session_state: | |
st.session_state.generated_queries = {} | |
manual_context = st.text_area("📝 Optional: Add your own context (e.g., mission, goals)", height=150) | |
# # retriever = init_vector_search().as_retriever(search_kwargs={"k": k_value, "score_threshold": score_threshold}) | |
retriever = init_vector_search().as_retriever() | |
# pre_k = k_value*4 # Retrieve more chunks first | |
# context_docs = retriever.get_relevant_documents(query, k=pre_k) | |
# if topics: | |
# context_docs = rerank_with_topics(context_docs, topics, alpha=topic_weight) | |
# context_docs = context_docs[:k_value] # Final top-k used in RAG | |
rag_chain = get_rag_chain(retriever, use_openai=USE_OPENAI, max_tokens=max_tokens) | |
uploaded_file = st.file_uploader("Upload PDF or TXT for extra context (optional)", type=["pdf", "txt"]) | |
uploaded_text = "" | |
if uploaded_file: | |
with st.spinner("📄 Processing uploaded file..."): | |
if uploaded_file.name.endswith(".pdf"): | |
reader = PdfReader(uploaded_file) | |
uploaded_text = "\n".join([page.extract_text() for page in reader.pages if page.extract_text()]) | |
elif uploaded_file.name.endswith(".txt"): | |
uploaded_text = uploaded_file.read().decode("utf-8") | |
# extract qs and headers using llms | |
questions = extract_with_llm_local(uploaded_text, use_openai=USE_OPENAI) | |
# filter out irrelevant text | |
def is_meaningful_prompt(text: str) -> bool: | |
too_short = len(text.strip()) < 10 | |
banned_keywords = ["phone", "email", "fax", "address", "date", "contact", "website"] | |
contains_bad_word = any(word in text.lower() for word in banned_keywords) | |
is_just_punctuation = all(c in ":.*- " for c in text.strip()) | |
return not (too_short or contains_bad_word or is_just_punctuation) | |
filtered_questions = [q for q in questions if is_meaningful_prompt(q)] | |
with st.form("question_selection_form"): | |
st.subheader("Choose prompts to answer:") | |
selected_questions=[] | |
for i,q in enumerate(filtered_questions): | |
if st.checkbox(q, key=f"q_{i}", value=True): | |
selected_questions.append(q) | |
submit_button = st.form_submit_button("Submit") | |
#Multi-Select Question | |
if 'submit_button' in locals() and submit_button: | |
if selected_questions: | |
with st.spinner("💡 Generating answers..."): | |
answers = [] | |
for q in selected_questions: | |
combined_context = "\n\n".join(filter(None, [manual_context.strip(), uploaded_text.strip()])) | |
pre_k=k_value*4 | |
context_docs=retriever.get_relevant_documents(q, k=pre_k) | |
if topics: | |
context_docs=rerank_with_topics(context_docs,topics,alpha=topic_weight) | |
context_docs=context_docs[:k_value] | |
# full_query = f"{q}\n\nAdditional context:\n{uploaded_text}" | |
if q in st.session_state.generated_queries: | |
response = st.session_state.generated_queries[q] | |
else: | |
response = rag_chain.invoke({ | |
"question": q, | |
"manual_context": combined_context, | |
"context_docs": context_docs | |
}) | |
st.session_state.generated_queries[q] = response | |
answers.append({"question": q, "answer": response}) | |
for item in answers: | |
st.markdown(f"### ❓ {item['question']}") | |
st.markdown(f"💬 {item['answer']['answer']}") | |
tokens = item['answer'].get("tokens", {}) | |
if tokens: | |
st.markdown(f"🧮 **Token Usage:** Prompt = {tokens.get('prompt')}, " | |
f"Completion = {tokens.get('completion')}, Total = {tokens.get('total')}") | |
else: | |
st.info("No prompts selected for answering.") | |
# ✍️ Manual single-question input | |
query = st.text_input("Ask a grant-related question") | |
if st.button("Submit"): | |
if not query: | |
st.warning("Please enter a question.") | |
return | |
# full_query = f"{query}\n\nAdditional context:\n{uploaded_text}" if uploaded_text else query | |
pre_k = k_value * 4 | |
context_docs=retriever.get_relevant_documents(query, k=pre_k) | |
if topics: | |
context_docs=rerank_with_topics(context_docs, topics, alpha=topic_weight) | |
context_docs = context_docs[:k_value] | |
combined_context = "\n\n".join(filter(None, [manual_context.strip(), uploaded_text.strip()])) | |
with st.spinner("🤖 Thinking..."): | |
# response = rag_chain.invoke(full_query) | |
response = rag_chain.invoke({"question":query,"manual_context": combined_context, "context_docs": context_docs}) | |
st.text_area("Grant Buddy says:", value=response["answer"], height=250, disabled=True) | |
tokens=response.get("tokens",{}) | |
if tokens: | |
st.markdown(f"🧮 **Token Usage:** Prompt = {tokens.get('prompt')}, " | |
f"Completion = {tokens.get('completion')}, Total = {tokens.get('total')}") | |
with st.expander("🔍 Retrieved Chunks"): | |
# context_docs = retriever.get_relevant_documents(query) | |
for doc in context_docs: | |
# st.json(doc.metadata) | |
st.markdown(f"**Chunk ID:** {doc.metadata.get('chunk_id', 'unknown')} | **Title:** {doc.metadata['metadata'].get('title', 'unknown')}") | |
st.markdown(doc.page_content[:700] + "...") | |
if topics: | |
matched_topics=set(doc.metadata['metadata'].get('topics',[])).intersection(topics) | |
st.markdown(f"**Matched Topics** {','.join(matched_topics)}") | |
st.markdown("---") | |
if __name__ == "__main__": | |
main() | |
# # app.py | |
# import os | |
# import re | |
# import openai | |
# from huggingface_hub import InferenceClient | |
# import json | |
# from huggingface_hub import HfApi | |
# import streamlit as st | |
# from typing import List, Dict, Any | |
# from urllib.parse import quote_plus | |
# from pymongo import MongoClient | |
# from PyPDF2 import PdfReader | |
# st.set_page_config(page_title="Grant Buddy RAG", page_icon="🤖") | |
# from typing import List | |
# from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings | |
# from langchain.embeddings import HuggingFaceEmbeddings | |
# from langchain_community.vectorstores import MongoDBAtlasVectorSearch | |
# from langchain.prompts import PromptTemplate | |
# from langchain.schema import Document | |
# from langchain.schema.runnable import RunnableLambda, RunnablePassthrough | |
# from huggingface_hub import InferenceClient | |
# # =================== Secure Env via Hugging Face Secrets =================== | |
# user = quote_plus(os.getenv("MONGO_USERNAME")) | |
# password = quote_plus(os.getenv("MONGO_PASSWORD")) | |
# cluster = os.getenv("MONGO_CLUSTER") | |
# db_name = os.getenv("MONGO_DB_NAME", "files") | |
# collection_name = os.getenv("MONGO_COLLECTION", "files_collection") | |
# index_name = os.getenv("MONGO_VECTOR_INDEX", "vector_index") | |
# HF_TOKEN = os.getenv("HF_TOKEN") | |
# OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "").strip() | |
# if OPENAI_API_KEY: | |
# openai.api_key = OPENAI_API_KEY | |
# from openai import OpenAI | |
# client = OpenAI(api_key=OPENAI_API_KEY) | |
# # MONGO_URI = f"mongodb+srv://{user}:{password}@{cluster}/{db_name}?retryWrites=true&w=majority" | |
# MONGO_URI = f"mongodb+srv://{user}:{password}@{cluster}/{db_name}?retryWrites=true&w=majority&tls=true&tlsAllowInvalidCertificates=true" | |
# # =================== Prompt =================== | |
# grantbuddy_prompt = PromptTemplate.from_template( | |
# """You are Grant Buddy, a specialized language model fine-tuned with instruction-tuning and RLHF. | |
# You help a nonprofit focused on social entrepreneurship, BIPOC empowerment, and edtech write clear, mission-aligned grant responses. | |
# **Instructions:** | |
# - Start with reasoning or context for your answer. | |
# - Always align with the nonprofit’s mission. | |
# - Use structured formatting: headings, bullet points, numbered lists. | |
# - Include impact data or examples if relevant. | |
# - Do NOT repeat the same sentence or answer multiple times. | |
# - If no answer exists in the context, say: "This information is not available in the current context." | |
# CONTEXT: | |
# {context} | |
# QUESTION: | |
# {question} | |
# """ | |
# ) | |
# # =================== Vector Search Setup =================== | |
# @st.cache_resource | |
# def init_embedding_model(): | |
# return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
# @st.cache_resource | |
# def init_vector_search() -> MongoDBAtlasVectorSearch: | |
# HF_TOKEN = os.getenv("HF_TOKEN", "").strip() | |
# model_name = "sentence-transformers/all-MiniLM-L6-v2" | |
# st.write(f"🔌 Connecting to Hugging Face model: `{model_name}`") | |
# embedding_model = HuggingFaceEmbeddings(model_name=model_name) | |
# # ✅ Manual MongoClient with TLS settings | |
# user = quote_plus(os.getenv("MONGO_USERNAME", "").strip()) | |
# password = quote_plus(os.getenv("MONGO_PASSWORD", "").strip()) | |
# cluster = os.getenv("MONGO_CLUSTER", "").strip() | |
# db_name = os.getenv("MONGO_DB_NAME", "files").strip() | |
# collection_name = os.getenv("MONGO_COLLECTION", "files_collection").strip() | |
# index_name = os.getenv("MONGO_VECTOR_INDEX", "vector_index").strip() | |
# mongo_uri = f"mongodb+srv://{user}:{password}@{cluster}/?retryWrites=true&w=majority" | |
# try: | |
# client = MongoClient(mongo_uri, tls=True, tlsAllowInvalidCertificates=True, serverSelectionTimeoutMS=20000) | |
# db = client[db_name] | |
# collection = db[collection_name] | |
# st.success("✅ MongoClient connected successfully") | |
# return MongoDBAtlasVectorSearch( | |
# collection=collection, | |
# embedding=embedding_model, | |
# index_name=index_name, | |
# ) | |
# except Exception as e: | |
# st.error("❌ Failed to connect to MongoDB Atlas manually") | |
# st.error(str(e)) | |
# raise e | |
# # =================== Question/Headers Extraction =================== | |
# # def extract_questions_and_headers(text: str) -> List[str]: | |
# # header_patterns = [ | |
# # r'\d+\.\s+\*\*([^\*]+)\*\*', | |
# # r'\*\*([^*]+)\*\*', | |
# # r'^([A-Z][^a-z]*[A-Z])$', | |
# # r'^([A-Z][A-Za-z\s]{3,})$', | |
# # r'^[A-Z][A-Za-z\s]+:$' | |
# # ] | |
# # question_patterns = [ | |
# # r'^.+\?$', | |
# # r'^\*?Please .+', | |
# # r'^How .+', | |
# # r'^What .+', | |
# # r'^Describe .+', | |
# # ] | |
# # combined_header_re = re.compile("|".join(header_patterns), re.MULTILINE) | |
# # combined_question_re = re.compile("|".join(question_patterns), re.MULTILINE) | |
# # headers = [match for group in combined_header_re.findall(text) for match in group if match] | |
# # questions = combined_question_re.findall(text) | |
# # return headers + questions | |
# # def extract_with_llm(text: str) -> List[str]: | |
# # client = InferenceClient(api_key=HF_TOKEN.strip()) | |
# # try: | |
# # response = client.chat.completions.create( | |
# # model="mistralai/Mistral-Nemo-Instruct-2407", # or "HuggingFaceH4/zephyr-7b-beta" | |
# # messages=[ | |
# # { | |
# # "role": "system", | |
# # "content": "You are an assistant helping extract questions and headers from grant applications.", | |
# # }, | |
# # { | |
# # "role": "user", | |
# # "content": ( | |
# # "Please extract all the grant application headers and questions from the following text. " | |
# # "Include section titles, prompts, and any question-like content. Return them as a numbered list.\n\n" | |
# # f"{text[:3000]}" | |
# # ), | |
# # }, | |
# # ], | |
# # temperature=0.2, | |
# # max_tokens=512, | |
# # ) | |
# # return [ | |
# # line.strip("•-1234567890. ").strip() | |
# # for line in response.choices[0].message.content.strip().split("\n") | |
# # if line.strip() | |
# # ] | |
# # except Exception as e: | |
# # st.error("❌ LLM extraction failed") | |
# # st.error(str(e)) | |
# # return [] | |
# # def extract_with_llm_local(text: str) -> List[str]: | |
# # prompt = ( | |
# # "You are an assistant helping extract useful questions and section headers from a grant application.\n" | |
# # "Return only the important prompts as a numbered list.\n\n" | |
# # "TEXT:\n" | |
# # f"{text[:3000]}\n\n" | |
# # "PROMPTS:" | |
# # ) | |
# # inputs = tokenizer(prompt, return_tensors="pt", truncation=True) | |
# # outputs = model.generate( | |
# # **inputs, | |
# # max_new_tokens=512, | |
# # temperature=0.3, | |
# # do_sample=False | |
# # ) | |
# # raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# # # Extract prompts from the numbered list in the output | |
# # lines = raw_output.split("\n") | |
# # prompts = [] | |
# # for line in lines: | |
# # line = line.strip("•-1234567890. ").strip() | |
# # if len(line) > 10: | |
# # prompts.append(line) | |
# # return prompts | |
# # def extract_with_llm_local(text: str) -> List[str]: | |
# # example_text = """TEXT: | |
# # 1. Project Summary: Please describe the main goals of your project. | |
# # 2. Contact Information: Address, phone, email. | |
# # 3. What is the mission of your organization? | |
# # 4. Who are the beneficiaries? | |
# # 5. Budget Breakdown | |
# # 6. Please describe how the funding will be used. | |
# # 7. Website: www.example.org | |
# # PROMPTS: | |
# # 1. Project Summary | |
# # 2. What is the mission of your organization? | |
# # 3. Who are the beneficiaries? | |
# # 4. Please describe how the funding will be used. | |
# # """ | |
# # prompt = ( | |
# # "You are an assistant helping extract important grant application prompts and section headers.\n" | |
# # "Return only questions and meaningful section titles that require thoughtful answers.\n" | |
# # "Avoid metadata like phone numbers, dates, contact info, or websites.\n\n" | |
# # f"{example_text}\n" | |
# # f"TEXT:\n{text[:3000]}\n\n" | |
# # "PROMPTS:" | |
# # ) | |
# # inputs = tokenizer(prompt, return_tensors="pt", truncation=True) | |
# # outputs = model.generate( | |
# # **inputs, | |
# # max_new_tokens=512, | |
# # temperature=0.3, | |
# # do_sample=False | |
# # ) | |
# # raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# # # Clean and extract numbered or bulleted lines | |
# # lines = raw_output.split("\n") | |
# # prompts = [] | |
# # for line in lines: | |
# # clean = line.strip("•-1234567890. ").strip() | |
# # if len(clean) > 10 and not any(bad in clean.lower() for bad in ["phone", "email", "address", "website"]): | |
# # prompts.append(clean) | |
# # return prompts | |
# def extract_with_llm_local(text: str, use_openai: bool = False) -> List[str]: | |
# # Example context to prime the model | |
# example_text = """TEXT: | |
# 1. Project Summary: Please describe the main goals of your project. | |
# 2. Contact Information: Address, phone, email. | |
# 3. What is the mission of your organization? | |
# 4. Who are the beneficiaries? | |
# 5. Budget Breakdown | |
# 6. Please describe how the funding will be used. | |
# 7. Website: www.example.org | |
# PROMPTS: | |
# 1. Project Summary | |
# 2. What is the mission of your organization? | |
# 3. Who are the beneficiaries? | |
# 4. Please describe how the funding will be used. | |
# """ | |
# prompt = ( | |
# "You are an assistant helping extract important grant application prompts and section headers.\n" | |
# "Return only questions and meaningful section titles that require thoughtful answers.\n" | |
# "Avoid metadata like phone numbers, dates, contact info, or websites.\n\n" | |
# f"{example_text}\n" | |
# f"TEXT:\n{text[:3000]}\n\n" | |
# "PROMPTS:" | |
# ) | |
# if use_openai: | |
# if not openai.api_key: | |
# st.error("❌ OPENAI_API_KEY is not set.") | |
# return "⚠️ OpenAI key missing." | |
# try: | |
# response = client.chat.completions.create( | |
# model="gpt-4o-mini", | |
# messages=[ | |
# {"role": "system", "content": "You extract prompts and headers from grant text."}, | |
# {"role": "user", "content": prompt}, | |
# ], | |
# temperature=0.2, | |
# max_tokens=500, | |
# ) | |
# # raw_output = response["choices"][0]["message"]["content"] | |
# raw_output = response.choices[0].message.content | |
# st.markdown(f"🧮 Extract Tokens: Prompt = {response.usage.prompt_tokens}, " | |
# f"Completion = {response.usage.completion_tokens}, Total = {response.usage.total_tokens}") | |
# except Exception as e: | |
# st.error(f"❌ OpenAI extraction failed: {e}") | |
# return [] | |
# else: | |
# inputs = tokenizer(prompt, return_tensors="pt", truncation=True) | |
# outputs = model.generate( | |
# **inputs, | |
# max_new_tokens=min(ax_tokens,512), | |
# temperature=0.3, | |
# do_sample=False, | |
# pad_token_id=tokenizer.eos_token_id | |
# ) | |
# raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# # Clean and deduplicate prompts | |
# lines = raw_output.split("\n") | |
# prompts = [] | |
# seen = set() | |
# for line in lines: | |
# clean = line.strip("•-1234567890. ").strip() | |
# if ( | |
# len(clean) > 10 | |
# and not any(bad in clean.lower() for bad in ["phone", "email", "address", "website"]) | |
# and clean not in seen | |
# ): | |
# prompts.append(clean) | |
# seen.add(clean) | |
# return prompts | |
# # def is_meaningful_prompt(text: str) -> bool: | |
# # too_short = len(text.strip()) < 10 | |
# # banned_keywords = ["phone", "email", "fax", "address", "date", "contact", "website"] | |
# # contains_bad_word = any(word in text.lower() for word in banned_keywords) | |
# # is_just_punctuation = all(c in ":.*- " for c in text.strip()) | |
# # return not (too_short or contains_bad_word or is_just_punctuation) | |
# # =================== Format Retrieved Chunks =================== | |
# def format_docs(docs: List[Document]) -> str: | |
# return "\n\n".join(doc.page_content or doc.metadata.get("content", "") for doc in docs) | |
# # =================== Generate Response from Hugging Face Model =================== | |
# # def generate_response(input_dict: Dict[str, Any]) -> str: | |
# # client = InferenceClient(api_key=HF_TOKEN.strip()) | |
# # prompt = grantbuddy_prompt.format(**input_dict) | |
# # try: | |
# # response = client.chat.completions.create( | |
# # model="HuggingFaceH4/zephyr-7b-beta", | |
# # messages=[ | |
# # {"role": "system", "content": prompt}, | |
# # {"role": "user", "content": input_dict["question"]}, | |
# # ], | |
# # max_tokens=1000, | |
# # temperature=0.2, | |
# # ) | |
# # return response.choices[0].message.content | |
# # except Exception as e: | |
# # st.error(f"❌ Error from model: {e}") | |
# # return "⚠️ Failed to generate response. Please check your model, HF token, or request format." | |
# from transformers import AutoModelForCausalLM, AutoTokenizer | |
# import torch | |
# @st.cache_resource | |
# def load_local_model(): | |
# model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
# tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# model = AutoModelForCausalLM.from_pretrained(model_name) | |
# return tokenizer, model | |
# tokenizer, model = load_local_model() | |
# def generate_response(input_dict, use_openai=False, max_tokens=700): | |
# prompt = grantbuddy_prompt.format(**input_dict) | |
# if use_openai: | |
# try: | |
# response = client.chat.completions.create( | |
# model="gpt-4o-mini", | |
# messages=[ | |
# {"role": "system", "content": prompt}, | |
# {"role": "user", "content": input_dict["question"]}, | |
# ], | |
# temperature=0.2, | |
# max_tokens=max_tokens, | |
# ) | |
# answer = response.choices[0].message.content.strip() | |
# # ✅ Token logging | |
# prompt_tokens = response.usage.prompt_tokens | |
# completion_tokens = response.usage.completion_tokens | |
# total_tokens = response.usage.total_tokens | |
# return { | |
# "answer": answer, | |
# "tokens": { | |
# "prompt": prompt_tokens, | |
# "completion": completion_tokens, | |
# "total": total_tokens | |
# } | |
# } | |
# except Exception as e: | |
# st.error(f"❌ OpenAI error: {e}") | |
# return { | |
# "answer": "⚠️ OpenAI request failed.", | |
# "tokens": {} | |
# } | |
# else: | |
# inputs = tokenizer(prompt, return_tensors="pt") | |
# outputs = model.generate( | |
# **inputs, | |
# max_new_tokens=512, | |
# temperature=0.7, | |
# do_sample=True, | |
# pad_token_id=tokenizer.eos_token_id | |
# ) | |
# decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# return { | |
# "answer": decoded[len(prompt):].strip(), | |
# "tokens": {} | |
# } | |
# # =================== RAG Chain =================== | |
# def get_rag_chain(retriever, use_openai=False, max_tokens=700): | |
# def merge_contexts(inputs): | |
# retrieved_chunks = format_docs(retriever.invoke(inputs["question"])) | |
# combined = "\n\n".join(filter(None, [ | |
# inputs.get("manual_context", ""), | |
# retrieved_chunks | |
# ])) | |
# return { | |
# "context": combined, | |
# "question": inputs["question"] | |
# } | |
# return RunnableLambda(merge_contexts) | RunnableLambda( | |
# lambda input_dict: generate_response(input_dict, use_openai=use_openai, max_tokens=max_tokens) | |
# ) | |
# # =================== Streamlit UI =================== | |
# def main(): | |
# # st.set_page_config(page_title="Grant Buddy RAG", page_icon="🤖") | |
# st.title("🤖 Grant Buddy: Grant-Writing Assistant") | |
# USE_OPENAI = st.sidebar.checkbox("Use OpenAI (Costs Tokens)", value=False) | |
# st.sidebar.markdown("### Retrieval Settings") | |
# k_value = st.sidebar.slider("How many chunks to retrieve (k)", min_value=5, max_value=40, step=5, value=10) | |
# score_threshold = st.sidebar.slider("Minimum relevance score", min_value=0.0, max_value=1.0, step=0.05, value=0.75) | |
# st.sidebar.markdown("### Generation Settings") | |
# max_tokens = st.sidebar.number_input("Max tokens in response", min_value=100, max_value=1500, value=700, step=50) | |
# if "generated_queries" not in st.session_state: | |
# st.session_state.generated_queries = {} | |
# manual_context = st.text_area("📝 Optional: Add your own context (e.g., mission, goals)", height=150) | |
# retriever = init_vector_search().as_retriever(search_kwargs={"k": k_value, "score_threshold": score_threshold}) | |
# rag_chain = get_rag_chain(retriever, use_openai=USE_OPENAI, max_tokens=max_tokens) | |
# uploaded_file = st.file_uploader("Upload PDF or TXT for extra context (optional)", type=["pdf", "txt"]) | |
# uploaded_text = "" | |
# if uploaded_file: | |
# with st.spinner("📄 Processing uploaded file..."): | |
# if uploaded_file.name.endswith(".pdf"): | |
# reader = PdfReader(uploaded_file) | |
# uploaded_text = "\n".join([page.extract_text() for page in reader.pages if page.extract_text()]) | |
# elif uploaded_file.name.endswith(".txt"): | |
# uploaded_text = uploaded_file.read().decode("utf-8") | |
# # extract qs and headers using llms | |
# questions = extract_with_llm_local(uploaded_text, use_openai=USE_OPENAI) | |
# # filter out irrelevant text | |
# def is_meaningful_prompt(text: str) -> bool: | |
# too_short = len(text.strip()) < 10 | |
# banned_keywords = ["phone", "email", "fax", "address", "date", "contact", "website"] | |
# contains_bad_word = any(word in text.lower() for word in banned_keywords) | |
# is_just_punctuation = all(c in ":.*- " for c in text.strip()) | |
# return not (too_short or contains_bad_word or is_just_punctuation) | |
# filtered_questions = [q for q in questions if is_meaningful_prompt(q)] | |
# with st.form("question_selection_form"): | |
# st.subheader("Choose prompts to answer:") | |
# selected_questions=[] | |
# for i,q in enumerate(filtered_questions): | |
# if st.checkbox(q, key=f"q_{i}", value=True): | |
# selected_questions.append(q) | |
# submit_button = st.form_submit_button("Submit") | |
# #Multi-Select Question | |
# if 'submit_button' in locals() and submit_button: | |
# if selected_questions: | |
# with st.spinner("💡 Generating answers..."): | |
# answers = [] | |
# for q in selected_questions: | |
# # full_query = f"{q}\n\nAdditional context:\n{uploaded_text}" | |
# combined_context = "\n\n".join(filter(None, [manual_context.strip(), uploaded_text.strip()])) | |
# if q in st.session_state.generated_queries: | |
# response = st.session_state.generated_queries[q] | |
# else: | |
# response = rag_chain.invoke({ | |
# "question": q, | |
# "manual_context": combined_context | |
# }) | |
# st.session_state.generated_queries[q] = response | |
# answers.append({"question": q, "answer": response}) | |
# for item in answers: | |
# st.markdown(f"### ❓ {item['question']}") | |
# st.markdown(f"💬 {item['answer']['answer']}") | |
# tokens = item['answer'].get("tokens", {}) | |
# if tokens: | |
# st.markdown(f"🧮 **Token Usage:** Prompt = {tokens.get('prompt')}, " | |
# f"Completion = {tokens.get('completion')}, Total = {tokens.get('total')}") | |
# else: | |
# st.info("No prompts selected for answering.") | |
# # ✍️ Manual single-question input | |
# query = st.text_input("Ask a grant-related question") | |
# if st.button("Submit"): | |
# if not query: | |
# st.warning("Please enter a question.") | |
# return | |
# # full_query = f"{query}\n\nAdditional context:\n{uploaded_text}" if uploaded_text else query | |
# combined_context = "\n\n".join(filter(None, [manual_context.strip(), uploaded_text.strip()])) | |
# with st.spinner("🤖 Thinking..."): | |
# # response = rag_chain.invoke(full_query) | |
# response = rag_chain.invoke({"question":query,"manual_context": combined_context}) | |
# st.text_area("Grant Buddy says:", value=response["answer"], height=250, disabled=True) | |
# tokens=response.get("tokens",{}) | |
# if tokens: | |
# st.markdown(f"🧮 **Token Usage:** Prompt = {tokens.get('prompt')}, " | |
# f"Completion = {tokens.get('completion')}, Total = {tokens.get('total')}") | |
# with st.expander("🔍 Retrieved Chunks"): | |
# context_docs = retriever.get_relevant_documents(query) | |
# for doc in context_docs: | |
# # st.json(doc.metadata) | |
# st.markdown(f"**Chunk ID:** {doc.metadata.get('chunk_id', 'unknown')} | **Title:** {doc.metadata['metadata'].get('title', 'unknown')}") | |
# st.markdown(doc.page_content[:700] + "...") | |
# st.markdown("---") | |
# if __name__ == "__main__": | |
# main() | |