", unsafe_allow_html=True)
if genparam.TOKEN_CAPTURE_ENABLED:
chat_number = len(chat_history) // 2
token_calculations = capture_tokens(prompt_data, response, chat_number)
if token_calculations:
st.sidebar.code(token_calculations)
return response
def capture_tokens(prompt_data, response, chat_number):
if not genparam.TOKEN_CAPTURE_ENABLED:
return
watsonx_llm = ModelInference(
api_client=client,
model_id=genparam.SELECTED_MODEL,
verify=genparam.VERIFY
)
input_tokens = watsonx_llm.tokenize(prompt=prompt_data)["result"]["token_count"]
output_tokens = watsonx_llm.tokenize(prompt=response)["result"]["token_count"]
total_tokens = input_tokens + output_tokens
st.session_state.token_capture.append(f"chat {chat_number}: {input_tokens} + {output_tokens} = {total_tokens}")
token_calculations = "\n".join(st.session_state.token_capture)
return token_calculations
def main():
initialize_session_state()
# Apply custom styles
#st.markdown(hide_sidebar_style, unsafe_allow_html=True)
st.markdown(three_column_style, unsafe_allow_html=True)
# Sidebar configuration
st.sidebar.header('The Tribunal')
st.sidebar.write('')
st.sidebar.write('')
if not check_password():
st.stop()
# Main chat interface
user_input = st.chat_input("Ask your question here", key="user_input")
if user_input:
# Create three columns
col1, col2, col3 = st.columns(3)
with col1:
st.markdown("", unsafe_allow_html=True)
st.subheader(genparam.BOT_1_NAME)
# Display chat history for bot 1
for message in st.session_state.chat_history_1:
with st.chat_message(message["role"], avatar="👤" if message["role"] == "user" else "🥸"):
#st.markdown(f"{message['content']}", unsafe_allow_html=True)
st.markdown(message['content'])
# Add user message and get bot 1 response
st.session_state.chat_history_1.append({"role": "user", "content": user_input, "avatar":"👤"})
milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
client,
wml_credentials,
st.secrets["vector_index_id"]
)
system_prompt = genparam.BOT_1_PROMPT
response = fetch_response(
user_input,
milvus_client,
emb,
vector_index_properties,
vector_store_schema,
system_prompt,
st.session_state.chat_history_1
)
st.session_state.chat_history_1.append({"role": "Tribunal", "content": response, "avatar":"🥸"})
st.markdown("
", unsafe_allow_html=True)
with col2:
st.markdown("", unsafe_allow_html=True)
st.subheader(genparam.BOT_2_NAME)
# Display chat history for bot 2
for message in st.session_state.chat_history_2:
with st.chat_message(message["role"], avatar="👤" if message["role"] == "user" else "🥸"):
#st.markdown(f"{message['content']}", unsafe_allow_html=True)
st.markdown(message['content'])
# Add user message and get bot 2 response
st.session_state.chat_history_2.append({"role": "user", "content": user_input, "avatar":"👤"})
milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
client,
wml_credentials,
st.secrets["vector_index_id"]
)
response = fetch_response(
user_input,
milvus_client,
emb,
vector_index_properties,
vector_store_schema,
genparam.BOT_2_PROMPT,
st.session_state.chat_history_2
)
st.session_state.chat_history_2.append({"role": "Tribunal", "content": response, "avatar":"🥸"})
st.markdown("
", unsafe_allow_html=True)
with col3:
st.markdown("", unsafe_allow_html=True)
st.subheader(genparam.BOT_3_NAME)
# Display chat history for bot 3
for message in st.session_state.chat_history_3:
with st.chat_message(message["role"], avatar="👤" if message["role"] == "user" else "🥸"):
#st.markdown(f"{message['content']}", unsafe_allow_html=True)
st.markdown(message['content'])
# Add user message and get bot 3 response
st.session_state.chat_history_3.append({"role": "user", "content": user_input, "avatar":"👤"})
milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
client,
wml_credentials,
st.secrets["vector_index_id"]
)
response = fetch_response(
user_input,
milvus_client,
emb,
vector_index_properties,
vector_store_schema,
genparam.BOT_3_PROMPT,
st.session_state.chat_history_3
)
st.session_state.chat_history_3.append({"role": "Tribunal", "content": response, "avatar":"🥸"})
st.markdown("
", unsafe_allow_html=True)
if __name__ == "__main__":
main()