Spaces:
Running
Running
# ############################################################################################################################# | |
# # Filename : app.py | |
# # Description: A Streamlit application to showcase how RAG works. | |
# # Author : Georgios Ioannou | |
# # | |
# # Copyright © 2024 by Georgios Ioannou | |
# ############################################################################################################################# | |
# app.py | |
import os | |
import json | |
from huggingface_hub import HfApi | |
import streamlit as st | |
from typing import List, Dict, Any | |
from urllib.parse import quote_plus | |
from pymongo import MongoClient | |
from PyPDF2 import PdfReader | |
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings | |
from langchain_community.vectorstores import MongoDBAtlasVectorSearch | |
from langchain.prompts import PromptTemplate | |
from langchain.schema import Document | |
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough | |
from huggingface_hub import InferenceClient | |
# =================== Secure Env via Hugging Face Secrets =================== | |
user = quote_plus(os.getenv("MONGO_USERNAME")) | |
password = quote_plus(os.getenv("MONGO_PASSWORD")) | |
cluster = os.getenv("MONGO_CLUSTER") | |
db_name = os.getenv("MONGO_DB_NAME", "files") | |
collection_name = os.getenv("MONGO_COLLECTION", "files_collection") | |
index_name = os.getenv("MONGO_VECTOR_INDEX", "vector_index") | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
MONGO_URI = f"mongodb+srv://{user}:{password}@{cluster}/{db_name}?retryWrites=true&w=majority" | |
# =================== Prompt =================== | |
grantbuddy_prompt = PromptTemplate.from_template( | |
"""You are Grant Buddy, a specialized assistant helping nonprofits apply for grants. | |
Always align answers with the nonprofit’s mission to combat systemic poverty through education, technology, and social innovation. | |
Use the following context to answer the question. Be concise and mission-aligned. | |
CONTEXT: | |
{context} | |
QUESTION: | |
{question} | |
Respond truthfully. If the answer is not available, say "This information is not available in the current context." | |
""" | |
) | |
# =================== Vector Search Setup =================== | |
def init_vector_search() -> MongoDBAtlasVectorSearch: | |
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings | |
from langchain_community.vectorstores import MongoDBAtlasVectorSearch | |
HF_TOKEN = os.getenv("HF_TOKEN", "").strip() | |
model_name = "thenlper/gte-small" | |
try: | |
st.write(f"🔌 Connecting to Hugging Face model: `{model_name}`") | |
embedding_model = HuggingFaceInferenceAPIEmbeddings( | |
api_key=HF_TOKEN, | |
model_name=model_name | |
) | |
# Test if embedding works | |
test_vector = embedding_model.embed_query("Test query for Grant Buddy") | |
st.success(f"✅ HF embedding model connected. Vector length: {len(test_vector)}") | |
except Exception as e: | |
st.error("❌ Failed to connect to Hugging Face Embedding API") | |
st.error(f"Error: {e}") | |
raise e # Stop app here if embedding fails | |
# MongoDB setup | |
user = quote_plus(os.getenv("MONGO_USERNAME", "").strip()) | |
password = quote_plus(os.getenv("MONGO_PASSWORD", "").strip()) | |
cluster = os.getenv("MONGO_CLUSTER", "").strip() | |
db_name = os.getenv("MONGO_DB_NAME", "files").strip() | |
collection_name = os.getenv("MONGO_COLLECTION", "files_collection").strip() | |
index_name = os.getenv("MONGO_VECTOR_INDEX", "vector_index").strip() | |
MONGO_URI = f"mongodb+srv://{user}:{password}@{cluster}/{db_name}?retryWrites=true&w=majority" | |
# Connect to vector search | |
try: | |
vector_store = MongoDBAtlasVectorSearch.from_connection_string( | |
connection_string=MONGO_URI, | |
namespace=f"{db_name}.{collection_name}", | |
embedding=embedding_model, | |
index_name=index_name | |
) | |
st.success("✅ Connected to MongoDB Vector Search") | |
return vector_store | |
except Exception as e: | |
st.error("❌ Failed to connect to MongoDB Atlas Vector Search") | |
st.error(f"Error: {e}") | |
raise e | |
# =================== Format Retrieved Chunks =================== | |
def format_docs(docs: List[Document]) -> str: | |
return "\n\n".join(doc.page_content or doc.metadata.get("content", "") for doc in docs) | |
# =================== Generate Response from Hugging Face Model =================== | |
def generate_response(input_dict: Dict[str, Any]) -> str: | |
client = InferenceClient(api_key=HF_TOKEN.strip()) | |
prompt = grantbuddy_prompt.format(**input_dict) | |
try: | |
response = client.chat.completions.create( | |
model="HuggingFaceH4/zephyr-7b-beta", | |
messages=[ | |
{"role": "system", "content": prompt}, | |
{"role": "user", "content": input_dict["question"]}, | |
], | |
max_tokens=1000, | |
temperature=0.2, | |
) | |
return response.choices[0].message.content | |
except Exception as e: | |
st.error(f"❌ Error from model: {e}") | |
return "⚠️ Failed to generate response. Please check your model, HF token, or request format." | |
# =================== RAG Chain =================== | |
def get_rag_chain(retriever): | |
return { | |
"context": retriever | RunnableLambda(format_docs), | |
"question": RunnablePassthrough() | |
} | RunnableLambda(generate_response) | |
# =================== Streamlit UI =================== | |
def main(): | |
st.set_page_config(page_title="Grant Buddy RAG", page_icon="🤖") | |
st.title("🤖 Grant Buddy: Grant-Writing Assistant") | |
uploaded_file = st.file_uploader("Upload PDF or TXT for extra context (optional)", type=["pdf", "txt"]) | |
uploaded_text = "" | |
if uploaded_file: | |
if uploaded_file.name.endswith(".pdf"): | |
reader = PdfReader(uploaded_file) | |
uploaded_text = "\n".join([page.extract_text() for page in reader.pages]) | |
elif uploaded_file.name.endswith(".txt"): | |
uploaded_text = uploaded_file.read().decode("utf-8") | |
retriever = init_vector_search().as_retriever(search_kwargs={"k": 10, "score_threshold": 0.75}) | |
rag_chain = get_rag_chain(retriever) | |
query = st.text_input("Ask a grant-related question") | |
if st.button("Submit"): | |
if not query: | |
st.warning("Please enter a question.") | |
return | |
full_query = f"{query}\n\nAdditional context:\n{uploaded_text}" if uploaded_text else query | |
with st.spinner("Thinking..."): | |
response = rag_chain.invoke(full_query) | |
st.text_area("Grant Buddy says:", value=response, height=250, disabled=True) | |
with st.expander("🔍 Retrieved Chunks"): | |
context_docs = retriever.get_relevant_documents(full_query) | |
for doc in context_docs: | |
st.markdown(f"**Chunk ID:** {doc.metadata.get('chunk_id', 'unknown')}") | |
st.markdown(doc.page_content[:700] + "...") | |
st.markdown("---") | |
if __name__ == "__main__": | |
main() | |
# # Import libraries. | |
# import os | |
# import streamlit as st | |
# from dotenv import load_dotenv, find_dotenv | |
# from huggingface_hub import InferenceClient | |
# from langchain.prompts import PromptTemplate | |
# from langchain.schema import Document | |
# from langchain.schema.runnable import RunnablePassthrough, RunnableLambda | |
# # from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings | |
# from langchain.embeddings import OpenAIEmbeddings | |
# from langchain_community.vectorstores import MongoDBAtlasVectorSearch | |
# from pymongo import MongoClient | |
# from pymongo.collection import Collection | |
# from typing import Dict, Any | |
# from langchain.chat_models import ChatOpenAI | |
# ############################################################################################################################# | |
# class RAGQuestionAnswering: | |
# def __init__(self): | |
# """ | |
# Parameters | |
# ---------- | |
# None | |
# Output | |
# ------ | |
# None | |
# Purpose | |
# ------- | |
# Initializes the RAG Question Answering system by setting up configuration | |
# and loading environment variables. | |
# Assumptions | |
# ----------- | |
# - Expects .env file with MONGO_URI and HF_TOKEN | |
# - Requires proper MongoDB setup with vector search index | |
# - Needs connection to Hugging Face API | |
# Notes | |
# ----- | |
# This is the main class that handles all RAG operations | |
# """ | |
# self.load_environment() | |
# self.setup_mongodb() | |
# self.setup_embedding_model() | |
# self.setup_vector_search() | |
# self.setup_rag_chain() | |
# def load_environment(self) -> None: | |
# """ | |
# Parameters | |
# ---------- | |
# None | |
# Output | |
# ------ | |
# None | |
# Purpose | |
# ------- | |
# Loads environment variables from .env file and sets up configuration constants. | |
# Assumptions | |
# ----------- | |
# Expects a .env file with MONGO_URI and HF_TOKEN defined | |
# Notes | |
# ----- | |
# Will stop the application if required environment variables are missing | |
# """ | |
# load_dotenv(find_dotenv()) | |
# self.MONGO_URI = os.getenv("MONGO_URI") | |
# # self.HF_TOKEN = os.getenv("HF_TOKEN") | |
# self.OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
# if not self.MONGO_URI or not self.OPENAI_API_KEY: | |
# st.error("Please ensure MONGO_URI and OPENAI_API_KEY are set in your .env file") | |
# st.stop() | |
# # MongoDB configuration. | |
# self.DB_NAME = "txts" | |
# self.COLLECTION_NAME = "txts_collection" | |
# self.VECTOR_SEARCH_INDEX = "vector_index" | |
# def setup_mongodb(self) -> None: | |
# """ | |
# Parameters | |
# ---------- | |
# None | |
# Output | |
# ------ | |
# None | |
# Purpose | |
# ------- | |
# Initializes the MongoDB connection and sets up the collection. | |
# Assumptions | |
# ----------- | |
# - Valid MongoDB URI is available | |
# - Database and collection exist in MongoDB Atlas | |
# Notes | |
# ----- | |
# Uses st.cache_resource for efficient connection management | |
# """ | |
# @st.cache_resource | |
# def init_mongodb() -> Collection: | |
# cluster = MongoClient(self.MONGO_URI) | |
# return cluster[self.DB_NAME][self.COLLECTION_NAME] | |
# self.mongodb_collection = init_mongodb() | |
# def setup_embedding_model(self) -> None: | |
# """ | |
# Parameters | |
# ---------- | |
# None | |
# Output | |
# ------ | |
# None | |
# Purpose | |
# ------- | |
# Initializes the embedding model for vector search. | |
# Assumptions | |
# ----------- | |
# - Valid Hugging Face API token | |
# - Internet connection to access the model | |
# Notes | |
# ----- | |
# Uses the all-mpnet-base-v2 model from sentence-transformers | |
# """ | |
# # @st.cache_resource | |
# # def init_embedding_model() -> HuggingFaceInferenceAPIEmbeddings: | |
# # return HuggingFaceInferenceAPIEmbeddings( | |
# # api_key=self.HF_TOKEN, | |
# # model_name="sentence-transformers/all-mpnet-base-v2", | |
# # ) | |
# @st.cache_resource | |
# def init_embedding_model() -> OpenAIEmbeddings: | |
# return OpenAIEmbeddings(model="text-embedding-3-small", openai_api_key=self.OPENAI_API_KEY) | |
# self.embedding_model = init_embedding_model() | |
# def setup_vector_search(self) -> None: | |
# """ | |
# Parameters | |
# ---------- | |
# None | |
# Output | |
# ------ | |
# None | |
# Purpose | |
# ------- | |
# Sets up the vector search functionality using MongoDB Atlas. | |
# Assumptions | |
# ----------- | |
# - MongoDB Atlas vector search index is properly configured | |
# - Valid embedding model is initialized | |
# Notes | |
# ----- | |
# Creates a retriever with similarity search and score threshold | |
# """ | |
# @st.cache_resource | |
# def init_vector_search() -> MongoDBAtlasVectorSearch: | |
# return MongoDBAtlasVectorSearch.from_connection_string( | |
# connection_string=self.MONGO_URI, | |
# namespace=f"{self.DB_NAME}.{self.COLLECTION_NAME}", | |
# embedding=self.embedding_model, | |
# index_name=self.VECTOR_SEARCH_INDEX, | |
# ) | |
# self.vector_search = init_vector_search() | |
# self.retriever = self.vector_search.as_retriever( | |
# search_type="similarity", search_kwargs={"k": 10, "score_threshold": 0.85} | |
# ) | |
# def format_docs(self, docs: list[Document]) -> str: | |
# """ | |
# Parameters | |
# ---------- | |
# **docs:** list[Document] - List of documents to be formatted | |
# Output | |
# ------ | |
# str: Formatted string containing concatenated document content | |
# Purpose | |
# ------- | |
# Formats the retrieved documents into a single string for processing | |
# Assumptions | |
# ----------- | |
# Documents have page_content attribute | |
# Notes | |
# ----- | |
# Joins documents with double newlines for better readability | |
# """ | |
# return "\n\n".join(doc.page_content for doc in docs) | |
# # def generate_response(self, input_dict: Dict[str, Any]) -> str: | |
# # """ | |
# # Parameters | |
# # ---------- | |
# # **input_dict:** Dict[str, Any] - Dictionary containing context and question | |
# # Output | |
# # ------ | |
# # str: Generated response from the model | |
# # Purpose | |
# # ------- | |
# # Generates a response using the Hugging Face model based on context and question | |
# # Assumptions | |
# # ----------- | |
# # - Valid Hugging Face API token | |
# # - Input dictionary contains 'context' and 'question' keys | |
# # Notes | |
# # ----- | |
# # Uses Zephyr model with controlled temperature | |
# # """ | |
# # hf_client = InferenceClient(api_key=self.HF_TOKEN) | |
# # formatted_prompt = self.prompt.format(**input_dict) | |
# # response = hf_client.chat.completions.create( | |
# # model="HuggingFaceH4/zephyr-7b-beta" | |
# # messages=[ | |
# # {"role": "system", "content": formatted_prompt}, | |
# # {"role": "user", "content": input_dict["question"]}, | |
# # ], | |
# # max_tokens=1000, | |
# # temperature=0.2, | |
# # ) | |
# # return response.choices[0].message.content | |
# from langchain.chat_models import ChatOpenAI | |
# from langchain.schema.messages import SystemMessage, HumanMessage | |
# def generate_response(self, input_dict: Dict[str, Any]) -> str: | |
# llm = ChatOpenAI( | |
# model="gpt-4", # or "gpt-3.5-turbo" | |
# temperature=0.2, | |
# openai_api_key=self.OPENAI_API_KEY, | |
# ) | |
# messages = [ | |
# SystemMessage(content=self.prompt.format(**input_dict)), | |
# HumanMessage(content=input_dict["question"]), | |
# ] | |
# return llm(messages).content | |
# def setup_rag_chain(self) -> None: | |
# """ | |
# Parameters | |
# ---------- | |
# None | |
# Output | |
# ------ | |
# None | |
# Purpose | |
# ------- | |
# Sets up the RAG chain for processing questions and generating answers | |
# Assumptions | |
# ----------- | |
# Retriever and response generator are properly initialized | |
# Notes | |
# ----- | |
# Creates a chain that combines retrieval and response generation | |
# """ | |
# self.prompt = PromptTemplate.from_template( | |
# """Use the following pieces of context to answer the question at the end. | |
# START OF CONTEXT: | |
# {context} | |
# END OF CONTEXT: | |
# START OF QUESTION: | |
# {question} | |
# END OF QUESTION: | |
# If you do not know the answer, just say that you do not know. | |
# NEVER assume things. | |
# """ | |
# ) | |
# self.rag_chain = { | |
# "context": self.retriever | RunnableLambda(self.format_docs), | |
# "question": RunnablePassthrough(), | |
# } | RunnableLambda(self.generate_response) | |
# def process_question(self, question: str) -> str: | |
# """ | |
# Parameters | |
# ---------- | |
# **question:** str - The user's question to be answered | |
# Output | |
# ------ | |
# str: The generated answer to the question | |
# Purpose | |
# ------- | |
# Processes a user question through the RAG chain and returns an answer | |
# Assumptions | |
# ----------- | |
# - Question is a non-empty string | |
# - RAG chain is properly initialized | |
# Notes | |
# ----- | |
# Main interface for question-answering functionality | |
# """ | |
# return self.rag_chain.invoke(question) | |
# ############################################################################################################################# | |
# def setup_streamlit_ui() -> None: | |
# """ | |
# Parameters | |
# ---------- | |
# None | |
# Output | |
# ------ | |
# None | |
# Purpose | |
# ------- | |
# Sets up the Streamlit user interface with proper styling and layout | |
# Assumptions | |
# ----------- | |
# - CSS file exists at ./static/styles/style.css | |
# - Image file exists at ./static/images/ctp.png | |
# Notes | |
# ----- | |
# Handles all UI-related setup and styling | |
# """ | |
# st.set_page_config(page_title="RAG Question Answering", page_icon="🤖") | |
# # Load CSS. | |
# with open("./static/styles/style.css") as f: | |
# st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) | |
# # Title and subtitles. | |
# st.markdown( | |
# '<h1 align="center" style="font-family: monospace; font-size: 2.1rem; margin-top: -4rem">RAG Question Answering</h1>', | |
# unsafe_allow_html=True, | |
# ) | |
# st.markdown( | |
# '<h3 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: -2rem">Using Zoom Closed Captioning From The Lectures</h3>', | |
# unsafe_allow_html=True, | |
# ) | |
# st.markdown( | |
# '<h2 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: 0rem">CUNY Tech Prep Tutorial 5</h2>', | |
# unsafe_allow_html=True, | |
# ) | |
# # Display logo. | |
# left_co, cent_co, last_co = st.columns(3) | |
# with cent_co: | |
# st.image("./static/images/ctp.png") | |
# ############################################################################################################################# | |
# def main(): | |
# """ | |
# Parameters | |
# ---------- | |
# None | |
# Output | |
# ------ | |
# None | |
# Purpose | |
# ------- | |
# Main function that runs the Streamlit application | |
# Assumptions | |
# ----------- | |
# All required environment variables and files are present | |
# Notes | |
# ----- | |
# Entry point for the application | |
# """ | |
# # Setup UI. | |
# setup_streamlit_ui() | |
# # Initialize RAG system. | |
# rag_system = RAGQuestionAnswering() | |
# # Create input elements. | |
# query = st.text_input("Question:", key="question_input") | |
# # Handle submission. | |
# if st.button("Submit", type="primary"): | |
# if query: | |
# with st.spinner("Generating response..."): | |
# response = rag_system.process_question(query) | |
# st.text_area("Answer:", value=response, height=200, disabled=True) | |
# else: | |
# st.warning("Please enter a question.") | |
# # Add GitHub link. | |
# st.markdown( | |
# """ | |
# <p align="center" style="font-family: monospace; color: #FAF9F6; font-size: 1rem;"> | |
# <b>Check out our <a href="https://github.com/GeorgiosIoannouCoder/" style="color: #FAF9F6;">GitHub repository</a></b> | |
# </p> | |
# """, | |
# unsafe_allow_html=True, | |
# ) | |
# ############################################################################################################################# | |
# if __name__ == "__main__": | |
# main() | |