Tesneem's picture
Update app.py
da1cdfe verified
raw
history blame
6.79 kB
# #############################################################################################################################
# # Filename : app.py
# # Description: A Streamlit application to showcase how RAG works.
# # Author : Georgios Ioannou
# #
# # Copyright © 2024 by Georgios Ioannou
# #############################################################################################################################
# app.py
import os
import json
from huggingface_hub import HfApi
import streamlit as st
from typing import List, Dict, Any
from urllib.parse import quote_plus
from pymongo import MongoClient
from PyPDF2 import PdfReader
pip install sentence-transformers
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import MongoDBAtlasVectorSearch
from langchain.prompts import PromptTemplate
from langchain.schema import Document
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
from huggingface_hub import InferenceClient
# =================== Secure Env via Hugging Face Secrets ===================
user = quote_plus(os.getenv("MONGO_USERNAME"))
password = quote_plus(os.getenv("MONGO_PASSWORD"))
cluster = os.getenv("MONGO_CLUSTER")
db_name = os.getenv("MONGO_DB_NAME", "files")
collection_name = os.getenv("MONGO_COLLECTION", "files_collection")
index_name = os.getenv("MONGO_VECTOR_INDEX", "vector_index")
HF_TOKEN = os.getenv("HF_TOKEN")
MONGO_URI = f"mongodb+srv://{user}:{password}@{cluster}/{db_name}?retryWrites=true&w=majority"
# =================== Prompt ===================
grantbuddy_prompt = PromptTemplate.from_template(
"""You are Grant Buddy, a specialized assistant helping nonprofits apply for grants.
Always align answers with the nonprofit’s mission to combat systemic poverty through education, technology, and social innovation.
Use the following context to answer the question. Be concise and mission-aligned.
CONTEXT:
{context}
QUESTION:
{question}
Respond truthfully. If the answer is not available, say "This information is not available in the current context."
"""
)
# =================== Vector Search Setup ===================
@st.cache_resource
def init_embedding_model():
return HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
@st.cache_resource
def init_vector_search() -> MongoDBAtlasVectorSearch:
# Load local embedding model
embedding_model = init_embedding_model()
try:
# Test embedding
test_vector = embedding_model.embed_query("Test query for Grant Buddy")
st.success(f"✅ Local embedding model loaded. Vector length: {len(test_vector)}")
except Exception as e:
st.error("❌ Failed to compute embedding locally")
st.error(f"Error: {e}")
raise e
# MongoDB setup
user = quote_plus(os.getenv("MONGO_USERNAME", "").strip())
password = quote_plus(os.getenv("MONGO_PASSWORD", "").strip())
cluster = os.getenv("MONGO_CLUSTER", "").strip()
db_name = os.getenv("MONGO_DB_NAME", "files").strip()
collection_name = os.getenv("MONGO_COLLECTION", "files_collection").strip()
index_name = os.getenv("MONGO_VECTOR_INDEX", "vector_index").strip()
MONGO_URI = f"mongodb+srv://{user}:{password}@{cluster}/{db_name}?retryWrites=true&w=majority"
try:
vector_store = MongoDBAtlasVectorSearch.from_connection_string(
connection_string=MONGO_URI,
namespace=f"{db_name}.{collection_name}",
embedding=embedding_model,
index_name=index_name
)
st.success("✅ Connected to MongoDB Vector Search")
return vector_store
except Exception as e:
st.error("❌ Failed to connect to MongoDB Atlas Vector Search")
st.error(f"Error: {e}")
raise e
# =================== Format Retrieved Chunks ===================
def format_docs(docs: List[Document]) -> str:
return "\n\n".join(doc.page_content or doc.metadata.get("content", "") for doc in docs)
# =================== Generate Response from Hugging Face Model ===================
def generate_response(input_dict: Dict[str, Any]) -> str:
client = InferenceClient(api_key=HF_TOKEN.strip())
prompt = grantbuddy_prompt.format(**input_dict)
try:
response = client.chat.completions.create(
model="HuggingFaceH4/zephyr-7b-beta",
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": input_dict["question"]},
],
max_tokens=1000,
temperature=0.2,
)
return response.choices[0].message.content
except Exception as e:
st.error(f"❌ Error from model: {e}")
return "⚠️ Failed to generate response. Please check your model, HF token, or request format."
# =================== RAG Chain ===================
def get_rag_chain(retriever):
return {
"context": retriever | RunnableLambda(format_docs),
"question": RunnablePassthrough()
} | RunnableLambda(generate_response)
# =================== Streamlit UI ===================
def main():
st.set_page_config(page_title="Grant Buddy RAG", page_icon="🤖")
st.title("🤖 Grant Buddy: Grant-Writing Assistant")
uploaded_file = st.file_uploader("Upload PDF or TXT for extra context (optional)", type=["pdf", "txt"])
uploaded_text = ""
if uploaded_file:
if uploaded_file.name.endswith(".pdf"):
reader = PdfReader(uploaded_file)
uploaded_text = "\n".join([page.extract_text() for page in reader.pages])
elif uploaded_file.name.endswith(".txt"):
uploaded_text = uploaded_file.read().decode("utf-8")
retriever = init_vector_search().as_retriever(search_kwargs={"k": 10, "score_threshold": 0.75})
rag_chain = get_rag_chain(retriever)
query = st.text_input("Ask a grant-related question")
if st.button("Submit"):
if not query:
st.warning("Please enter a question.")
return
full_query = f"{query}\n\nAdditional context:\n{uploaded_text}" if uploaded_text else query
with st.spinner("Thinking..."):
response = rag_chain.invoke(full_query)
st.text_area("Grant Buddy says:", value=response, height=250, disabled=True)
with st.expander("🔍 Retrieved Chunks"):
context_docs = retriever.get_relevant_documents(full_query)
for doc in context_docs:
st.markdown(f"**Chunk ID:** {doc.metadata.get('chunk_id', 'unknown')}")
st.markdown(doc.page_content[:700] + "...")
st.markdown("---")
if __name__ == "__main__":
main()