Tesneem's picture
Update app.py
a067cdb verified
raw
history blame
8.65 kB
# #############################################################################################################################
# # Filename : app.py
# # Description: A Streamlit application to showcase how RAG works.
# # Author : Georgios Ioannou
# #
# # Copyright Β© 2024 by Georgios Ioannou
# #############################################################################################################################
# app.py
import os
import re
import json
from huggingface_hub import HfApi
import streamlit as st
from typing import List, Dict, Any
from urllib.parse import quote_plus
from pymongo import MongoClient
from PyPDF2 import PdfReader
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import MongoDBAtlasVectorSearch
from langchain.prompts import PromptTemplate
from langchain.schema import Document
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
from huggingface_hub import InferenceClient
# =================== Secure Env via Hugging Face Secrets ===================
user = quote_plus(os.getenv("MONGO_USERNAME"))
password = quote_plus(os.getenv("MONGO_PASSWORD"))
cluster = os.getenv("MONGO_CLUSTER")
db_name = os.getenv("MONGO_DB_NAME", "files")
collection_name = os.getenv("MONGO_COLLECTION", "files_collection")
index_name = os.getenv("MONGO_VECTOR_INDEX", "vector_index")
HF_TOKEN = os.getenv("HF_TOKEN")
# MONGO_URI = f"mongodb+srv://{user}:{password}@{cluster}/{db_name}?retryWrites=true&w=majority"
MONGO_URI = f"mongodb+srv://{user}:{password}@{cluster}/{db_name}?retryWrites=true&w=majority&tls=true&tlsAllowInvalidCertificates=true"
# =================== Prompt ===================
grantbuddy_prompt = PromptTemplate.from_template(
"""You are Grant Buddy, a specialized language model fine-tuned with instruction-tuning and RLHF.
You help a nonprofit focused on social entrepreneurship, BIPOC empowerment, and edtech write clear, mission-aligned grant responses.
**Instructions:**
- Start with reasoning or context for your answer.
- Always align with the nonprofit’s mission.
- Use structured formatting: headings, bullet points, numbered lists.
- Include impact data or examples if relevant.
- Do NOT repeat the same sentence or answer multiple times.
- If no answer exists in the context, say: "This information is not available in the current context."
CONTEXT:
{context}
QUESTION:
{question}
"""
)
# =================== Vector Search Setup ===================
@st.cache_resource
def init_embedding_model():
return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
@st.cache_resource
def init_vector_search() -> MongoDBAtlasVectorSearch:
HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
model_name = "sentence-transformers/all-MiniLM-L6-v2"
st.write(f"πŸ”Œ Connecting to Hugging Face model: `{model_name}`")
embedding_model = HuggingFaceEmbeddings(model_name=model_name)
# βœ… Manual MongoClient with TLS settings
user = quote_plus(os.getenv("MONGO_USERNAME", "").strip())
password = quote_plus(os.getenv("MONGO_PASSWORD", "").strip())
cluster = os.getenv("MONGO_CLUSTER", "").strip()
db_name = os.getenv("MONGO_DB_NAME", "files").strip()
collection_name = os.getenv("MONGO_COLLECTION", "files_collection").strip()
index_name = os.getenv("MONGO_VECTOR_INDEX", "vector_index").strip()
mongo_uri = f"mongodb+srv://{user}:{password}@{cluster}/?retryWrites=true&w=majority"
try:
client = MongoClient(mongo_uri, tls=True, tlsAllowInvalidCertificates=True, serverSelectionTimeoutMS=20000)
db = client[db_name]
collection = db[collection_name]
st.success("βœ… MongoClient connected successfully")
return MongoDBAtlasVectorSearch(
collection=collection,
embedding=embedding_model,
index_name=index_name,
)
except Exception as e:
st.error("❌ Failed to connect to MongoDB Atlas manually")
st.error(str(e))
raise e
# =================== Question/Headers Extraction ===================
def extract_questions_and_headers(text: str) -> List[str]:
header_patterns = [
r'\d+\.\s+\*\*([^\*]+)\*\*',
r'\*\*([^*]+)\*\*',
r'^([A-Z][^a-z]*[A-Z])$',
r'^([A-Z][A-Za-z\s]{3,})$',
r'^[A-Z][A-Za-z\s]+:$'
]
question_patterns = [
r'^.+\?$',
r'^\*?Please .+',
r'^How .+',
r'^What .+',
r'^Describe .+',
]
combined_header_re = re.compile("|".join(header_patterns), re.MULTILINE)
combined_question_re = re.compile("|".join(question_patterns), re.MULTILINE)
headers = [match for group in combined_header_re.findall(text) for match in group if match]
questions = combined_question_re.findall(text)
return headers + questions
# =================== Format Retrieved Chunks ===================
def format_docs(docs: List[Document]) -> str:
return "\n\n".join(doc.page_content or doc.metadata.get("content", "") for doc in docs)
# =================== Generate Response from Hugging Face Model ===================
def generate_response(input_dict: Dict[str, Any]) -> str:
client = InferenceClient(api_key=HF_TOKEN.strip())
prompt = grantbuddy_prompt.format(**input_dict)
try:
response = client.chat.completions.create(
model="HuggingFaceH4/zephyr-7b-beta",
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": input_dict["question"]},
],
max_tokens=1000,
temperature=0.2,
)
return response.choices[0].message.content
except Exception as e:
st.error(f"❌ Error from model: {e}")
return "⚠️ Failed to generate response. Please check your model, HF token, or request format."
# =================== RAG Chain ===================
def get_rag_chain(retriever):
return {
"context": retriever | RunnableLambda(format_docs),
"question": RunnablePassthrough()
} | RunnableLambda(generate_response)
# =================== Streamlit UI ===================
def main():
st.set_page_config(page_title="Grant Buddy RAG", page_icon="πŸ€–")
st.title("πŸ€– Grant Buddy: Grant-Writing Assistant")
uploaded_file = st.file_uploader("Upload PDF or TXT for extra context (optional)", type=["pdf", "txt"])
uploaded_text = ""
retriever = init_vector_search().as_retriever(search_kwargs={"k": 10, "score_threshold": 0.75})
rag_chain = get_rag_chain(retriever) # βœ… Initialize before usage
# πŸ” Process uploaded file
if uploaded_file:
with st.spinner("πŸ“„ Processing uploaded file..."):
if uploaded_file.name.endswith(".pdf"):
reader = PdfReader(uploaded_file)
uploaded_text = "\n".join([page.extract_text() for page in reader.pages if page.extract_text()])
elif uploaded_file.name.endswith(".txt"):
uploaded_text = uploaded_file.read().decode("utf-8")
questions = extract_questions_and_headers(uploaded_text)
st.success(f"βœ… Found {len(questions)} questions or headers.")
# Generate answers
answers = []
for q in questions:
full_query = f"{q}\n\nAdditional context:\n{uploaded_text}"
response = rag_chain.invoke(full_query)
answers.append({"question": q, "answer": response})
for item in answers:
st.markdown(f"### ❓ {item['question']}")
st.markdown(f"πŸ’¬ {item['answer']}")
# βœ… Manual query box
query = st.text_input("Ask a grant-related question")
if st.button("Submit"):
if not query and not uploaded_file:
st.warning("Please enter a question.")
return
full_query = f"{query}\n\nAdditional context:\n{uploaded_text}" if uploaded_text else query
with st.spinner("πŸ€– Thinking..."):
response = rag_chain.invoke(full_query)
st.text_area("Grant Buddy says:", value=response, height=250, disabled=True)
with st.expander("πŸ” Retrieved Chunks"):
context_docs = retriever.get_relevant_documents(full_query)
for doc in context_docs:
st.markdown(f"**Chunk ID:** {doc.metadata.get('chunk_id', 'unknown','title')}")
st.markdown(doc.page_content[:700] + "...")
st.markdown("---")
if __name__ == "__main__":
main()