Spaces:
Sleeping
Sleeping
# ############################################################################################################################# | |
# # Filename : app.py | |
# # Description: A Streamlit application to showcase how RAG works. | |
# # Author : Georgios Ioannou | |
# # | |
# # Copyright © 2024 by Georgios Ioannou | |
# ############################################################################################################################# | |
# app.py | |
import os | |
import re | |
from huggingface_hub import InferenceClient | |
import json | |
from huggingface_hub import HfApi | |
import streamlit as st | |
from typing import List, Dict, Any | |
from urllib.parse import quote_plus | |
from pymongo import MongoClient | |
from PyPDF2 import PdfReader | |
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import MongoDBAtlasVectorSearch | |
from langchain.prompts import PromptTemplate | |
from langchain.schema import Document | |
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough | |
from huggingface_hub import InferenceClient | |
# =================== Secure Env via Hugging Face Secrets =================== | |
user = quote_plus(os.getenv("MONGO_USERNAME")) | |
password = quote_plus(os.getenv("MONGO_PASSWORD")) | |
cluster = os.getenv("MONGO_CLUSTER") | |
db_name = os.getenv("MONGO_DB_NAME", "files") | |
collection_name = os.getenv("MONGO_COLLECTION", "files_collection") | |
index_name = os.getenv("MONGO_VECTOR_INDEX", "vector_index") | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
# MONGO_URI = f"mongodb+srv://{user}:{password}@{cluster}/{db_name}?retryWrites=true&w=majority" | |
MONGO_URI = f"mongodb+srv://{user}:{password}@{cluster}/{db_name}?retryWrites=true&w=majority&tls=true&tlsAllowInvalidCertificates=true" | |
# =================== Prompt =================== | |
grantbuddy_prompt = PromptTemplate.from_template( | |
"""You are Grant Buddy, a specialized language model fine-tuned with instruction-tuning and RLHF. | |
You help a nonprofit focused on social entrepreneurship, BIPOC empowerment, and edtech write clear, mission-aligned grant responses. | |
**Instructions:** | |
- Start with reasoning or context for your answer. | |
- Always align with the nonprofit’s mission. | |
- Use structured formatting: headings, bullet points, numbered lists. | |
- Include impact data or examples if relevant. | |
- Do NOT repeat the same sentence or answer multiple times. | |
- If no answer exists in the context, say: "This information is not available in the current context." | |
CONTEXT: | |
{context} | |
QUESTION: | |
{question} | |
""" | |
) | |
# =================== Vector Search Setup =================== | |
def init_embedding_model(): | |
return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
def init_vector_search() -> MongoDBAtlasVectorSearch: | |
HF_TOKEN = os.getenv("HF_TOKEN", "").strip() | |
model_name = "sentence-transformers/all-MiniLM-L6-v2" | |
st.write(f"🔌 Connecting to Hugging Face model: `{model_name}`") | |
embedding_model = HuggingFaceEmbeddings(model_name=model_name) | |
# ✅ Manual MongoClient with TLS settings | |
user = quote_plus(os.getenv("MONGO_USERNAME", "").strip()) | |
password = quote_plus(os.getenv("MONGO_PASSWORD", "").strip()) | |
cluster = os.getenv("MONGO_CLUSTER", "").strip() | |
db_name = os.getenv("MONGO_DB_NAME", "files").strip() | |
collection_name = os.getenv("MONGO_COLLECTION", "files_collection").strip() | |
index_name = os.getenv("MONGO_VECTOR_INDEX", "vector_index").strip() | |
mongo_uri = f"mongodb+srv://{user}:{password}@{cluster}/?retryWrites=true&w=majority" | |
try: | |
client = MongoClient(mongo_uri, tls=True, tlsAllowInvalidCertificates=True, serverSelectionTimeoutMS=20000) | |
db = client[db_name] | |
collection = db[collection_name] | |
st.success("✅ MongoClient connected successfully") | |
return MongoDBAtlasVectorSearch( | |
collection=collection, | |
embedding=embedding_model, | |
index_name=index_name, | |
) | |
except Exception as e: | |
st.error("❌ Failed to connect to MongoDB Atlas manually") | |
st.error(str(e)) | |
raise e | |
# =================== Question/Headers Extraction =================== | |
# def extract_questions_and_headers(text: str) -> List[str]: | |
# header_patterns = [ | |
# r'\d+\.\s+\*\*([^\*]+)\*\*', | |
# r'\*\*([^*]+)\*\*', | |
# r'^([A-Z][^a-z]*[A-Z])$', | |
# r'^([A-Z][A-Za-z\s]{3,})$', | |
# r'^[A-Z][A-Za-z\s]+:$' | |
# ] | |
# question_patterns = [ | |
# r'^.+\?$', | |
# r'^\*?Please .+', | |
# r'^How .+', | |
# r'^What .+', | |
# r'^Describe .+', | |
# ] | |
# combined_header_re = re.compile("|".join(header_patterns), re.MULTILINE) | |
# combined_question_re = re.compile("|".join(question_patterns), re.MULTILINE) | |
# headers = [match for group in combined_header_re.findall(text) for match in group if match] | |
# questions = combined_question_re.findall(text) | |
# return headers + questions | |
def extract_with_llm(text: str) -> List[str]: | |
client = InferenceClient(api_key=HF_TOKEN.strip()) | |
try: | |
response = client.chat.completions.create( | |
model="mistralai/Mistral-Nemo-Instruct-2407", # or "HuggingFaceH4/zephyr-7b-beta" | |
messages=[ | |
{ | |
"role": "system", | |
"content": "You are an assistant helping extract questions and headers from grant applications.", | |
}, | |
{ | |
"role": "user", | |
"content": ( | |
"Please extract all the grant application headers and questions from the following text. " | |
"Include section titles, prompts, and any question-like content. Return them as a numbered list.\n\n" | |
f"{text[:3000]}" | |
), | |
}, | |
], | |
temperature=0.2, | |
max_tokens=512, | |
) | |
return [ | |
line.strip("•-1234567890. ").strip() | |
for line in response.choices[0].message.content.strip().split("\n") | |
if line.strip() | |
] | |
except Exception as e: | |
st.error("❌ LLM extraction failed") | |
st.error(str(e)) | |
return [] | |
# =================== Format Retrieved Chunks =================== | |
def format_docs(docs: List[Document]) -> str: | |
return "\n\n".join(doc.page_content or doc.metadata.get("content", "") for doc in docs) | |
# =================== Generate Response from Hugging Face Model =================== | |
def generate_response(input_dict: Dict[str, Any]) -> str: | |
client = InferenceClient(api_key=HF_TOKEN.strip()) | |
prompt = grantbuddy_prompt.format(**input_dict) | |
try: | |
response = client.chat.completions.create( | |
model="HuggingFaceH4/zephyr-7b-beta", | |
messages=[ | |
{"role": "system", "content": prompt}, | |
{"role": "user", "content": input_dict["question"]}, | |
], | |
max_tokens=1000, | |
temperature=0.2, | |
) | |
return response.choices[0].message.content | |
except Exception as e: | |
st.error(f"❌ Error from model: {e}") | |
return "⚠️ Failed to generate response. Please check your model, HF token, or request format." | |
# =================== RAG Chain =================== | |
def get_rag_chain(retriever): | |
return { | |
"context": retriever | RunnableLambda(format_docs), | |
"question": RunnablePassthrough() | |
} | RunnableLambda(generate_response) | |
# =================== Streamlit UI =================== | |
def main(): | |
st.set_page_config(page_title="Grant Buddy RAG", page_icon="🤖") | |
st.title("🤖 Grant Buddy: Grant-Writing Assistant") | |
uploaded_file = st.file_uploader("Upload PDF or TXT for extra context (optional)", type=["pdf", "txt"]) | |
uploaded_text = "" | |
retriever = init_vector_search().as_retriever(search_kwargs={"k": 10, "score_threshold": 0.75}) | |
rag_chain = get_rag_chain(retriever) # ✅ Initialize before usage | |
# 🔁 Process uploaded file | |
if uploaded_file: | |
with st.spinner("📄 Processing uploaded file..."): | |
if uploaded_file.name.endswith(".pdf"): | |
reader = PdfReader(uploaded_file) | |
uploaded_text = "\n".join([page.extract_text() for page in reader.pages if page.extract_text()]) | |
elif uploaded_file.name.endswith(".txt"): | |
uploaded_text = uploaded_file.read().decode("utf-8") | |
questions = extract_with_llm(uploaded_text) | |
st.success(f"✅ Found {len(questions)} questions or headers.") | |
with st.expander("🧠 Extracted Prompts from Upload"): | |
st.write(questions) | |
# Generate answers | |
answers = [] | |
for q in questions: | |
full_query = f"{q}\n\nAdditional context:\n{uploaded_text}" | |
response = rag_chain.invoke(full_query) | |
answers.append({"question": q, "answer": response}) | |
for item in answers: | |
st.markdown(f"### ❓ {item['question']}") | |
st.markdown(f"💬 {item['answer']}") | |
# ✅ Manual query box | |
query = st.text_input("Ask a grant-related question") | |
if st.button("Submit"): | |
if not query and not uploaded_file: | |
st.warning("Please enter a question.") | |
return | |
full_query = f"{query}\n\nAdditional context:\n{uploaded_text}" if uploaded_text else query | |
with st.spinner("🤖 Thinking..."): | |
response = rag_chain.invoke(full_query) | |
st.text_area("Grant Buddy says:", value=response, height=250, disabled=True) | |
with st.expander("🔍 Retrieved Chunks"): | |
context_docs = retriever.get_relevant_documents(full_query) | |
for doc in context_docs: | |
st.markdown(f"**Chunk ID:** {doc.metadata.get('chunk_id', 'unknown','title')}") | |
st.markdown(doc.page_content[:700] + "...") | |
st.markdown("---") | |
if __name__ == "__main__": | |
main() | |