Spaces:
Sleeping
Sleeping
# ############################################################################################################################# | |
# # Filename : app.py | |
# # Description: A Streamlit application to showcase how RAG works. | |
# # Author : Georgios Ioannou | |
# # | |
# # Copyright Β© 2024 by Georgios Ioannou | |
# ############################################################################################################################# | |
# app.py | |
import os | |
import re | |
import json | |
from huggingface_hub import HfApi | |
import streamlit as st | |
from typing import List, Dict, Any | |
from urllib.parse import quote_plus | |
from pymongo import MongoClient | |
from PyPDF2 import PdfReader | |
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import MongoDBAtlasVectorSearch | |
from langchain.prompts import PromptTemplate | |
from langchain.schema import Document | |
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough | |
from huggingface_hub import InferenceClient | |
# =================== Secure Env via Hugging Face Secrets =================== | |
user = quote_plus(os.getenv("MONGO_USERNAME")) | |
password = quote_plus(os.getenv("MONGO_PASSWORD")) | |
cluster = os.getenv("MONGO_CLUSTER") | |
db_name = os.getenv("MONGO_DB_NAME", "files") | |
collection_name = os.getenv("MONGO_COLLECTION", "files_collection") | |
index_name = os.getenv("MONGO_VECTOR_INDEX", "vector_index") | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
# MONGO_URI = f"mongodb+srv://{user}:{password}@{cluster}/{db_name}?retryWrites=true&w=majority" | |
MONGO_URI = f"mongodb+srv://{user}:{password}@{cluster}/{db_name}?retryWrites=true&w=majority&tls=true&tlsAllowInvalidCertificates=true" | |
# =================== Prompt =================== | |
grantbuddy_prompt = PromptTemplate.from_template( | |
"""You are Grant Buddy, a specialized language model fine-tuned with instruction-tuning and RLHF. | |
You help a nonprofit focused on social entrepreneurship, BIPOC empowerment, and edtech write clear, mission-aligned grant responses. | |
**Instructions:** | |
- Start with reasoning or context for your answer. | |
- Always align with the nonprofitβs mission. | |
- Use structured formatting: headings, bullet points, numbered lists. | |
- Include impact data or examples if relevant. | |
- Do NOT repeat the same sentence or answer multiple times. | |
- If no answer exists in the context, say: "This information is not available in the current context." | |
CONTEXT: | |
{context} | |
QUESTION: | |
{question} | |
""" | |
) | |
# =================== Vector Search Setup =================== | |
def init_embedding_model(): | |
return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
def init_vector_search() -> MongoDBAtlasVectorSearch: | |
HF_TOKEN = os.getenv("HF_TOKEN", "").strip() | |
model_name = "sentence-transformers/all-MiniLM-L6-v2" | |
st.write(f"π Connecting to Hugging Face model: `{model_name}`") | |
embedding_model = HuggingFaceEmbeddings(model_name=model_name) | |
# β Manual MongoClient with TLS settings | |
user = quote_plus(os.getenv("MONGO_USERNAME", "").strip()) | |
password = quote_plus(os.getenv("MONGO_PASSWORD", "").strip()) | |
cluster = os.getenv("MONGO_CLUSTER", "").strip() | |
db_name = os.getenv("MONGO_DB_NAME", "files").strip() | |
collection_name = os.getenv("MONGO_COLLECTION", "files_collection").strip() | |
index_name = os.getenv("MONGO_VECTOR_INDEX", "vector_index").strip() | |
mongo_uri = f"mongodb+srv://{user}:{password}@{cluster}/?retryWrites=true&w=majority" | |
try: | |
client = MongoClient(mongo_uri, tls=True, tlsAllowInvalidCertificates=True, serverSelectionTimeoutMS=20000) | |
db = client[db_name] | |
collection = db[collection_name] | |
st.success("β MongoClient connected successfully") | |
return MongoDBAtlasVectorSearch( | |
collection=collection, | |
embedding=embedding_model, | |
index_name=index_name, | |
) | |
except Exception as e: | |
st.error("β Failed to connect to MongoDB Atlas manually") | |
st.error(str(e)) | |
raise e | |
# =================== Question/Headers Extraction =================== | |
def extract_questions_and_headers(text: str) -> List[str]: | |
header_patterns = [ | |
r'\d+\.\s+\*\*([^\*]+)\*\*', | |
r'\*\*([^*]+)\*\*', | |
r'^([A-Z][^a-z]*[A-Z])$', | |
r'^([A-Z][A-Za-z\s]{3,})$', | |
r'^[A-Z][A-Za-z\s]+:$' | |
] | |
question_patterns = [ | |
r'^.+\?$', | |
r'^\*?Please .+', | |
r'^How .+', | |
r'^What .+', | |
r'^Describe .+', | |
] | |
combined_header_re = re.compile("|".join(header_patterns), re.MULTILINE) | |
combined_question_re = re.compile("|".join(question_patterns), re.MULTILINE) | |
headers = [match for group in combined_header_re.findall(text) for match in group if match] | |
questions = combined_question_re.findall(text) | |
return headers + questions | |
# =================== Format Retrieved Chunks =================== | |
def format_docs(docs: List[Document]) -> str: | |
return "\n\n".join(doc.page_content or doc.metadata.get("content", "") for doc in docs) | |
# =================== Generate Response from Hugging Face Model =================== | |
def generate_response(input_dict: Dict[str, Any]) -> str: | |
client = InferenceClient(api_key=HF_TOKEN.strip()) | |
prompt = grantbuddy_prompt.format(**input_dict) | |
try: | |
response = client.chat.completions.create( | |
model="HuggingFaceH4/zephyr-7b-beta", | |
messages=[ | |
{"role": "system", "content": prompt}, | |
{"role": "user", "content": input_dict["question"]}, | |
], | |
max_tokens=1000, | |
temperature=0.2, | |
) | |
return response.choices[0].message.content | |
except Exception as e: | |
st.error(f"β Error from model: {e}") | |
return "β οΈ Failed to generate response. Please check your model, HF token, or request format." | |
# =================== RAG Chain =================== | |
def get_rag_chain(retriever): | |
return { | |
"context": retriever | RunnableLambda(format_docs), | |
"question": RunnablePassthrough() | |
} | RunnableLambda(generate_response) | |
# =================== Streamlit UI =================== | |
def main(): | |
st.set_page_config(page_title="Grant Buddy RAG", page_icon="π€") | |
st.title("π€ Grant Buddy: Grant-Writing Assistant") | |
uploaded_file = st.file_uploader("Upload PDF or TXT for extra context (optional)", type=["pdf", "txt"]) | |
uploaded_text = "" | |
retriever = init_vector_search().as_retriever(search_kwargs={"k": 10, "score_threshold": 0.75}) | |
rag_chain = get_rag_chain(retriever) # β Initialize before usage | |
# π Process uploaded file | |
if uploaded_file: | |
with st.spinner("π Processing uploaded file..."): | |
if uploaded_file.name.endswith(".pdf"): | |
reader = PdfReader(uploaded_file) | |
uploaded_text = "\n".join([page.extract_text() for page in reader.pages if page.extract_text()]) | |
elif uploaded_file.name.endswith(".txt"): | |
uploaded_text = uploaded_file.read().decode("utf-8") | |
questions = extract_questions_and_headers(uploaded_text) | |
st.success(f"β Found {len(questions)} questions or headers.") | |
# Generate answers | |
answers = [] | |
for q in questions: | |
full_query = f"{q}\n\nAdditional context:\n{uploaded_text}" | |
response = rag_chain.invoke(full_query) | |
answers.append({"question": q, "answer": response}) | |
for item in answers: | |
st.markdown(f"### β {item['question']}") | |
st.markdown(f"π¬ {item['answer']}") | |
# β Manual query box | |
query = st.text_input("Ask a grant-related question") | |
if st.button("Submit"): | |
if not query and not uploaded_file: | |
st.warning("Please enter a question.") | |
return | |
full_query = f"{query}\n\nAdditional context:\n{uploaded_text}" if uploaded_text else query | |
with st.spinner("π€ Thinking..."): | |
response = rag_chain.invoke(full_query) | |
st.text_area("Grant Buddy says:", value=response, height=250, disabled=True) | |
with st.expander("π Retrieved Chunks"): | |
context_docs = retriever.get_relevant_documents(full_query) | |
for doc in context_docs: | |
st.markdown(f"**Chunk ID:** {doc.metadata.get('chunk_id', 'unknown','title')}") | |
st.markdown(doc.page_content[:700] + "...") | |
st.markdown("---") | |
if __name__ == "__main__": | |
main() | |