Spaces:
Sleeping
Sleeping
# ############################################################################################################################# | |
# # Filename : app.py | |
# # Description: A Streamlit application to showcase how RAG works. | |
# # Author : Georgios Ioannou | |
# # | |
# # Copyright © 2024 by Georgios Ioannou | |
# ############################################################################################################################# | |
# app.py | |
import os | |
import json | |
from huggingface_hub import HfApi | |
import streamlit as st | |
from typing import List, Dict, Any | |
from urllib.parse import quote_plus | |
from pymongo import MongoClient | |
from PyPDF2 import PdfReader | |
pip install sentence-transformers | |
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import MongoDBAtlasVectorSearch | |
from langchain.prompts import PromptTemplate | |
from langchain.schema import Document | |
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough | |
from huggingface_hub import InferenceClient | |
# =================== Secure Env via Hugging Face Secrets =================== | |
user = quote_plus(os.getenv("MONGO_USERNAME")) | |
password = quote_plus(os.getenv("MONGO_PASSWORD")) | |
cluster = os.getenv("MONGO_CLUSTER") | |
db_name = os.getenv("MONGO_DB_NAME", "files") | |
collection_name = os.getenv("MONGO_COLLECTION", "files_collection") | |
index_name = os.getenv("MONGO_VECTOR_INDEX", "vector_index") | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
MONGO_URI = f"mongodb+srv://{user}:{password}@{cluster}/{db_name}?retryWrites=true&w=majority" | |
# =================== Prompt =================== | |
grantbuddy_prompt = PromptTemplate.from_template( | |
"""You are Grant Buddy, a specialized assistant helping nonprofits apply for grants. | |
Always align answers with the nonprofit’s mission to combat systemic poverty through education, technology, and social innovation. | |
Use the following context to answer the question. Be concise and mission-aligned. | |
CONTEXT: | |
{context} | |
QUESTION: | |
{question} | |
Respond truthfully. If the answer is not available, say "This information is not available in the current context." | |
""" | |
) | |
# =================== Vector Search Setup =================== | |
def init_embedding_model(): | |
return HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") | |
def init_vector_search() -> MongoDBAtlasVectorSearch: | |
# Load local embedding model | |
embedding_model = init_embedding_model() | |
try: | |
# Test embedding | |
test_vector = embedding_model.embed_query("Test query for Grant Buddy") | |
st.success(f"✅ Local embedding model loaded. Vector length: {len(test_vector)}") | |
except Exception as e: | |
st.error("❌ Failed to compute embedding locally") | |
st.error(f"Error: {e}") | |
raise e | |
# MongoDB setup | |
user = quote_plus(os.getenv("MONGO_USERNAME", "").strip()) | |
password = quote_plus(os.getenv("MONGO_PASSWORD", "").strip()) | |
cluster = os.getenv("MONGO_CLUSTER", "").strip() | |
db_name = os.getenv("MONGO_DB_NAME", "files").strip() | |
collection_name = os.getenv("MONGO_COLLECTION", "files_collection").strip() | |
index_name = os.getenv("MONGO_VECTOR_INDEX", "vector_index").strip() | |
MONGO_URI = f"mongodb+srv://{user}:{password}@{cluster}/{db_name}?retryWrites=true&w=majority" | |
try: | |
vector_store = MongoDBAtlasVectorSearch.from_connection_string( | |
connection_string=MONGO_URI, | |
namespace=f"{db_name}.{collection_name}", | |
embedding=embedding_model, | |
index_name=index_name | |
) | |
st.success("✅ Connected to MongoDB Vector Search") | |
return vector_store | |
except Exception as e: | |
st.error("❌ Failed to connect to MongoDB Atlas Vector Search") | |
st.error(f"Error: {e}") | |
raise e | |
# =================== Format Retrieved Chunks =================== | |
def format_docs(docs: List[Document]) -> str: | |
return "\n\n".join(doc.page_content or doc.metadata.get("content", "") for doc in docs) | |
# =================== Generate Response from Hugging Face Model =================== | |
def generate_response(input_dict: Dict[str, Any]) -> str: | |
client = InferenceClient(api_key=HF_TOKEN.strip()) | |
prompt = grantbuddy_prompt.format(**input_dict) | |
try: | |
response = client.chat.completions.create( | |
model="HuggingFaceH4/zephyr-7b-beta", | |
messages=[ | |
{"role": "system", "content": prompt}, | |
{"role": "user", "content": input_dict["question"]}, | |
], | |
max_tokens=1000, | |
temperature=0.2, | |
) | |
return response.choices[0].message.content | |
except Exception as e: | |
st.error(f"❌ Error from model: {e}") | |
return "⚠️ Failed to generate response. Please check your model, HF token, or request format." | |
# =================== RAG Chain =================== | |
def get_rag_chain(retriever): | |
return { | |
"context": retriever | RunnableLambda(format_docs), | |
"question": RunnablePassthrough() | |
} | RunnableLambda(generate_response) | |
# =================== Streamlit UI =================== | |
def main(): | |
st.set_page_config(page_title="Grant Buddy RAG", page_icon="🤖") | |
st.title("🤖 Grant Buddy: Grant-Writing Assistant") | |
uploaded_file = st.file_uploader("Upload PDF or TXT for extra context (optional)", type=["pdf", "txt"]) | |
uploaded_text = "" | |
if uploaded_file: | |
if uploaded_file.name.endswith(".pdf"): | |
reader = PdfReader(uploaded_file) | |
uploaded_text = "\n".join([page.extract_text() for page in reader.pages]) | |
elif uploaded_file.name.endswith(".txt"): | |
uploaded_text = uploaded_file.read().decode("utf-8") | |
retriever = init_vector_search().as_retriever(search_kwargs={"k": 10, "score_threshold": 0.75}) | |
rag_chain = get_rag_chain(retriever) | |
query = st.text_input("Ask a grant-related question") | |
if st.button("Submit"): | |
if not query: | |
st.warning("Please enter a question.") | |
return | |
full_query = f"{query}\n\nAdditional context:\n{uploaded_text}" if uploaded_text else query | |
with st.spinner("Thinking..."): | |
response = rag_chain.invoke(full_query) | |
st.text_area("Grant Buddy says:", value=response, height=250, disabled=True) | |
with st.expander("🔍 Retrieved Chunks"): | |
context_docs = retriever.get_relevant_documents(full_query) | |
for doc in context_docs: | |
st.markdown(f"**Chunk ID:** {doc.metadata.get('chunk_id', 'unknown')}") | |
st.markdown(doc.page_content[:700] + "...") | |
st.markdown("---") | |
if __name__ == "__main__": | |
main() | |