File size: 5,732 Bytes
63c0a0b
90fddeb
4e00df7
 
63c0a0b
90fddeb
63c0a0b
 
dfd217b
 
63c0a0b
dfd217b
 
63c0a0b
4e00df7
dfd217b
63c0a0b
a91d644
90fddeb
 
 
 
2200d67
90fddeb
 
4e00df7
dfd217b
63c0a0b
4e00df7
90fddeb
 
 
78308ba
90fddeb
 
4e00df7
 
90fddeb
 
 
 
 
 
 
 
 
 
 
 
 
4e00df7
90fddeb
 
 
 
63c0a0b
90fddeb
 
 
 
 
4e00df7
90fddeb
63c0a0b
5516e24
 
63c0a0b
 
90fddeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63c0a0b
90fddeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ca7761
 
 
 
 
 
90fddeb
 
 
 
1ca7761
cae23e1
1ca7761
90fddeb
 
 
 
 
 
 
4e00df7
90fddeb
 
 
 
 
 
 
 
 
 
dfd217b
 
 
90fddeb
 
 
 
dfd217b
90fddeb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# main.py

import os
import streamlit as st
import anthropic
from requests import JSONDecodeError

from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import SupabaseVectorStore
from langchain_community.llms import HuggingFaceEndpoint

from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory

from supabase import Client, create_client
from streamlit.logger import get_logger
from stats import get_usage, add_usage

# ─────── supabase + secrets ────────────────────────────────────────────────────
supabase_url    = st.secrets.SUPABASE_URL
supabase_key    = st.secrets.SUPABASE_KEY
openai_api_key  = st.secrets.openai_api_key
anthropic_api_key = st.secrets.anthropic_api_key
hf_api_key      = st.secrets.hf_api_key
username        = st.secrets.username

supabase: Client = create_client(supabase_url, supabase_key)
logger = get_logger(__name__)

# ─────── embeddings ─────────────────────────────────────────────────────────────
# Switch to local BGE embeddings (no JSONDecode errors, no HTTP‑batch issues) :contentReference[oaicite:0]{index=0}
embeddings = HuggingFaceBgeEmbeddings(
    model_name="BAAI/bge-large-en-v1.5",
    model_kwargs={"device": "cpu"},
    encode_kwargs={"normalize_embeddings": True},
)

# ─────── vector store + memory ─────────────────────────────────────────────────
vector_store = SupabaseVectorStore(
    client=supabase,
    embedding=embeddings,
    query_name="match_documents",
    table_name="documents",
)
memory = ConversationBufferMemory(
    memory_key="chat_history",
    input_key="question",
    output_key="answer",
    return_messages=True,
)

# ─────── LLM setup ──────────────────────────────────────────────────────────────
model        = "mistralai/Mixtral-8x7B-Instruct-v0.1"
temperature  = 0.1
max_tokens   = 500

def response_generator(query: str) -> str:
    """Ask the RAG chain to answer `query`, with JSON‑error fallback."""
    # log usage
    add_usage(supabase, "chat", "prompt:" + query, {"model": model, "temperature": temperature})
    logger.info("Using HF model %s", model)

    # prepare HF text-generation LLM
    hf = HuggingFaceEndpoint(
        # endpoint_url=f"https://api-inference.huggingface.co/models/{model}",
        endpoint_url=f"https://router.huggingface.co/hf-inference/models/{model}",
        task="text-generation",
        huggingfacehub_api_token=hf_api_key,
        model_kwargs={
            "temperature": temperature,
            "max_new_tokens": max_tokens,
            "return_full_text": False,
        },
    )

    # conversational RAG chain
    qa = ConversationalRetrievalChain.from_llm(
        llm=hf,
        retriever=vector_store.as_retriever(
            search_kwargs={"score_threshold": 0.6, "k": 4, "filter": {"user": username}}
        ),
        memory=memory,
        verbose=True,
        return_source_documents=True,
    )

    try:
        result = qa({"question": query})
    except JSONDecodeError as e:
        # fallback logging  
        logger.error("Embedding JSONDecodeError: %s", e)
        return "Sorry, I had trouble understanding the embedded data. Please try again."

    answer = result.get("answer", "")
    sources = result.get("source_documents", [])

    if not sources:
        return (
            "I’m sorry, I don’t have enough information to answer that. "
            "If you have a public data source to add, please email [email protected]."
        )
    return answer

# ─────── Streamlit UI ──────────────────────────────────────────────────────────
st.set_page_config(
    page_title="Securade.ai - Safety Copilot",
    page_icon="https://securade.ai/favicon.ico",
    layout="centered",
    initial_sidebar_state="collapsed",
    menu_items={
        "About": "# Securade.ai Safety Copilot v0.1\n[https://securade.ai](https://securade.ai)",
        "Get Help": "https://securade.ai",
        "Report a Bug": "mailto:[email protected]",
    },
)

st.title("πŸ‘·β€β™‚οΈ Safety Copilot 🦺")
stats = get_usage(supabase)
st.markdown(f"_{stats} queries answered!_")
st.markdown(
    "Chat with your personal safety assistant about any health & safety related queries. "
    "[[blog](https://securade.ai/blog/how-securade-ai-safety-copilot-transforms-worker-safety.html)"
    "|[paper](https://securade.ai/assets/pdfs/Securade.ai-Safety-Copilot-Whitepaper.pdf)]"
)

if "chat_history" not in st.session_state:
    st.session_state.chat_history = []

# show history
for msg in st.session_state.chat_history:
    with st.chat_message(msg["role"]):
        st.markdown(msg["content"])

# new user input
if prompt := st.chat_input("Ask a question"):
    st.session_state.chat_history.append({"role": "user", "content": prompt})
    with st.chat_message("user"):
        st.markdown(prompt)

    with st.spinner("Safety briefing in progress..."):
        answer = response_generator(prompt)

    with st.chat_message("assistant"):
        st.markdown(answer)
    st.session_state.chat_history.append({"role": "assistant", "content": answer})