MilanM commited on
Commit
f8948b6
·
verified ·
1 Parent(s): 3b42419

Update tribunal_2.py

Browse files
Files changed (1) hide show
  1. tribunal_2.py +50 -22
tribunal_2.py CHANGED
@@ -85,6 +85,17 @@ three_column_style = """
85
  </style>
86
  """
87
 
 
 
 
 
 
 
 
 
 
 
 
88
  def setup_client(project_id):
89
  credentials = Credentials(
90
  url=st.secrets["url"],
@@ -197,7 +208,7 @@ def fetch_response(user_input, milvus_client, emb, vector_index_properties, vect
197
  prompt_data = apply_prompt_syntax(
198
  prompt,
199
  system_prompt,
200
- genparam.PROMPT_TEMPLATE,
201
  genparam.BAKE_IN_PROMPT_SYNTAX
202
  )
203
 
@@ -205,7 +216,7 @@ def fetch_response(user_input, milvus_client, emb, vector_index_properties, vect
205
 
206
  watsonx_llm = ModelInference(
207
  api_client=client,
208
- model_id=genparam.SELECTED_MODEL,
209
  verify=genparam.VERIFY
210
  )
211
 
@@ -217,7 +228,19 @@ def fetch_response(user_input, milvus_client, emb, vector_index_properties, vect
217
  GenParams.STOP_SEQUENCES: genparam.STOP_SEQUENCES
218
  }
219
 
220
- with st.chat_message("Tribunal", avatar="🥸"):
 
 
 
 
 
 
 
 
 
 
 
 
221
  if genparam.TOKEN_CAPTURE_ENABLED:
222
  st.code(prompt_data, line_numbers=True, wrap_lines=True)
223
  stream = generate_response(watsonx_llm, prompt_data, params)
@@ -278,17 +301,18 @@ def main():
278
  st.subheader(genparam.BOT_1_NAME)
279
  # Display chat history for bot 1
280
  for message in st.session_state.chat_history_1:
281
- with st.chat_message(message["role"], avatar="👤" if message["role"] == "user" else "🥸"):
282
  #st.markdown(f"<span style='color: #1565C0;'>{message['content']}</span>", unsafe_allow_html=True)
283
  st.markdown(message['content'])
284
 
285
  # Add user message and get bot 1 response
286
- st.session_state.chat_history_1.append({"role": "user", "content": user_input, "avatar":"👤"})
287
  milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
288
- client,
289
- wml_credentials,
290
- st.secrets["vector_index_id"]
291
- )
 
292
  system_prompt = genparam.BOT_1_PROMPT
293
 
294
  response = fetch_response(
@@ -300,7 +324,7 @@ def main():
300
  system_prompt,
301
  st.session_state.chat_history_1
302
  )
303
- st.session_state.chat_history_1.append({"role": "Tribunal", "content": response, "avatar":"🥸"})
304
  st.markdown("</div>", unsafe_allow_html=True)
305
 
306
  with col2:
@@ -308,18 +332,20 @@ def main():
308
  st.subheader(genparam.BOT_2_NAME)
309
  # Display chat history for bot 2
310
  for message in st.session_state.chat_history_2:
311
- with st.chat_message(message["role"], avatar="👤" if message["role"] == "user" else "🥸"):
312
  #st.markdown(f"<span style='color: #2E7D32;'>{message['content']}</span>", unsafe_allow_html=True)
313
  st.markdown(message['content'])
314
 
315
 
316
  # Add user message and get bot 2 response
317
- st.session_state.chat_history_2.append({"role": "user", "content": user_input, "avatar":"👤"})
318
  milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
319
  client,
320
- wml_credentials,
321
- st.secrets["vector_index_id"]
 
322
  )
 
323
 
324
  response = fetch_response(
325
  user_input,
@@ -327,10 +353,10 @@ def main():
327
  emb,
328
  vector_index_properties,
329
  vector_store_schema,
330
- genparam.BOT_2_PROMPT,
331
  st.session_state.chat_history_2
332
  )
333
- st.session_state.chat_history_2.append({"role": "Tribunal", "content": response, "avatar":"🥸"})
334
  st.markdown("</div>", unsafe_allow_html=True)
335
 
336
  with col3:
@@ -338,18 +364,20 @@ def main():
338
  st.subheader(genparam.BOT_3_NAME)
339
  # Display chat history for bot 3
340
  for message in st.session_state.chat_history_3:
341
- with st.chat_message(message["role"], avatar="👤" if message["role"] == "user" else "🥸"):
342
  #st.markdown(f"<span style='color: #6A1B9A;'>{message['content']}</span>", unsafe_allow_html=True)
343
  st.markdown(message['content'])
344
 
345
 
346
  # Add user message and get bot 3 response
347
- st.session_state.chat_history_3.append({"role": "user", "content": user_input, "avatar":"👤"})
348
  milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
349
  client,
350
- wml_credentials,
351
- st.secrets["vector_index_id"]
 
352
  )
 
353
 
354
  response = fetch_response(
355
  user_input,
@@ -357,10 +385,10 @@ def main():
357
  emb,
358
  vector_index_properties,
359
  vector_store_schema,
360
- genparam.BOT_3_PROMPT,
361
  st.session_state.chat_history_3
362
  )
363
- st.session_state.chat_history_3.append({"role": "Tribunal", "content": response, "avatar":"🥸"})
364
  st.markdown("</div>", unsafe_allow_html=True)
365
 
366
  if __name__ == "__main__":
 
85
  </style>
86
  """
87
 
88
+ #-----
89
+ def get_active_model():
90
+ return genparam.SELECTED_MODEL_1 if genparam.ACTIVE_MODEL == 0 else genparam.SELECTED_MODEL_2
91
+
92
+ def get_active_prompt_template():
93
+ return genparam.PROMPT_TEMPLATE_1 if genparam.ACTIVE_MODEL == 0 else genparam.PROMPT_TEMPLATE_2
94
+
95
+ def get_active_vector_index():
96
+ return st.secrets["vector_index_id_1"] if genparam.ACTIVE_INDEX == 0 else st.secrets["vector_index_id_2"]
97
+ #-----
98
+
99
  def setup_client(project_id):
100
  credentials = Credentials(
101
  url=st.secrets["url"],
 
208
  prompt_data = apply_prompt_syntax(
209
  prompt,
210
  system_prompt,
211
+ get_active_prompt_template(),
212
  genparam.BAKE_IN_PROMPT_SYNTAX
213
  )
214
 
 
216
 
217
  watsonx_llm = ModelInference(
218
  api_client=client,
219
+ model_id=get_active_model(),
220
  verify=genparam.VERIFY
221
  )
222
 
 
228
  GenParams.STOP_SEQUENCES: genparam.STOP_SEQUENCES
229
  }
230
 
231
+ bot_name = None
232
+ bot_avatar = None
233
+ if chat_history == st.session_state.chat_history_1:
234
+ bot_name = genparam.BOT_1_NAME
235
+ bot_avatar = genparam.BOT_1_AVATAR
236
+ elif chat_history == st.session_state.chat_history_2:
237
+ bot_name = genparam.BOT_2_NAME
238
+ bot_avatar = genparam.BOT_2_AVATAR
239
+ else:
240
+ bot_name = genparam.BOT_3_NAME
241
+ bot_avatar = genparam.BOT_3_AVATAR
242
+
243
+ with st.chat_message(bot_name, avatar=bot_avatar):
244
  if genparam.TOKEN_CAPTURE_ENABLED:
245
  st.code(prompt_data, line_numbers=True, wrap_lines=True)
246
  stream = generate_response(watsonx_llm, prompt_data, params)
 
301
  st.subheader(genparam.BOT_1_NAME)
302
  # Display chat history for bot 1
303
  for message in st.session_state.chat_history_1:
304
+ with st.chat_message(message["role"], avatar=genparam.USER_AVATAR if message["role"] == "user" else genparam.BOT_1_AVATAR):
305
  #st.markdown(f"<span style='color: #1565C0;'>{message['content']}</span>", unsafe_allow_html=True)
306
  st.markdown(message['content'])
307
 
308
  # Add user message and get bot 1 response
309
+ st.session_state.chat_history_1.append({"role": "user", "content": user_input, "avatar": genparam.USER_AVATAR})
310
  milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
311
+ client,
312
+ wml_credentials,
313
+ get_active_vector_index()
314
+ #st.secrets["vector_index_id"]
315
+ )
316
  system_prompt = genparam.BOT_1_PROMPT
317
 
318
  response = fetch_response(
 
324
  system_prompt,
325
  st.session_state.chat_history_1
326
  )
327
+ st.session_state.chat_history_1.append({"role": genparam.BOT_1_NAME, "content": response, "avatar": genparam.BOT_1_AVATAR})
328
  st.markdown("</div>", unsafe_allow_html=True)
329
 
330
  with col2:
 
332
  st.subheader(genparam.BOT_2_NAME)
333
  # Display chat history for bot 2
334
  for message in st.session_state.chat_history_2:
335
+ with st.chat_message(message["role"], avatar=genparam.USER_AVATAR if message["role"] == "user" else genparam.BOT_2_AVATAR):
336
  #st.markdown(f"<span style='color: #2E7D32;'>{message['content']}</span>", unsafe_allow_html=True)
337
  st.markdown(message['content'])
338
 
339
 
340
  # Add user message and get bot 2 response
341
+ st.session_state.chat_history_2.append({"role": "user", "content": user_input, "avatar": genparam.USER_AVATAR})
342
  milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
343
  client,
344
+ wml_credentials,
345
+ get_active_vector_index()
346
+ #st.secrets["vector_index_id"]
347
  )
348
+ system_prompt = genparam.BOT_2_PROMPT
349
 
350
  response = fetch_response(
351
  user_input,
 
353
  emb,
354
  vector_index_properties,
355
  vector_store_schema,
356
+ system_prompt,
357
  st.session_state.chat_history_2
358
  )
359
+ st.session_state.chat_history_2.append({"role": genparam.BOT_2_NAME, "content": response, "avatar": genparam.BOT_2_AVATAR})
360
  st.markdown("</div>", unsafe_allow_html=True)
361
 
362
  with col3:
 
364
  st.subheader(genparam.BOT_3_NAME)
365
  # Display chat history for bot 3
366
  for message in st.session_state.chat_history_3:
367
+ with st.chat_message(message["role"], avatar=genparam.USER_AVATAR if message["role"] == "user" else genparam.BOT_3_AVATAR):
368
  #st.markdown(f"<span style='color: #6A1B9A;'>{message['content']}</span>", unsafe_allow_html=True)
369
  st.markdown(message['content'])
370
 
371
 
372
  # Add user message and get bot 3 response
373
+ st.session_state.chat_history_3.append({"role": "user", "content": user_input, "avatar": genparam.USER_AVATAR})
374
  milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
375
  client,
376
+ wml_credentials,
377
+ get_active_vector_index()
378
+ #st.secrets["vector_index_id"]
379
  )
380
+ system_prompt = genparam.BOT_3_PROMPT
381
 
382
  response = fetch_response(
383
  user_input,
 
385
  emb,
386
  vector_index_properties,
387
  vector_store_schema,
388
+ system_prompt,
389
  st.session_state.chat_history_3
390
  )
391
+ st.session_state.chat_history_3.append({"role": genparam.BOT_3_NAME, "content": response, "avatar": genparam.BOT_3_AVATAR})
392
  st.markdown("</div>", unsafe_allow_html=True)
393
 
394
  if __name__ == "__main__":