codelion commited on
Commit
a91d644
·
verified ·
1 Parent(s): 09ffec2

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +140 -91
main.py CHANGED
@@ -1,81 +1,152 @@
1
- # main.py
2
  import os
3
  import streamlit as st
4
- import anthropic
5
-
6
  from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
7
- from langchain_community.embeddings import HuggingFaceBgeEmbeddings
8
  from langchain_community.vectorstores import SupabaseVectorStore
9
  from langchain_community.llms import HuggingFaceEndpoint
10
- from langchain_community.vectorstores import SupabaseVectorStore
11
-
12
  from langchain.chains import ConversationalRetrievalChain
13
  from langchain.memory import ConversationBufferMemory
14
-
15
  from supabase import Client, create_client
16
  from streamlit.logger import get_logger
17
- from stats import get_usage, add_usage
18
 
19
- supabase_url = st.secrets.SUPABASE_URL
20
- supabase_key = st.secrets.SUPABASE_KEY
21
- openai_api_key = st.secrets.openai_api_key
22
- anthropic_api_key = st.secrets.anthropic_api_key
23
- hf_api_key = st.secrets.hf_api_key
24
- username = st.secrets.username
 
 
 
25
 
 
26
  supabase: Client = create_client(supabase_url, supabase_key)
27
- logger = get_logger(__name__)
28
 
29
- embeddings = HuggingFaceInferenceAPIEmbeddings(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  api_key=hf_api_key,
31
  model_name="BAAI/bge-large-en-v1.5",
32
- api_url="https://router.huggingface.co/hf-inference/pipeline/feature-extraction/",
33
  )
34
 
35
- if 'chat_history' not in st.session_state:
36
- st.session_state['chat_history'] = []
 
37
 
38
- vector_store = SupabaseVectorStore(supabase, embeddings, query_name='match_documents', table_name="documents")
39
- memory = ConversationBufferMemory(memory_key="chat_history", input_key='question', output_key='answer', return_messages=True)
 
 
 
 
 
 
 
 
 
 
 
40
 
 
41
  model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
42
-
43
  temperature = 0.1
44
  max_tokens = 500
45
- stats = str(get_usage(supabase))
 
 
 
 
 
 
 
 
46
 
47
  def response_generator(query):
48
- qa = None
49
- add_usage(supabase, "chat", "prompt" + query, {"model": model, "temperature": temperature})
50
- logger.info('Using HF model %s', model)
51
- # print(st.session_state['max_tokens'])
52
- endpoint_url = ("https://api-inference.huggingface.co/models/"+ model)
53
- model_kwargs = {"temperature" : temperature,
54
- "max_new_tokens" : max_tokens,
55
- # "repetition_penalty" : 1.1,
56
- "return_full_text" : False}
57
- hf = HuggingFaceEndpoint(
58
- endpoint_url=endpoint_url,
59
- task="text-generation",
60
- huggingfacehub_api_token=hf_api_key,
61
- model_kwargs=model_kwargs
62
- )
63
- qa = ConversationalRetrievalChain.from_llm(hf, retriever=vector_store.as_retriever(search_kwargs={"score_threshold": 0.6, "k": 4,"filter": {"user": username}}), memory=memory, verbose=True, return_source_documents=True)
64
-
65
- # Generate model's response
66
- model_response = qa({"question": query})
67
- logger.info('Result: %s', model_response["answer"])
68
- sources = model_response["source_documents"]
69
- logger.info('Sources: %s', model_response["source_documents"])
70
-
71
- if len(sources) > 0:
72
- response = model_response["answer"]
73
- else:
74
- response = "I am sorry, I do not have enough information to provide an answer. If there is a public source of data that you would like to add, please email [email protected]."
75
-
76
- return response
77
-
78
- # Set the theme
 
 
 
 
 
 
 
 
79
  st.set_page_config(
80
  page_title="Securade.ai - Safety Copilot",
81
  page_icon="https://securade.ai/favicon.ico",
@@ -83,55 +154,33 @@ st.set_page_config(
83
  initial_sidebar_state="collapsed",
84
  menu_items={
85
  "About": "# Securade.ai Safety Copilot v0.1\n [https://securade.ai](https://securade.ai)",
86
- "Get Help" : "https://securade.ai",
87
- "Report a Bug": "mailto:[email protected]"
88
- }
89
  )
90
 
91
  st.title("👷‍♂️ Safety Copilot 🦺")
 
 
 
 
 
 
92
 
93
- st.markdown("Chat with your personal safety assistant about any health & safety related queries. [[blog](https://securade.ai/blog/how-securade-ai-safety-copilot-transforms-worker-safety.html)|[paper](https://securade.ai/assets/pdfs/Securade.ai-Safety-Copilot-Whitepaper.pdf)]")
94
- # st.markdown("Up-to-date with latest OSH regulations for Singapore, Indonesia, Malaysia & other parts of Asia.")
95
- st.markdown("_"+ stats + " queries answered!_")
96
-
97
- if 'chat_history' not in st.session_state:
98
- st.session_state['chat_history'] = []
99
-
100
- # Display chat messages from history on app rerun
101
  for message in st.session_state.chat_history:
102
  with st.chat_message(message["role"]):
103
  st.markdown(message["content"])
104
-
105
- # Accept user input
106
- if prompt := st.chat_input("Ask a question"):
107
- # print(prompt)
108
- # Add user message to chat history
109
  st.session_state.chat_history.append({"role": "user", "content": prompt})
110
- # Display user message in chat message container
111
  with st.chat_message("user"):
112
  st.markdown(prompt)
113
 
114
- with st.spinner('Safety briefing in progress...'):
115
  response = response_generator(prompt)
116
 
117
- # Display assistant response in chat message container
118
  with st.chat_message("assistant"):
119
  st.markdown(response)
120
- # Add assistant response to chat history
121
- # print(response)
122
- st.session_state.chat_history.append({"role": "assistant", "content": response})
123
-
124
- # query = st.text_area("## Ask a question (" + stats + " queries answered so far)", max_chars=500)
125
- # columns = st.columns(2)
126
- # with columns[0]:
127
- # button = st.button("Ask")
128
- # with columns[1]:
129
- # clear_history = st.button("Clear History", type='secondary')
130
-
131
- # st.markdown("---\n\n")
132
-
133
- # if clear_history:
134
- # # Clear memory in Langchain
135
- # memory.clear()
136
- # st.session_state['chat_history'] = []
137
- # st.experimental_rerun()
 
 
1
  import os
2
  import streamlit as st
3
+ import logging
4
+ from requests.exceptions import JSONDecodeError
5
  from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
 
6
  from langchain_community.vectorstores import SupabaseVectorStore
7
  from langchain_community.llms import HuggingFaceEndpoint
 
 
8
  from langchain.chains import ConversationalRetrievalChain
9
  from langchain.memory import ConversationBufferMemory
 
10
  from supabase import Client, create_client
11
  from streamlit.logger import get_logger
 
12
 
13
+ # Configure logging
14
+ logger = get_logger(__name__)
15
+ logging.basicConfig(level=logging.INFO)
16
+
17
+ # Load secrets
18
+ supabase_url = st.secrets["SUPABASE_URL"]
19
+ supabase_key = st.secrets["SUPABASE_KEY"]
20
+ hf_api_key = st.secrets["hf_api_key"]
21
+ username = st.secrets["username"]
22
 
23
+ # Initialize Supabase client
24
  supabase: Client = create_client(supabase_url, supabase_key)
 
25
 
26
+ # Custom HuggingFaceInferenceAPIEmbeddings to handle JSONDecodeError
27
+ class CustomHuggingFaceInferenceAPIEmbeddings(HuggingFaceInferenceAPIEmbeddings):
28
+ def embed_query(self, text: str):
29
+ try:
30
+ response = self.client.post(
31
+ json={"inputs": text, "options": {"use_cache": False}},
32
+ task="feature-extraction",
33
+ )
34
+ if response.status_code != 200:
35
+ logger.error(f"API request failed with status {response.status_code}: {response.text}")
36
+ return [0.0] * 384 # Return zero vector of expected dimension
37
+ try:
38
+ embeddings = response.json()
39
+ if not isinstance(embeddings, list) or not embeddings:
40
+ logger.error(f"Invalid embeddings response: {embeddings}")
41
+ return [0.0] * 384
42
+ return embeddings[0]
43
+ except JSONDecodeError as e:
44
+ logger.error(f"JSON decode error: {str(e)}, response: {response.text}")
45
+ return [0.0] * 384
46
+ except Exception as e:
47
+ logger.error(f"Error embedding query: {str(e)}")
48
+ return [0.0] * 384
49
+
50
+ def embed_documents(self, texts):
51
+ try:
52
+ response = self.client.post(
53
+ json={"inputs": texts, "options": {"use_cache": False}},
54
+ task="feature-extraction",
55
+ )
56
+ if response.status_code != 200:
57
+ logger.error(f"API request failed with status {response.status_code}: {response.text}")
58
+ return [[0.0] * 384 for _ in texts]
59
+ try:
60
+ embeddings = response.json()
61
+ if not isinstance(embeddings, list) or not embeddings:
62
+ logger.error(f"Invalid embeddings response: {embeddings}")
63
+ return [[0.0] * 384 for _ in texts]
64
+ return [emb[0] for emb in embeddings]
65
+ except JSONDecodeError as e:
66
+ logger.error(f"JSON decode error: {str(e)}, response: {response.text}")
67
+ return [[0.0] * 384 for _ in texts]
68
+ except Exception as e:
69
+ logger.error(f"Error embedding documents: {str(e)}")
70
+ return [[0.0] * 384 for _ in texts]
71
+
72
+ # Initialize embeddings
73
+ embeddings = CustomHuggingFaceInferenceAPIEmbeddings(
74
  api_key=hf_api_key,
75
  model_name="BAAI/bge-large-en-v1.5",
 
76
  )
77
 
78
+ # Initialize session state
79
+ if "chat_history" not in st.session_state:
80
+ st.session_state["chat_history"] = []
81
 
82
+ # Initialize vector store and memory
83
+ vector_store = SupabaseVectorStore(
84
+ client=supabase,
85
+ embedding=embeddings,
86
+ query_name="match_documents",
87
+ table_name="documents",
88
+ )
89
+ memory = ConversationBufferMemory(
90
+ memory_key="chat_history",
91
+ input_key="question",
92
+ output_key="answer",
93
+ return_messages=True,
94
+ )
95
 
96
+ # Model configuration
97
  model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
 
98
  temperature = 0.1
99
  max_tokens = 500
100
+
101
+ # Mock stats function (replace with your actual implementation)
102
+ def get_usage(supabase):
103
+ return 100 # Replace with actual logic
104
+
105
+ def add_usage(supabase, action, prompt, metadata):
106
+ pass # Replace with actual logic
107
+
108
+ stats = str(get_usage(supabase))
109
 
110
  def response_generator(query):
111
+ try:
112
+ add_usage(supabase, "chat", f"prompt: {query}", {"model": model, "temperature": temperature})
113
+ logger.info("Using HF model %s", model)
114
+
115
+ endpoint_url = f"https://api-inference.huggingface.co/models/{model}"
116
+ model_kwargs = {
117
+ "temperature": temperature,
118
+ "max_new_tokens": max_tokens,
119
+ "return_full_text": False,
120
+ }
121
+ hf = HuggingFaceEndpoint(
122
+ endpoint_url=endpoint_url,
123
+ task="text-generation",
124
+ huggingfacehub_api_token=hf_api_key,
125
+ model_kwargs=model_kwargs,
126
+ )
127
+ qa = ConversationalRetrievalChain.from_llm(
128
+ llm=hf,
129
+ retriever=vector_store.as_retriever(search_kwargs={"score_threshold": 0.6, "k": 4, "filter": {"user": username}}),
130
+ memory=memory,
131
+ verbose=True,
132
+ return_source_documents=True,
133
+ )
134
+
135
+ # Use invoke instead of deprecated __call__
136
+ model_response = qa.invoke({"question": query})
137
+ logger.info("Result: %s", model_response["answer"])
138
+ sources = model_response["source_documents"]
139
+ logger.info("Sources: %s", sources)
140
+
141
+ if sources:
142
+ return model_response["answer"]
143
+ else:
144
+ return "I am sorry, I do not have enough information to provide an answer. If there is a public source of data that you would like to add, please email [email protected]."
145
+ except Exception as e:
146
+ logger.error(f"Error generating response: {str(e)}")
147
+ return "An error occurred while processing your request. Please try again later."
148
+
149
+ # Streamlit UI
150
  st.set_page_config(
151
  page_title="Securade.ai - Safety Copilot",
152
  page_icon="https://securade.ai/favicon.ico",
 
154
  initial_sidebar_state="collapsed",
155
  menu_items={
156
  "About": "# Securade.ai Safety Copilot v0.1\n [https://securade.ai](https://securade.ai)",
157
+ "Get Help": "https://securade.ai",
158
+ "Report a Bug": "mailto:[email protected]",
159
+ },
160
  )
161
 
162
  st.title("👷‍♂️ Safety Copilot 🦺")
163
+ st.markdown(
164
+ "Chat with your personal safety assistant about any health & safety related queries. "
165
+ "[[blog](https://securade.ai/blog/how-securade-ai-safety-copilot-transforms-worker-safety.html)|"
166
+ "[paper](https://securade.ai/assets/pdfs/Securade.ai-Safety-Copilot-Whitepaper.pdf)]"
167
+ )
168
+ st.markdown(f"_{stats} queries answered!_")
169
 
170
+ # Display chat history
 
 
 
 
 
 
 
171
  for message in st.session_state.chat_history:
172
  with st.chat_message(message["role"]):
173
  st.markdown(message["content"])
174
+
175
+ # Handle user input
176
+ if prompt := st.chat_input("Ask a question"):
 
 
177
  st.session_state.chat_history.append({"role": "user", "content": prompt})
 
178
  with st.chat_message("user"):
179
  st.markdown(prompt)
180
 
181
+ with st.spinner("Safety briefing in progress..."):
182
  response = response_generator(prompt)
183
 
 
184
  with st.chat_message("assistant"):
185
  st.markdown(response)
186
+ st.session_state.chat_history.append({"role": "assistant", "content": response})