Spaces:
Running
Running
File size: 5,732 Bytes
63c0a0b 90fddeb 4e00df7 63c0a0b 90fddeb 63c0a0b dfd217b 63c0a0b dfd217b 63c0a0b 4e00df7 dfd217b 63c0a0b a91d644 90fddeb 2200d67 90fddeb 4e00df7 dfd217b 63c0a0b 4e00df7 90fddeb 78308ba 90fddeb 4e00df7 90fddeb 4e00df7 90fddeb 63c0a0b 90fddeb 4e00df7 90fddeb 63c0a0b 5516e24 63c0a0b 90fddeb 63c0a0b 90fddeb 1ca7761 90fddeb 1ca7761 cae23e1 1ca7761 90fddeb 4e00df7 90fddeb dfd217b 90fddeb dfd217b 90fddeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
# main.py
import os
import streamlit as st
import anthropic
from requests import JSONDecodeError
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import SupabaseVectorStore
from langchain_community.llms import HuggingFaceEndpoint
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from supabase import Client, create_client
from streamlit.logger import get_logger
from stats import get_usage, add_usage
# βββββββ supabase + secrets ββββββββββββββββββββββββββββββββββββββββββββββββββββ
supabase_url = st.secrets.SUPABASE_URL
supabase_key = st.secrets.SUPABASE_KEY
openai_api_key = st.secrets.openai_api_key
anthropic_api_key = st.secrets.anthropic_api_key
hf_api_key = st.secrets.hf_api_key
username = st.secrets.username
supabase: Client = create_client(supabase_url, supabase_key)
logger = get_logger(__name__)
# βββββββ embeddings βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Switch to local BGE embeddings (no JSONDecode errors, no HTTPβbatch issues) :contentReference[oaicite:0]{index=0}
embeddings = HuggingFaceBgeEmbeddings(
model_name="BAAI/bge-large-en-v1.5",
model_kwargs={"device": "cpu"},
encode_kwargs={"normalize_embeddings": True},
)
# βββββββ vector store + memory βββββββββββββββββββββββββββββββββββββββββββββββββ
vector_store = SupabaseVectorStore(
client=supabase,
embedding=embeddings,
query_name="match_documents",
table_name="documents",
)
memory = ConversationBufferMemory(
memory_key="chat_history",
input_key="question",
output_key="answer",
return_messages=True,
)
# βββββββ LLM setup ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
temperature = 0.1
max_tokens = 500
def response_generator(query: str) -> str:
"""Ask the RAG chain to answer `query`, with JSONβerror fallback."""
# log usage
add_usage(supabase, "chat", "prompt:" + query, {"model": model, "temperature": temperature})
logger.info("Using HF model %s", model)
# prepare HF text-generation LLM
hf = HuggingFaceEndpoint(
# endpoint_url=f"https://api-inference.huggingface.co/models/{model}",
endpoint_url=f"https://router.huggingface.co/hf-inference/models/{model}",
task="text-generation",
huggingfacehub_api_token=hf_api_key,
model_kwargs={
"temperature": temperature,
"max_new_tokens": max_tokens,
"return_full_text": False,
},
)
# conversational RAG chain
qa = ConversationalRetrievalChain.from_llm(
llm=hf,
retriever=vector_store.as_retriever(
search_kwargs={"score_threshold": 0.6, "k": 4, "filter": {"user": username}}
),
memory=memory,
verbose=True,
return_source_documents=True,
)
try:
result = qa({"question": query})
except JSONDecodeError as e:
# fallback logging
logger.error("Embedding JSONDecodeError: %s", e)
return "Sorry, I had trouble understanding the embedded data. Please try again."
answer = result.get("answer", "")
sources = result.get("source_documents", [])
if not sources:
return (
"Iβm sorry, I donβt have enough information to answer that. "
"If you have a public data source to add, please email [email protected]."
)
return answer
# βββββββ Streamlit UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
st.set_page_config(
page_title="Securade.ai - Safety Copilot",
page_icon="https://securade.ai/favicon.ico",
layout="centered",
initial_sidebar_state="collapsed",
menu_items={
"About": "# Securade.ai Safety Copilot v0.1\n[https://securade.ai](https://securade.ai)",
"Get Help": "https://securade.ai",
"Report a Bug": "mailto:[email protected]",
},
)
st.title("π·ββοΈ Safety Copilot π¦Ί")
stats = get_usage(supabase)
st.markdown(f"_{stats} queries answered!_")
st.markdown(
"Chat with your personal safety assistant about any health & safety related queries. "
"[[blog](https://securade.ai/blog/how-securade-ai-safety-copilot-transforms-worker-safety.html)"
"|[paper](https://securade.ai/assets/pdfs/Securade.ai-Safety-Copilot-Whitepaper.pdf)]"
)
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# show history
for msg in st.session_state.chat_history:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
# new user input
if prompt := st.chat_input("Ask a question"):
st.session_state.chat_history.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.spinner("Safety briefing in progress..."):
answer = response_generator(prompt)
with st.chat_message("assistant"):
st.markdown(answer)
st.session_state.chat_history.append({"role": "assistant", "content": answer})
|