MilanM commited on
Commit
5142da8
·
verified ·
1 Parent(s): 14de980

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +367 -0
app.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from io import BytesIO
3
+ import ibm_watsonx_ai
4
+ import secretsload
5
+ import genparam
6
+ import requests
7
+ import time
8
+ import re
9
+ import json
10
+
11
+ from ibm_watsonx_ai.foundation_models import ModelInference
12
+ from ibm_watsonx_ai import Credentials, APIClient
13
+ from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
14
+ from ibm_watsonx_ai.metanames import GenTextReturnOptMetaNames as RetParams
15
+
16
+ from ibm_watsonx_ai.foundation_models import Embeddings
17
+ from ibm_watsonx_ai.foundation_models.utils.enums import EmbeddingTypes
18
+ from pymilvus import MilvusClient
19
+
20
+ from secretsload import load_stsecrets
21
+
22
+ credentials = load_stsecrets()
23
+
24
+ st.set_page_config(
25
+ page_title="The Tribunal",
26
+ page_icon="🥸",
27
+ initial_sidebar_state="collapsed",
28
+ layout="wide"
29
+ )
30
+
31
+ # Password protection
32
+ def check_password():
33
+ def password_entered():
34
+ if st.session_state["password"] == st.secrets["app_password"]:
35
+ st.session_state["password_correct"] = True
36
+ del st.session_state["password"]
37
+ else:
38
+ st.session_state["password_correct"] = False
39
+
40
+ if "password_correct" not in st.session_state:
41
+ st.markdown("\n\n")
42
+ st.text_input("Enter the password", type="password", on_change=password_entered, key="password")
43
+ st.divider()
44
+ st.info("Designed and developed by Milan Mrdenovic © IBM Norway 2024")
45
+ return False
46
+ elif not st.session_state["password_correct"]:
47
+ st.markdown("\n\n")
48
+ st.text_input("Enter the password", type="password", on_change=password_entered, key="password")
49
+ st.divider()
50
+ st.error("😕 Incorrect password")
51
+ st.info("Designed and developed by Milan Mrdenovic © IBM Norway 2024")
52
+ return False
53
+ else:
54
+ return True
55
+
56
+ def initialize_session_state():
57
+ if 'chat_history_1' not in st.session_state:
58
+ st.session_state.chat_history_1 = []
59
+ if 'chat_history_2' not in st.session_state:
60
+ st.session_state.chat_history_2 = []
61
+ if 'chat_history_3' not in st.session_state:
62
+ st.session_state.chat_history_3 = []
63
+ if 'first_question' not in st.session_state:
64
+ st.session_state.first_question = False
65
+ if "counter" not in st.session_state:
66
+ st.session_state["counter"] = 0
67
+ if 'token_capture' not in st.session_state:
68
+ st.session_state.token_capture = []
69
+
70
+
71
+
72
+ three_column_style = """
73
+ <style>
74
+ .stColumn {
75
+ padding: 0.5rem;
76
+ border-right: 1px solid #dedede;
77
+ }
78
+ .stColumn:last-child {
79
+ border-right: none;
80
+ }
81
+ .chat-container {
82
+ height: calc(100vh - 200px);
83
+ overflow-y: auto;
84
+ }
85
+ </style>
86
+ """
87
+
88
+ def setup_client(project_id):
89
+ credentials = Credentials(
90
+ url=st.secrets["url"],
91
+ api_key=st.secrets["api_key"]
92
+ )
93
+ apo = st.secrets["api_key"]
94
+ client = APIClient(credentials, project_id=project_id)
95
+ return credentials, client
96
+
97
+ wml_credentials, client = setup_client(st.secrets["project_id"])
98
+
99
+ def setup_vector_index(client, wml_credentials, vector_index_id):
100
+ vector_index_details = client.data_assets.get_details(vector_index_id)
101
+ vector_index_properties = vector_index_details["entity"]["vector_index"]
102
+
103
+ emb = Embeddings(
104
+ model_id=vector_index_properties["settings"]["embedding_model_id"],
105
+ #model_id="sentence-transformers/all-minilm-l12-v2",
106
+ credentials=wml_credentials,
107
+ project_id=st.secrets["project_id"],
108
+ params={
109
+ "truncate_input_tokens": 512
110
+ }
111
+ )
112
+
113
+ vector_store_schema = vector_index_properties["settings"]["schema_fields"]
114
+ connection_details = client.connections.get_details(vector_index_details["entity"]["vector_index"]["store"]["connection_id"])
115
+ connection_properties = connection_details["entity"]["properties"]
116
+
117
+ milvus_client = MilvusClient(
118
+ uri=f'https://{connection_properties.get("host")}:{connection_properties.get("port")}',
119
+ user=connection_properties.get("username"),
120
+ password=connection_properties.get("password"),
121
+ db_name=vector_index_properties["store"]["database"]
122
+ )
123
+
124
+ return milvus_client, emb, vector_index_properties, vector_store_schema
125
+
126
+ def proximity_search(question, milvus_client, emb, vector_index_properties, vector_store_schema):
127
+ query_vectors = emb.embed_query(question)
128
+ milvus_response = milvus_client.search(
129
+ collection_name=vector_index_properties["store"]["index"],
130
+ data=[query_vectors],
131
+ limit=vector_index_properties["settings"]["top_k"],
132
+ metric_type="L2",
133
+ output_fields=[
134
+ vector_store_schema.get("text"),
135
+ vector_store_schema.get("document_name"),
136
+ vector_store_schema.get("page_number")
137
+ ]
138
+ )
139
+
140
+ documents = []
141
+
142
+ for hit in milvus_response[0]:
143
+ text = hit["entity"].get(vector_store_schema.get("text"), "")
144
+ doc_name = hit["entity"].get(vector_store_schema.get("document_name"), "Unknown Document")
145
+ page_num = hit["entity"].get(vector_store_schema.get("page_number"), "N/A")
146
+
147
+ formatted_result = f"Document: {doc_name}\nContent: {text}\nPage: {page_num}\n"
148
+ documents.append(formatted_result)
149
+
150
+ joined = "\n".join(documents)
151
+ retrieved = f"""Number of Retrieved Documents: {len(documents)}\n\n{joined}"""
152
+
153
+ return retrieved
154
+
155
+ def prepare_prompt(prompt, chat_history):
156
+ if genparam.TYPE == "chat" and chat_history:
157
+ chats = "\n".join([f"{message['role']}: \"{message['content']}\"" for message in chat_history])
158
+ prompt = f"""Retrieved Contextual Information:\n__grounding__\n\nConversation History:\n{chats}\n\nNew User Input: {prompt}"""
159
+ return prompt
160
+ else:
161
+ prompt = f"""Retrieved Contextual Information:\n__grounding__\n\nUser Input: {prompt}"""
162
+ return prompt
163
+
164
+ def apply_prompt_syntax(prompt, system_prompt, prompt_template, bake_in_prompt_syntax):
165
+ model_family_syntax = {
166
+ "llama3-instruct (llama-3, 3.1 & 3.2) - system": """<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""",
167
+ "llama3-instruct (llama-3, 3.1 & 3.2) - user": """<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""",
168
+ "granite-13b-chat & instruct - system": """<|system|>\n{system_prompt}\n<|user|>\n{prompt}\n<|assistant|>\n\n""",
169
+ "granite-13b-chat & instruct - user": """<|user|>\n{prompt}\n<|assistant|>\n\n""",
170
+ "mistral & mixtral v2 tokenizer - system": """<s>[INST] System Prompt: {system_prompt} [/INST][INST] {prompt} [/INST]\n\n""",
171
+ "mistral & mixtral v2 tokenizer - user": """<s>[INST] {prompt} [/INST]\n\n""",
172
+ "no syntax - system": """{system_prompt}\n\n{prompt}""",
173
+ "no syntax - user": """{prompt}"""
174
+ }
175
+
176
+ if bake_in_prompt_syntax:
177
+ template = model_family_syntax[prompt_template]
178
+ if system_prompt:
179
+ return template.format(system_prompt=system_prompt, prompt=prompt)
180
+ return prompt
181
+
182
+ def generate_response(watsonx_llm, prompt_data, params):
183
+ generated_response = watsonx_llm.generate_text_stream(prompt=prompt_data, params=params)
184
+ for chunk in generated_response:
185
+ yield chunk
186
+
187
+ def fetch_response(user_input, milvus_client, emb, vector_index_properties, vector_store_schema, system_prompt, chat_history):
188
+ grounding = proximity_search(
189
+ question=user_input,
190
+ milvus_client=milvus_client,
191
+ emb=emb,
192
+ vector_index_properties=vector_index_properties,
193
+ vector_store_schema=vector_store_schema
194
+ )
195
+ prompt = prepare_prompt(user_input, chat_history)
196
+
197
+ prompt_data = apply_prompt_syntax(
198
+ prompt,
199
+ system_prompt,
200
+ genparam.PROMPT_TEMPLATE,
201
+ genparam.BAKE_IN_PROMPT_SYNTAX
202
+ )
203
+
204
+ prompt_data = prompt_data.replace("__grounding__", grounding)
205
+
206
+ watsonx_llm = ModelInference(
207
+ api_client=client,
208
+ model_id=genparam.SELECTED_MODEL,
209
+ verify=genparam.VERIFY
210
+ )
211
+
212
+ params = {
213
+ GenParams.DECODING_METHOD: genparam.DECODING_METHOD,
214
+ GenParams.MAX_NEW_TOKENS: genparam.MAX_NEW_TOKENS,
215
+ GenParams.MIN_NEW_TOKENS: genparam.MIN_NEW_TOKENS,
216
+ GenParams.REPETITION_PENALTY: genparam.REPETITION_PENALTY,
217
+ GenParams.STOP_SEQUENCES: genparam.STOP_SEQUENCES
218
+ }
219
+
220
+ with st.chat_message("Tribunal", avatar="🥸"):
221
+ if genparam.TOKEN_CAPTURE_ENABLED:
222
+ st.code(prompt_data, line_numbers=True, wrap_lines=True)
223
+ stream = generate_response(watsonx_llm, prompt_data, params)
224
+ response = st.write_stream(stream)
225
+ # response = st.write_stream(stream, f"<span style='color: {color};'>", unsafe_allow_html=True)
226
+
227
+ if genparam.TOKEN_CAPTURE_ENABLED:
228
+ chat_number = len(chat_history) // 2
229
+ token_calculations = capture_tokens(prompt_data, response, chat_number)
230
+ if token_calculations:
231
+ st.sidebar.code(token_calculations)
232
+
233
+ return response
234
+
235
+ def capture_tokens(prompt_data, response, chat_number):
236
+ if not genparam.TOKEN_CAPTURE_ENABLED:
237
+ return
238
+
239
+ watsonx_llm = ModelInference(
240
+ api_client=client,
241
+ model_id=genparam.SELECTED_MODEL,
242
+ verify=genparam.VERIFY
243
+ )
244
+
245
+ input_tokens = watsonx_llm.tokenize(prompt=prompt_data)["result"]["token_count"]
246
+ output_tokens = watsonx_llm.tokenize(prompt=response)["result"]["token_count"]
247
+ total_tokens = input_tokens + output_tokens
248
+
249
+ st.session_state.token_capture.append(f"chat {chat_number}: {input_tokens} + {output_tokens} = {total_tokens}")
250
+
251
+ token_calculations = "\n".join(st.session_state.token_capture)
252
+ return token_calculations
253
+
254
+ def main():
255
+ initialize_session_state()
256
+
257
+ # Apply custom styles
258
+ #st.markdown(hide_sidebar_style, unsafe_allow_html=True)
259
+ st.markdown(three_column_style, unsafe_allow_html=True)
260
+
261
+ # Sidebar configuration
262
+ st.sidebar.header('The Tribunal')
263
+ st.sidebar.write('')
264
+ st.sidebar.write('')
265
+
266
+ if not check_password():
267
+ st.stop()
268
+
269
+ # Main chat interface
270
+ user_input = st.chat_input("Ask your question here", key="user_input")
271
+
272
+ if user_input:
273
+ # Create three columns
274
+ col1, col2, col3 = st.columns(3)
275
+
276
+ with col1:
277
+ st.markdown("<div class='chat-container'>", unsafe_allow_html=True)
278
+ st.subheader(genparam.BOT_1_NAME)
279
+ # Display chat history for bot 1
280
+ for message in st.session_state.chat_history_1:
281
+ with st.chat_message(message["role"], avatar="👤" if message["role"] == "user" else "🥸"):
282
+ #st.markdown(f"<span style='color: #1565C0;'>{message['content']}</span>", unsafe_allow_html=True)
283
+ st.markdown(message['content'])
284
+
285
+ # Add user message and get bot 1 response
286
+ st.session_state.chat_history_1.append({"role": "user", "content": user_input, "avatar":"👤"})
287
+ milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
288
+ client,
289
+ wml_credentials,
290
+ st.secrets["vector_index_id"]
291
+ )
292
+ system_prompt = genparam.BOT_1_PROMPT
293
+
294
+ response = fetch_response(
295
+ user_input,
296
+ milvus_client,
297
+ emb,
298
+ vector_index_properties,
299
+ vector_store_schema,
300
+ system_prompt,
301
+ st.session_state.chat_history_1
302
+ )
303
+ st.session_state.chat_history_1.append({"role": "Tribunal", "content": response, "avatar":"🥸"})
304
+ st.markdown("</div>", unsafe_allow_html=True)
305
+
306
+ with col2:
307
+ st.markdown("<div class='chat-container'>", unsafe_allow_html=True)
308
+ st.subheader(genparam.BOT_2_NAME)
309
+ # Display chat history for bot 2
310
+ for message in st.session_state.chat_history_2:
311
+ with st.chat_message(message["role"], avatar="👤" if message["role"] == "user" else "🥸"):
312
+ #st.markdown(f"<span style='color: #2E7D32;'>{message['content']}</span>", unsafe_allow_html=True)
313
+ st.markdown(message['content'])
314
+
315
+
316
+ # Add user message and get bot 2 response
317
+ st.session_state.chat_history_2.append({"role": "user", "content": user_input, "avatar":"👤"})
318
+ milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
319
+ client,
320
+ wml_credentials,
321
+ st.secrets["vector_index_id"]
322
+ )
323
+
324
+ response = fetch_response(
325
+ user_input,
326
+ milvus_client,
327
+ emb,
328
+ vector_index_properties,
329
+ vector_store_schema,
330
+ genparam.BOT_2_PROMPT,
331
+ st.session_state.chat_history_2
332
+ )
333
+ st.session_state.chat_history_2.append({"role": "Tribunal", "content": response, "avatar":"🥸"})
334
+ st.markdown("</div>", unsafe_allow_html=True)
335
+
336
+ with col3:
337
+ st.markdown("<div class='chat-container'>", unsafe_allow_html=True)
338
+ st.subheader(genparam.BOT_3_NAME)
339
+ # Display chat history for bot 3
340
+ for message in st.session_state.chat_history_3:
341
+ with st.chat_message(message["role"], avatar="👤" if message["role"] == "user" else "🥸"):
342
+ #st.markdown(f"<span style='color: #6A1B9A;'>{message['content']}</span>", unsafe_allow_html=True)
343
+ st.markdown(message['content'])
344
+
345
+
346
+ # Add user message and get bot 3 response
347
+ st.session_state.chat_history_3.append({"role": "user", "content": user_input, "avatar":"👤"})
348
+ milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
349
+ client,
350
+ wml_credentials,
351
+ st.secrets["vector_index_id"]
352
+ )
353
+
354
+ response = fetch_response(
355
+ user_input,
356
+ milvus_client,
357
+ emb,
358
+ vector_index_properties,
359
+ vector_store_schema,
360
+ genparam.BOT_3_PROMPT,
361
+ st.session_state.chat_history_3
362
+ )
363
+ st.session_state.chat_history_3.append({"role": "Tribunal", "content": response, "avatar":"🥸"})
364
+ st.markdown("</div>", unsafe_allow_html=True)
365
+
366
+ if __name__ == "__main__":
367
+ main()