Tesneem's picture
Update app.py
da5e58a verified
raw
history blame
10.1 kB
# #############################################################################################################################
# # Filename : app.py
# # Description: A Streamlit application to showcase how RAG works.
# # Author : Georgios Ioannou
# #
# # Copyright © 2024 by Georgios Ioannou
# #############################################################################################################################
# app.py
import os
import re
from huggingface_hub import InferenceClient
import json
from huggingface_hub import HfApi
import streamlit as st
from typing import List, Dict, Any
from urllib.parse import quote_plus
from pymongo import MongoClient
from PyPDF2 import PdfReader
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import MongoDBAtlasVectorSearch
from langchain.prompts import PromptTemplate
from langchain.schema import Document
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
from huggingface_hub import InferenceClient
# =================== Secure Env via Hugging Face Secrets ===================
user = quote_plus(os.getenv("MONGO_USERNAME"))
password = quote_plus(os.getenv("MONGO_PASSWORD"))
cluster = os.getenv("MONGO_CLUSTER")
db_name = os.getenv("MONGO_DB_NAME", "files")
collection_name = os.getenv("MONGO_COLLECTION", "files_collection")
index_name = os.getenv("MONGO_VECTOR_INDEX", "vector_index")
HF_TOKEN = os.getenv("HF_TOKEN")
# MONGO_URI = f"mongodb+srv://{user}:{password}@{cluster}/{db_name}?retryWrites=true&w=majority"
MONGO_URI = f"mongodb+srv://{user}:{password}@{cluster}/{db_name}?retryWrites=true&w=majority&tls=true&tlsAllowInvalidCertificates=true"
# =================== Prompt ===================
grantbuddy_prompt = PromptTemplate.from_template(
"""You are Grant Buddy, a specialized language model fine-tuned with instruction-tuning and RLHF.
You help a nonprofit focused on social entrepreneurship, BIPOC empowerment, and edtech write clear, mission-aligned grant responses.
**Instructions:**
- Start with reasoning or context for your answer.
- Always align with the nonprofit’s mission.
- Use structured formatting: headings, bullet points, numbered lists.
- Include impact data or examples if relevant.
- Do NOT repeat the same sentence or answer multiple times.
- If no answer exists in the context, say: "This information is not available in the current context."
CONTEXT:
{context}
QUESTION:
{question}
"""
)
# =================== Vector Search Setup ===================
@st.cache_resource
def init_embedding_model():
return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
@st.cache_resource
def init_vector_search() -> MongoDBAtlasVectorSearch:
HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
model_name = "sentence-transformers/all-MiniLM-L6-v2"
st.write(f"🔌 Connecting to Hugging Face model: `{model_name}`")
embedding_model = HuggingFaceEmbeddings(model_name=model_name)
# ✅ Manual MongoClient with TLS settings
user = quote_plus(os.getenv("MONGO_USERNAME", "").strip())
password = quote_plus(os.getenv("MONGO_PASSWORD", "").strip())
cluster = os.getenv("MONGO_CLUSTER", "").strip()
db_name = os.getenv("MONGO_DB_NAME", "files").strip()
collection_name = os.getenv("MONGO_COLLECTION", "files_collection").strip()
index_name = os.getenv("MONGO_VECTOR_INDEX", "vector_index").strip()
mongo_uri = f"mongodb+srv://{user}:{password}@{cluster}/?retryWrites=true&w=majority"
try:
client = MongoClient(mongo_uri, tls=True, tlsAllowInvalidCertificates=True, serverSelectionTimeoutMS=20000)
db = client[db_name]
collection = db[collection_name]
st.success("✅ MongoClient connected successfully")
return MongoDBAtlasVectorSearch(
collection=collection,
embedding=embedding_model,
index_name=index_name,
)
except Exception as e:
st.error("❌ Failed to connect to MongoDB Atlas manually")
st.error(str(e))
raise e
# =================== Question/Headers Extraction ===================
# def extract_questions_and_headers(text: str) -> List[str]:
# header_patterns = [
# r'\d+\.\s+\*\*([^\*]+)\*\*',
# r'\*\*([^*]+)\*\*',
# r'^([A-Z][^a-z]*[A-Z])$',
# r'^([A-Z][A-Za-z\s]{3,})$',
# r'^[A-Z][A-Za-z\s]+:$'
# ]
# question_patterns = [
# r'^.+\?$',
# r'^\*?Please .+',
# r'^How .+',
# r'^What .+',
# r'^Describe .+',
# ]
# combined_header_re = re.compile("|".join(header_patterns), re.MULTILINE)
# combined_question_re = re.compile("|".join(question_patterns), re.MULTILINE)
# headers = [match for group in combined_header_re.findall(text) for match in group if match]
# questions = combined_question_re.findall(text)
# return headers + questions
def extract_with_llm(text: str) -> List[str]:
client = InferenceClient(api_key=HF_TOKEN.strip())
try:
response = client.chat.completions.create(
model="mistralai/Mistral-Nemo-Instruct-2407", # or "HuggingFaceH4/zephyr-7b-beta"
messages=[
{
"role": "system",
"content": "You are an assistant helping extract questions and headers from grant applications.",
},
{
"role": "user",
"content": (
"Please extract all the grant application headers and questions from the following text. "
"Include section titles, prompts, and any question-like content. Return them as a numbered list.\n\n"
f"{text[:3000]}"
),
},
],
temperature=0.2,
max_tokens=512,
)
return [
line.strip("•-1234567890. ").strip()
for line in response.choices[0].message.content.strip().split("\n")
if line.strip()
]
except Exception as e:
st.error("❌ LLM extraction failed")
st.error(str(e))
return []
# =================== Format Retrieved Chunks ===================
def format_docs(docs: List[Document]) -> str:
return "\n\n".join(doc.page_content or doc.metadata.get("content", "") for doc in docs)
# =================== Generate Response from Hugging Face Model ===================
def generate_response(input_dict: Dict[str, Any]) -> str:
client = InferenceClient(api_key=HF_TOKEN.strip())
prompt = grantbuddy_prompt.format(**input_dict)
try:
response = client.chat.completions.create(
model="HuggingFaceH4/zephyr-7b-beta",
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": input_dict["question"]},
],
max_tokens=1000,
temperature=0.2,
)
return response.choices[0].message.content
except Exception as e:
st.error(f"❌ Error from model: {e}")
return "⚠️ Failed to generate response. Please check your model, HF token, or request format."
# =================== RAG Chain ===================
def get_rag_chain(retriever):
return {
"context": retriever | RunnableLambda(format_docs),
"question": RunnablePassthrough()
} | RunnableLambda(generate_response)
# =================== Streamlit UI ===================
def main():
st.set_page_config(page_title="Grant Buddy RAG", page_icon="🤖")
st.title("🤖 Grant Buddy: Grant-Writing Assistant")
uploaded_file = st.file_uploader("Upload PDF or TXT for extra context (optional)", type=["pdf", "txt"])
uploaded_text = ""
retriever = init_vector_search().as_retriever(search_kwargs={"k": 10, "score_threshold": 0.75})
rag_chain = get_rag_chain(retriever) # ✅ Initialize before usage
# 🔁 Process uploaded file
if uploaded_file:
with st.spinner("📄 Processing uploaded file..."):
if uploaded_file.name.endswith(".pdf"):
reader = PdfReader(uploaded_file)
uploaded_text = "\n".join([page.extract_text() for page in reader.pages if page.extract_text()])
elif uploaded_file.name.endswith(".txt"):
uploaded_text = uploaded_file.read().decode("utf-8")
questions = extract_with_llm(uploaded_text)
st.success(f"✅ Found {len(questions)} questions or headers.")
with st.expander("🧠 Extracted Prompts from Upload"):
st.write(questions)
# Generate answers
answers = []
for q in questions:
full_query = f"{q}\n\nAdditional context:\n{uploaded_text}"
response = rag_chain.invoke(full_query)
answers.append({"question": q, "answer": response})
for item in answers:
st.markdown(f"### ❓ {item['question']}")
st.markdown(f"💬 {item['answer']}")
# ✅ Manual query box
query = st.text_input("Ask a grant-related question")
if st.button("Submit"):
if not query and not uploaded_file:
st.warning("Please enter a question.")
return
full_query = f"{query}\n\nAdditional context:\n{uploaded_text}" if uploaded_text else query
with st.spinner("🤖 Thinking..."):
response = rag_chain.invoke(full_query)
st.text_area("Grant Buddy says:", value=response, height=250, disabled=True)
with st.expander("🔍 Retrieved Chunks"):
context_docs = retriever.get_relevant_documents(full_query)
for doc in context_docs:
st.markdown(f"**Chunk ID:** {doc.metadata.get('chunk_id', 'unknown','title')}")
st.markdown(doc.page_content[:700] + "...")
st.markdown("---")
if __name__ == "__main__":
main()