codelion commited on
Commit
63c0a0b
·
verified ·
1 Parent(s): 1b4efee

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +84 -135
main.py CHANGED
@@ -1,18 +1,20 @@
 
1
  import os
2
  import streamlit as st
3
- import logging
4
- from requests.exceptions import JSONDecodeError
5
  from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
 
6
  from langchain_community.vectorstores import SupabaseVectorStore
7
  from langchain_community.llms import HuggingFaceEndpoint
 
 
8
  from langchain.chains import ConversationalRetrievalChain
9
  from langchain.memory import ConversationBufferMemory
 
10
  from supabase import Client, create_client
11
  from streamlit.logger import get_logger
12
-
13
- # Configure logging
14
- logger = get_logger(__name__)
15
- logging.basicConfig(level=logging.INFO)
16
 
17
  supabase_url = st.secrets.SUPABASE_URL
18
  supabase_key = st.secrets.SUPABASE_KEY
@@ -21,134 +23,59 @@ anthropic_api_key = st.secrets.anthropic_api_key
21
  hf_api_key = st.secrets.hf_api_key
22
  username = st.secrets.username
23
 
24
- # Initialize Supabase client
25
  supabase: Client = create_client(supabase_url, supabase_key)
 
26
 
27
- # Custom HuggingFaceInferenceAPIEmbeddings to handle JSONDecodeError
28
- class CustomHuggingFaceInferenceAPIEmbeddings(HuggingFaceInferenceAPIEmbeddings):
29
- def embed_query(self, text: str):
30
- try:
31
- response = self.client.post(
32
- json={"inputs": text, "options": {"use_cache": False}},
33
- task="feature-extraction",
34
- )
35
- if response.status_code != 200:
36
- logger.error(f"API request failed with status {response.status_code}: {response.text}")
37
- return [0.0] * 384 # Return zero vector of expected dimension
38
- try:
39
- embeddings = response.json()
40
- if not isinstance(embeddings, list) or not embeddings:
41
- logger.error(f"Invalid embeddings response: {embeddings}")
42
- return [0.0] * 384
43
- return embeddings[0]
44
- except JSONDecodeError as e:
45
- logger.error(f"JSON decode error: {str(e)}, response: {response.text}")
46
- return [0.0] * 384
47
- except Exception as e:
48
- logger.error(f"Error embedding query: {str(e)}")
49
- return [0.0] * 384
50
-
51
- def embed_documents(self, texts):
52
- try:
53
- response = self.client.post(
54
- json={"inputs": texts, "options": {"use_cache": False}},
55
- task="feature-extraction",
56
- )
57
- if response.status_code != 200:
58
- logger.error(f"API request failed with status {response.status_code}: {response.text}")
59
- return [[0.0] * 384 for _ in texts]
60
- try:
61
- embeddings = response.json()
62
- if not isinstance(embeddings, list) or not embeddings:
63
- logger.error(f"Invalid embeddings response: {embeddings}")
64
- return [[0.0] * 384 for _ in texts]
65
- return [emb[0] for emb in embeddings]
66
- except JSONDecodeError as e:
67
- logger.error(f"JSON decode error: {str(e)}, response: {response.text}")
68
- return [[0.0] * 384 for _ in texts]
69
- except Exception as e:
70
- logger.error(f"Error embedding documents: {str(e)}")
71
- return [[0.0] * 384 for _ in texts]
72
-
73
- # Initialize embeddings
74
- embeddings = CustomHuggingFaceInferenceAPIEmbeddings(
75
  api_key=hf_api_key,
76
  model_name="BAAI/bge-large-en-v1.5",
77
  api_url="https://router.huggingface.co/hf-inference/pipeline/feature-extraction/",
78
  )
79
 
80
- # Initialize session state
81
- if "chat_history" not in st.session_state:
82
- st.session_state["chat_history"] = []
83
 
84
- # Initialize vector store and memory
85
- vector_store = SupabaseVectorStore(
86
- client=supabase,
87
- embedding=embeddings,
88
- query_name="match_documents",
89
- table_name="documents",
90
- )
91
- memory = ConversationBufferMemory(
92
- memory_key="chat_history",
93
- input_key="question",
94
- output_key="answer",
95
- return_messages=True,
96
- )
97
 
98
- # Model configuration
99
  model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
 
100
  temperature = 0.1
101
  max_tokens = 500
102
-
103
- # Mock stats function (replace with your actual implementation)
104
- def get_usage(supabase):
105
- return 100 # Replace with actual logic
106
-
107
- def add_usage(supabase, action, prompt, metadata):
108
- pass # Replace with actual logic
109
-
110
- stats = str(get_usage(supabase))
111
 
112
  def response_generator(query):
113
- try:
114
- add_usage(supabase, "chat", f"prompt: {query}", {"model": model, "temperature": temperature})
115
- logger.info("Using HF model %s", model)
116
-
117
- endpoint_url = f"https://api-inference.huggingface.co/models/{model}"
118
- model_kwargs = {
119
- "temperature": temperature,
120
- "max_new_tokens": max_tokens,
121
- "return_full_text": False,
122
- }
123
- hf = HuggingFaceEndpoint(
124
- endpoint_url=endpoint_url,
125
- task="text-generation",
126
- huggingfacehub_api_token=hf_api_key,
127
- model_kwargs=model_kwargs,
128
- )
129
- qa = ConversationalRetrievalChain.from_llm(
130
- llm=hf,
131
- retriever=vector_store.as_retriever(search_kwargs={"score_threshold": 0.6, "k": 4, "filter": {"user": username}}),
132
- memory=memory,
133
- verbose=True,
134
- return_source_documents=True,
135
- )
136
-
137
- # Use invoke instead of deprecated __call__
138
- model_response = qa.invoke({"question": query})
139
- logger.info("Result: %s", model_response["answer"])
140
- sources = model_response["source_documents"]
141
- logger.info("Sources: %s", sources)
142
-
143
- if sources:
144
- return model_response["answer"]
145
- else:
146
- return "I am sorry, I do not have enough information to provide an answer. If there is a public source of data that you would like to add, please email [email protected]."
147
- except Exception as e:
148
- logger.error(f"Error generating response: {str(e)}")
149
- return "An error occurred while processing your request. Please try again later."
150
-
151
- # Streamlit UI
152
  st.set_page_config(
153
  page_title="Securade.ai - Safety Copilot",
154
  page_icon="https://securade.ai/favicon.ico",
@@ -156,33 +83,55 @@ st.set_page_config(
156
  initial_sidebar_state="collapsed",
157
  menu_items={
158
  "About": "# Securade.ai Safety Copilot v0.1\n [https://securade.ai](https://securade.ai)",
159
- "Get Help": "https://securade.ai",
160
- "Report a Bug": "mailto:[email protected]",
161
- },
162
  )
163
 
164
  st.title("👷‍♂️ Safety Copilot 🦺")
165
- st.markdown(
166
- "Chat with your personal safety assistant about any health & safety related queries. "
167
- "[[blog](https://securade.ai/blog/how-securade-ai-safety-copilot-transforms-worker-safety.html)|"
168
- "[paper](https://securade.ai/assets/pdfs/Securade.ai-Safety-Copilot-Whitepaper.pdf)]"
169
- )
170
- st.markdown(f"_{stats} queries answered!_")
171
 
172
- # Display chat history
 
 
 
 
 
 
 
173
  for message in st.session_state.chat_history:
174
  with st.chat_message(message["role"]):
175
  st.markdown(message["content"])
176
-
177
- # Handle user input
178
- if prompt := st.chat_input("Ask a question"):
 
 
179
  st.session_state.chat_history.append({"role": "user", "content": prompt})
 
180
  with st.chat_message("user"):
181
  st.markdown(prompt)
182
 
183
- with st.spinner("Safety briefing in progress..."):
184
  response = response_generator(prompt)
185
 
 
186
  with st.chat_message("assistant"):
187
  st.markdown(response)
188
- st.session_state.chat_history.append({"role": "assistant", "content": response})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
  import os
3
  import streamlit as st
4
+ import anthropic
5
+
6
  from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
7
+ from langchain_community.embeddings import HuggingFaceBgeEmbeddings
8
  from langchain_community.vectorstores import SupabaseVectorStore
9
  from langchain_community.llms import HuggingFaceEndpoint
10
+ from langchain_community.vectorstores import SupabaseVectorStore
11
+
12
  from langchain.chains import ConversationalRetrievalChain
13
  from langchain.memory import ConversationBufferMemory
14
+
15
  from supabase import Client, create_client
16
  from streamlit.logger import get_logger
17
+ from stats import get_usage, add_usage
 
 
 
18
 
19
  supabase_url = st.secrets.SUPABASE_URL
20
  supabase_key = st.secrets.SUPABASE_KEY
 
23
  hf_api_key = st.secrets.hf_api_key
24
  username = st.secrets.username
25
 
 
26
  supabase: Client = create_client(supabase_url, supabase_key)
27
+ logger = get_logger(__name__)
28
 
29
+ embeddings = HuggingFaceInferenceAPIEmbeddings(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  api_key=hf_api_key,
31
  model_name="BAAI/bge-large-en-v1.5",
32
  api_url="https://router.huggingface.co/hf-inference/pipeline/feature-extraction/",
33
  )
34
 
35
+ if 'chat_history' not in st.session_state:
36
+ st.session_state['chat_history'] = []
 
37
 
38
+ vector_store = SupabaseVectorStore(supabase, embeddings, query_name='match_documents', table_name="documents")
39
+ memory = ConversationBufferMemory(memory_key="chat_history", input_key='question', output_key='answer', return_messages=True)
 
 
 
 
 
 
 
 
 
 
 
40
 
 
41
  model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
42
+
43
  temperature = 0.1
44
  max_tokens = 500
45
+ stats = str(get_usage(supabase))
 
 
 
 
 
 
 
 
46
 
47
  def response_generator(query):
48
+ qa = None
49
+ add_usage(supabase, "chat", "prompt" + query, {"model": model, "temperature": temperature})
50
+ logger.info('Using HF model %s', model)
51
+ # print(st.session_state['max_tokens'])
52
+ endpoint_url = ("https://api-inference.huggingface.co/models/"+ model)
53
+ model_kwargs = {"temperature" : temperature,
54
+ "max_new_tokens" : max_tokens,
55
+ # "repetition_penalty" : 1.1,
56
+ "return_full_text" : False}
57
+ hf = HuggingFaceEndpoint(
58
+ endpoint_url=endpoint_url,
59
+ task="text-generation",
60
+ huggingfacehub_api_token=hf_api_key,
61
+ model_kwargs=model_kwargs
62
+ )
63
+ qa = ConversationalRetrievalChain.from_llm(hf, retriever=vector_store.as_retriever(search_kwargs={"score_threshold": 0.6, "k": 4,"filter": {"user": username}}), memory=memory, verbose=True, return_source_documents=True)
64
+
65
+ # Generate model's response
66
+ model_response = qa({"question": query})
67
+ logger.info('Result: %s', model_response["answer"])
68
+ sources = model_response["source_documents"]
69
+ logger.info('Sources: %s', model_response["source_documents"])
70
+
71
+ if len(sources) > 0:
72
+ response = model_response["answer"]
73
+ else:
74
+ response = "I am sorry, I do not have enough information to provide an answer. If there is a public source of data that you would like to add, please email [email protected]."
75
+
76
+ return response
77
+
78
+ # Set the theme
 
 
 
 
 
 
 
 
79
  st.set_page_config(
80
  page_title="Securade.ai - Safety Copilot",
81
  page_icon="https://securade.ai/favicon.ico",
 
83
  initial_sidebar_state="collapsed",
84
  menu_items={
85
  "About": "# Securade.ai Safety Copilot v0.1\n [https://securade.ai](https://securade.ai)",
86
+ "Get Help" : "https://securade.ai",
87
+ "Report a Bug": "mailto:[email protected]"
88
+ }
89
  )
90
 
91
  st.title("👷‍♂️ Safety Copilot 🦺")
 
 
 
 
 
 
92
 
93
+ st.markdown("Chat with your personal safety assistant about any health & safety related queries. [[blog](https://securade.ai/blog/how-securade-ai-safety-copilot-transforms-worker-safety.html)|[paper](https://securade.ai/assets/pdfs/Securade.ai-Safety-Copilot-Whitepaper.pdf)]")
94
+ # st.markdown("Up-to-date with latest OSH regulations for Singapore, Indonesia, Malaysia & other parts of Asia.")
95
+ st.markdown("_"+ stats + " queries answered!_")
96
+
97
+ if 'chat_history' not in st.session_state:
98
+ st.session_state['chat_history'] = []
99
+
100
+ # Display chat messages from history on app rerun
101
  for message in st.session_state.chat_history:
102
  with st.chat_message(message["role"]):
103
  st.markdown(message["content"])
104
+
105
+ # Accept user input
106
+ if prompt := st.chat_input("Ask a question"):
107
+ # print(prompt)
108
+ # Add user message to chat history
109
  st.session_state.chat_history.append({"role": "user", "content": prompt})
110
+ # Display user message in chat message container
111
  with st.chat_message("user"):
112
  st.markdown(prompt)
113
 
114
+ with st.spinner('Safety briefing in progress...'):
115
  response = response_generator(prompt)
116
 
117
+ # Display assistant response in chat message container
118
  with st.chat_message("assistant"):
119
  st.markdown(response)
120
+ # Add assistant response to chat history
121
+ # print(response)
122
+ st.session_state.chat_history.append({"role": "assistant", "content": response})
123
+
124
+ # query = st.text_area("## Ask a question (" + stats + " queries answered so far)", max_chars=500)
125
+ # columns = st.columns(2)
126
+ # with columns[0]:
127
+ # button = st.button("Ask")
128
+ # with columns[1]:
129
+ # clear_history = st.button("Clear History", type='secondary')
130
+
131
+ # st.markdown("---\n\n")
132
+
133
+ # if clear_history:
134
+ # # Clear memory in Langchain
135
+ # memory.clear()
136
+ # st.session_state['chat_history'] = []
137
+ # st.experimental_rerun()