safety-copilot / main.py
codelion's picture
Update main.py
5516e24 verified
# main.py
import os
import streamlit as st
import anthropic
from requests import JSONDecodeError
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import SupabaseVectorStore
from langchain_community.llms import HuggingFaceEndpoint
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from supabase import Client, create_client
from streamlit.logger import get_logger
from stats import get_usage, add_usage
# ─────── supabase + secrets ────────────────────────────────────────────────────
supabase_url = st.secrets.SUPABASE_URL
supabase_key = st.secrets.SUPABASE_KEY
openai_api_key = st.secrets.openai_api_key
anthropic_api_key = st.secrets.anthropic_api_key
hf_api_key = st.secrets.hf_api_key
username = st.secrets.username
supabase: Client = create_client(supabase_url, supabase_key)
logger = get_logger(__name__)
# ─────── embeddings ─────────────────────────────────────────────────────────────
# Switch to local BGE embeddings (no JSONDecode errors, no HTTP‑batch issues) :contentReference[oaicite:0]{index=0}
embeddings = HuggingFaceBgeEmbeddings(
model_name="BAAI/bge-large-en-v1.5",
model_kwargs={"device": "cpu"},
encode_kwargs={"normalize_embeddings": True},
)
# ─────── vector store + memory ─────────────────────────────────────────────────
vector_store = SupabaseVectorStore(
client=supabase,
embedding=embeddings,
query_name="match_documents",
table_name="documents",
)
memory = ConversationBufferMemory(
memory_key="chat_history",
input_key="question",
output_key="answer",
return_messages=True,
)
# ─────── LLM setup ──────────────────────────────────────────────────────────────
model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
temperature = 0.1
max_tokens = 500
def response_generator(query: str) -> str:
"""Ask the RAG chain to answer `query`, with JSON‑error fallback."""
# log usage
add_usage(supabase, "chat", "prompt:" + query, {"model": model, "temperature": temperature})
logger.info("Using HF model %s", model)
# prepare HF text-generation LLM
hf = HuggingFaceEndpoint(
# endpoint_url=f"https://api-inference.huggingface.co/models/{model}",
endpoint_url=f"https://router.huggingface.co/hf-inference/models/{model}",
task="text-generation",
huggingfacehub_api_token=hf_api_key,
model_kwargs={
"temperature": temperature,
"max_new_tokens": max_tokens,
"return_full_text": False,
},
)
# conversational RAG chain
qa = ConversationalRetrievalChain.from_llm(
llm=hf,
retriever=vector_store.as_retriever(
search_kwargs={"score_threshold": 0.6, "k": 4, "filter": {"user": username}}
),
memory=memory,
verbose=True,
return_source_documents=True,
)
try:
result = qa({"question": query})
except JSONDecodeError as e:
# fallback logging
logger.error("Embedding JSONDecodeError: %s", e)
return "Sorry, I had trouble understanding the embedded data. Please try again."
answer = result.get("answer", "")
sources = result.get("source_documents", [])
if not sources:
return (
"I’m sorry, I don’t have enough information to answer that. "
"If you have a public data source to add, please email [email protected]."
)
return answer
# ─────── Streamlit UI ──────────────────────────────────────────────────────────
st.set_page_config(
page_title="Securade.ai - Safety Copilot",
page_icon="https://securade.ai/favicon.ico",
layout="centered",
initial_sidebar_state="collapsed",
menu_items={
"About": "# Securade.ai Safety Copilot v0.1\n[https://securade.ai](https://securade.ai)",
"Get Help": "https://securade.ai",
"Report a Bug": "mailto:[email protected]",
},
)
st.title("πŸ‘·β€β™‚οΈ Safety Copilot 🦺")
stats = get_usage(supabase)
st.markdown(f"_{stats} queries answered!_")
st.markdown(
"Chat with your personal safety assistant about any health & safety related queries. "
"[[blog](https://securade.ai/blog/how-securade-ai-safety-copilot-transforms-worker-safety.html)"
"|[paper](https://securade.ai/assets/pdfs/Securade.ai-Safety-Copilot-Whitepaper.pdf)]"
)
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# show history
for msg in st.session_state.chat_history:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
# new user input
if prompt := st.chat_input("Ask a question"):
st.session_state.chat_history.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.spinner("Safety briefing in progress..."):
answer = response_generator(prompt)
with st.chat_message("assistant"):
st.markdown(answer)
st.session_state.chat_history.append({"role": "assistant", "content": answer})