Update tribunal_2.py
Browse files- tribunal_2.py +50 -22
tribunal_2.py
CHANGED
@@ -85,6 +85,17 @@ three_column_style = """
|
|
85 |
</style>
|
86 |
"""
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
def setup_client(project_id):
|
89 |
credentials = Credentials(
|
90 |
url=st.secrets["url"],
|
@@ -197,7 +208,7 @@ def fetch_response(user_input, milvus_client, emb, vector_index_properties, vect
|
|
197 |
prompt_data = apply_prompt_syntax(
|
198 |
prompt,
|
199 |
system_prompt,
|
200 |
-
|
201 |
genparam.BAKE_IN_PROMPT_SYNTAX
|
202 |
)
|
203 |
|
@@ -205,7 +216,7 @@ def fetch_response(user_input, milvus_client, emb, vector_index_properties, vect
|
|
205 |
|
206 |
watsonx_llm = ModelInference(
|
207 |
api_client=client,
|
208 |
-
model_id=
|
209 |
verify=genparam.VERIFY
|
210 |
)
|
211 |
|
@@ -217,7 +228,19 @@ def fetch_response(user_input, milvus_client, emb, vector_index_properties, vect
|
|
217 |
GenParams.STOP_SEQUENCES: genparam.STOP_SEQUENCES
|
218 |
}
|
219 |
|
220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
if genparam.TOKEN_CAPTURE_ENABLED:
|
222 |
st.code(prompt_data, line_numbers=True, wrap_lines=True)
|
223 |
stream = generate_response(watsonx_llm, prompt_data, params)
|
@@ -278,17 +301,18 @@ def main():
|
|
278 |
st.subheader(genparam.BOT_1_NAME)
|
279 |
# Display chat history for bot 1
|
280 |
for message in st.session_state.chat_history_1:
|
281 |
-
with st.chat_message(message["role"], avatar=
|
282 |
#st.markdown(f"<span style='color: #1565C0;'>{message['content']}</span>", unsafe_allow_html=True)
|
283 |
st.markdown(message['content'])
|
284 |
|
285 |
# Add user message and get bot 1 response
|
286 |
-
st.session_state.chat_history_1.append({"role": "user", "content": user_input, "avatar":
|
287 |
milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
|
|
292 |
system_prompt = genparam.BOT_1_PROMPT
|
293 |
|
294 |
response = fetch_response(
|
@@ -300,7 +324,7 @@ def main():
|
|
300 |
system_prompt,
|
301 |
st.session_state.chat_history_1
|
302 |
)
|
303 |
-
st.session_state.chat_history_1.append({"role":
|
304 |
st.markdown("</div>", unsafe_allow_html=True)
|
305 |
|
306 |
with col2:
|
@@ -308,18 +332,20 @@ def main():
|
|
308 |
st.subheader(genparam.BOT_2_NAME)
|
309 |
# Display chat history for bot 2
|
310 |
for message in st.session_state.chat_history_2:
|
311 |
-
with st.chat_message(message["role"], avatar=
|
312 |
#st.markdown(f"<span style='color: #2E7D32;'>{message['content']}</span>", unsafe_allow_html=True)
|
313 |
st.markdown(message['content'])
|
314 |
|
315 |
|
316 |
# Add user message and get bot 2 response
|
317 |
-
st.session_state.chat_history_2.append({"role": "user", "content": user_input, "avatar":
|
318 |
milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
|
319 |
client,
|
320 |
-
wml_credentials,
|
321 |
-
|
|
|
322 |
)
|
|
|
323 |
|
324 |
response = fetch_response(
|
325 |
user_input,
|
@@ -327,10 +353,10 @@ def main():
|
|
327 |
emb,
|
328 |
vector_index_properties,
|
329 |
vector_store_schema,
|
330 |
-
|
331 |
st.session_state.chat_history_2
|
332 |
)
|
333 |
-
st.session_state.chat_history_2.append({"role":
|
334 |
st.markdown("</div>", unsafe_allow_html=True)
|
335 |
|
336 |
with col3:
|
@@ -338,18 +364,20 @@ def main():
|
|
338 |
st.subheader(genparam.BOT_3_NAME)
|
339 |
# Display chat history for bot 3
|
340 |
for message in st.session_state.chat_history_3:
|
341 |
-
with st.chat_message(message["role"], avatar=
|
342 |
#st.markdown(f"<span style='color: #6A1B9A;'>{message['content']}</span>", unsafe_allow_html=True)
|
343 |
st.markdown(message['content'])
|
344 |
|
345 |
|
346 |
# Add user message and get bot 3 response
|
347 |
-
st.session_state.chat_history_3.append({"role": "user", "content": user_input, "avatar":
|
348 |
milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
|
349 |
client,
|
350 |
-
wml_credentials,
|
351 |
-
|
|
|
352 |
)
|
|
|
353 |
|
354 |
response = fetch_response(
|
355 |
user_input,
|
@@ -357,10 +385,10 @@ def main():
|
|
357 |
emb,
|
358 |
vector_index_properties,
|
359 |
vector_store_schema,
|
360 |
-
|
361 |
st.session_state.chat_history_3
|
362 |
)
|
363 |
-
st.session_state.chat_history_3.append({"role":
|
364 |
st.markdown("</div>", unsafe_allow_html=True)
|
365 |
|
366 |
if __name__ == "__main__":
|
|
|
85 |
</style>
|
86 |
"""
|
87 |
|
88 |
+
#-----
|
89 |
+
def get_active_model():
|
90 |
+
return genparam.SELECTED_MODEL_1 if genparam.ACTIVE_MODEL == 0 else genparam.SELECTED_MODEL_2
|
91 |
+
|
92 |
+
def get_active_prompt_template():
|
93 |
+
return genparam.PROMPT_TEMPLATE_1 if genparam.ACTIVE_MODEL == 0 else genparam.PROMPT_TEMPLATE_2
|
94 |
+
|
95 |
+
def get_active_vector_index():
|
96 |
+
return st.secrets["vector_index_id_1"] if genparam.ACTIVE_INDEX == 0 else st.secrets["vector_index_id_2"]
|
97 |
+
#-----
|
98 |
+
|
99 |
def setup_client(project_id):
|
100 |
credentials = Credentials(
|
101 |
url=st.secrets["url"],
|
|
|
208 |
prompt_data = apply_prompt_syntax(
|
209 |
prompt,
|
210 |
system_prompt,
|
211 |
+
get_active_prompt_template(),
|
212 |
genparam.BAKE_IN_PROMPT_SYNTAX
|
213 |
)
|
214 |
|
|
|
216 |
|
217 |
watsonx_llm = ModelInference(
|
218 |
api_client=client,
|
219 |
+
model_id=get_active_model(),
|
220 |
verify=genparam.VERIFY
|
221 |
)
|
222 |
|
|
|
228 |
GenParams.STOP_SEQUENCES: genparam.STOP_SEQUENCES
|
229 |
}
|
230 |
|
231 |
+
bot_name = None
|
232 |
+
bot_avatar = None
|
233 |
+
if chat_history == st.session_state.chat_history_1:
|
234 |
+
bot_name = genparam.BOT_1_NAME
|
235 |
+
bot_avatar = genparam.BOT_1_AVATAR
|
236 |
+
elif chat_history == st.session_state.chat_history_2:
|
237 |
+
bot_name = genparam.BOT_2_NAME
|
238 |
+
bot_avatar = genparam.BOT_2_AVATAR
|
239 |
+
else:
|
240 |
+
bot_name = genparam.BOT_3_NAME
|
241 |
+
bot_avatar = genparam.BOT_3_AVATAR
|
242 |
+
|
243 |
+
with st.chat_message(bot_name, avatar=bot_avatar):
|
244 |
if genparam.TOKEN_CAPTURE_ENABLED:
|
245 |
st.code(prompt_data, line_numbers=True, wrap_lines=True)
|
246 |
stream = generate_response(watsonx_llm, prompt_data, params)
|
|
|
301 |
st.subheader(genparam.BOT_1_NAME)
|
302 |
# Display chat history for bot 1
|
303 |
for message in st.session_state.chat_history_1:
|
304 |
+
with st.chat_message(message["role"], avatar=genparam.USER_AVATAR if message["role"] == "user" else genparam.BOT_1_AVATAR):
|
305 |
#st.markdown(f"<span style='color: #1565C0;'>{message['content']}</span>", unsafe_allow_html=True)
|
306 |
st.markdown(message['content'])
|
307 |
|
308 |
# Add user message and get bot 1 response
|
309 |
+
st.session_state.chat_history_1.append({"role": "user", "content": user_input, "avatar": genparam.USER_AVATAR})
|
310 |
milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
|
311 |
+
client,
|
312 |
+
wml_credentials,
|
313 |
+
get_active_vector_index()
|
314 |
+
#st.secrets["vector_index_id"]
|
315 |
+
)
|
316 |
system_prompt = genparam.BOT_1_PROMPT
|
317 |
|
318 |
response = fetch_response(
|
|
|
324 |
system_prompt,
|
325 |
st.session_state.chat_history_1
|
326 |
)
|
327 |
+
st.session_state.chat_history_1.append({"role": genparam.BOT_1_NAME, "content": response, "avatar": genparam.BOT_1_AVATAR})
|
328 |
st.markdown("</div>", unsafe_allow_html=True)
|
329 |
|
330 |
with col2:
|
|
|
332 |
st.subheader(genparam.BOT_2_NAME)
|
333 |
# Display chat history for bot 2
|
334 |
for message in st.session_state.chat_history_2:
|
335 |
+
with st.chat_message(message["role"], avatar=genparam.USER_AVATAR if message["role"] == "user" else genparam.BOT_2_AVATAR):
|
336 |
#st.markdown(f"<span style='color: #2E7D32;'>{message['content']}</span>", unsafe_allow_html=True)
|
337 |
st.markdown(message['content'])
|
338 |
|
339 |
|
340 |
# Add user message and get bot 2 response
|
341 |
+
st.session_state.chat_history_2.append({"role": "user", "content": user_input, "avatar": genparam.USER_AVATAR})
|
342 |
milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
|
343 |
client,
|
344 |
+
wml_credentials,
|
345 |
+
get_active_vector_index()
|
346 |
+
#st.secrets["vector_index_id"]
|
347 |
)
|
348 |
+
system_prompt = genparam.BOT_2_PROMPT
|
349 |
|
350 |
response = fetch_response(
|
351 |
user_input,
|
|
|
353 |
emb,
|
354 |
vector_index_properties,
|
355 |
vector_store_schema,
|
356 |
+
system_prompt,
|
357 |
st.session_state.chat_history_2
|
358 |
)
|
359 |
+
st.session_state.chat_history_2.append({"role": genparam.BOT_2_NAME, "content": response, "avatar": genparam.BOT_2_AVATAR})
|
360 |
st.markdown("</div>", unsafe_allow_html=True)
|
361 |
|
362 |
with col3:
|
|
|
364 |
st.subheader(genparam.BOT_3_NAME)
|
365 |
# Display chat history for bot 3
|
366 |
for message in st.session_state.chat_history_3:
|
367 |
+
with st.chat_message(message["role"], avatar=genparam.USER_AVATAR if message["role"] == "user" else genparam.BOT_3_AVATAR):
|
368 |
#st.markdown(f"<span style='color: #6A1B9A;'>{message['content']}</span>", unsafe_allow_html=True)
|
369 |
st.markdown(message['content'])
|
370 |
|
371 |
|
372 |
# Add user message and get bot 3 response
|
373 |
+
st.session_state.chat_history_3.append({"role": "user", "content": user_input, "avatar": genparam.USER_AVATAR})
|
374 |
milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
|
375 |
client,
|
376 |
+
wml_credentials,
|
377 |
+
get_active_vector_index()
|
378 |
+
#st.secrets["vector_index_id"]
|
379 |
)
|
380 |
+
system_prompt = genparam.BOT_3_PROMPT
|
381 |
|
382 |
response = fetch_response(
|
383 |
user_input,
|
|
|
385 |
emb,
|
386 |
vector_index_properties,
|
387 |
vector_store_schema,
|
388 |
+
system_prompt,
|
389 |
st.session_state.chat_history_3
|
390 |
)
|
391 |
+
st.session_state.chat_history_3.append({"role": genparam.BOT_3_NAME, "content": response, "avatar": genparam.BOT_3_AVATAR})
|
392 |
st.markdown("</div>", unsafe_allow_html=True)
|
393 |
|
394 |
if __name__ == "__main__":
|