Spaces:
Running
Running
File size: 5,395 Bytes
63c0a0b 4e00df7 63c0a0b 78308ba 63c0a0b dfd217b 63c0a0b dfd217b 63c0a0b 4e00df7 dfd217b 63c0a0b a91d644 2200d67 4e00df7 dfd217b 63c0a0b 4e00df7 63c0a0b 78308ba 1b4efee 4e00df7 63c0a0b 4e00df7 63c0a0b 4e00df7 d035a6e 63c0a0b 3fd401e dfd217b 63c0a0b 4e00df7 dfd217b 63c0a0b 1ca7761 63c0a0b 1ca7761 cae23e1 1ca7761 4e00df7 63c0a0b dfd217b 63c0a0b dfd217b 63c0a0b dfd217b 63c0a0b dfd217b 63c0a0b dfd217b 63c0a0b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
# main.py
import os
import streamlit as st
import anthropic
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import SupabaseVectorStore
from langchain_community.llms import HuggingFaceEndpoint
from langchain_community.vectorstores import SupabaseVectorStore
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from supabase import Client, create_client
from streamlit.logger import get_logger
from stats import get_usage, add_usage
supabase_url = st.secrets.SUPABASE_URL
supabase_key = st.secrets.SUPABASE_KEY
openai_api_key = st.secrets.openai_api_key
anthropic_api_key = st.secrets.anthropic_api_key
hf_api_key = st.secrets.hf_api_key
username = st.secrets.username
supabase: Client = create_client(supabase_url, supabase_key)
logger = get_logger(__name__)
embeddings = HuggingFaceInferenceAPIEmbeddings(
api_key=hf_api_key,
model_name="BAAI/bge-large-en-v1.5",
api_url="https://router.huggingface.co/hf-inference/pipeline/feature-extraction/",
)
if 'chat_history' not in st.session_state:
st.session_state['chat_history'] = []
vector_store = SupabaseVectorStore(supabase, embeddings, query_name='match_documents', table_name="documents")
memory = ConversationBufferMemory(memory_key="chat_history", input_key='question', output_key='answer', return_messages=True)
model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
temperature = 0.1
max_tokens = 500
stats = str(get_usage(supabase))
def response_generator(query):
qa = None
add_usage(supabase, "chat", "prompt" + query, {"model": model, "temperature": temperature})
logger.info('Using HF model %s', model)
# print(st.session_state['max_tokens'])
endpoint_url = ("https://api-inference.huggingface.co/models/"+ model)
model_kwargs = {"temperature" : temperature,
"max_new_tokens" : max_tokens,
# "repetition_penalty" : 1.1,
"return_full_text" : False}
hf = HuggingFaceEndpoint(
endpoint_url=endpoint_url,
task="text-generation",
huggingfacehub_api_token=hf_api_key,
model_kwargs=model_kwargs
)
qa = ConversationalRetrievalChain.from_llm(hf, retriever=vector_store.as_retriever(search_kwargs={"score_threshold": 0.6, "k": 4,"filter": {"user": username}}), memory=memory, verbose=True, return_source_documents=True)
# Generate model's response
model_response = qa({"question": query})
logger.info('Result: %s', model_response["answer"])
sources = model_response["source_documents"]
logger.info('Sources: %s', model_response["source_documents"])
if len(sources) > 0:
response = model_response["answer"]
else:
response = "I am sorry, I do not have enough information to provide an answer. If there is a public source of data that you would like to add, please email [email protected]."
return response
# Set the theme
st.set_page_config(
page_title="Securade.ai - Safety Copilot",
page_icon="https://securade.ai/favicon.ico",
layout="centered",
initial_sidebar_state="collapsed",
menu_items={
"About": "# Securade.ai Safety Copilot v0.1\n [https://securade.ai](https://securade.ai)",
"Get Help" : "https://securade.ai",
"Report a Bug": "mailto:[email protected]"
}
)
st.title("👷♂️ Safety Copilot 🦺")
st.markdown("Chat with your personal safety assistant about any health & safety related queries. [[blog](https://securade.ai/blog/how-securade-ai-safety-copilot-transforms-worker-safety.html)|[paper](https://securade.ai/assets/pdfs/Securade.ai-Safety-Copilot-Whitepaper.pdf)]")
# st.markdown("Up-to-date with latest OSH regulations for Singapore, Indonesia, Malaysia & other parts of Asia.")
st.markdown("_"+ stats + " queries answered!_")
if 'chat_history' not in st.session_state:
st.session_state['chat_history'] = []
# Display chat messages from history on app rerun
for message in st.session_state.chat_history:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Accept user input
if prompt := st.chat_input("Ask a question"):
# print(prompt)
# Add user message to chat history
st.session_state.chat_history.append({"role": "user", "content": prompt})
# Display user message in chat message container
with st.chat_message("user"):
st.markdown(prompt)
with st.spinner('Safety briefing in progress...'):
response = response_generator(prompt)
# Display assistant response in chat message container
with st.chat_message("assistant"):
st.markdown(response)
# Add assistant response to chat history
# print(response)
st.session_state.chat_history.append({"role": "assistant", "content": response})
# query = st.text_area("## Ask a question (" + stats + " queries answered so far)", max_chars=500)
# columns = st.columns(2)
# with columns[0]:
# button = st.button("Ask")
# with columns[1]:
# clear_history = st.button("Clear History", type='secondary')
# st.markdown("---\n\n")
# if clear_history:
# # Clear memory in Langchain
# memory.clear()
# st.session_state['chat_history'] = []
# st.experimental_rerun() |