Fading_Moments / fading_moments.py
MilanM's picture
Update fading_moments.py
cfaaeeb verified
import streamlit as st
from knowledge_bases import KNOWLEDGE_BASE_OPTIONS, SYSTEM_PROMPTS
import genparam
from functions import (
check_password,
initialize_session_state,
setup_client,
fetch_response,
capture_tokens
)
# Custom CSS for the three-column layout
three_column_style = """
<style>
.stColumn {
padding: 0.5rem;
border-right: 1px solid #dedede;
}
.stColumn:last-child {
border-right: none;
}
.chat-container {
height: calc(100vh - 200px);
overflow-y: auto;
display: flex;
flex-direction: column;
}
.chat-messages {
display: flex;
flex-direction: column;
gap: 1rem;
}
</style>
"""
def main():
# Page configuration
st.set_page_config(
page_title="Fading Moments",
page_icon="🌫️",
initial_sidebar_state="collapsed",
layout="wide"
)
initialize_session_state()
st.markdown(three_column_style, unsafe_allow_html=True)
# Sidebar configuration
st.sidebar.header('The Solutioning Sages')
st.sidebar.divider()
# Knowledge Base Selection
selected_kb = st.sidebar.selectbox(
"Select Knowledge Base",
KNOWLEDGE_BASE_OPTIONS,
index=KNOWLEDGE_BASE_OPTIONS.index(st.session_state.selected_kb)
)
# Update knowledge base if selection changes
if selected_kb != st.session_state.selected_kb:
st.session_state.selected_kb = selected_kb
# Display current knowledge base contents
with st.sidebar.expander("Knowledge Base Contents"):
st.write("📄 [Knowledge base files would be listed here]")
# Display active model information
st.sidebar.divider()
active_model = genparam.SELECTED_MODEL_1 if genparam.ACTIVE_MODEL == 0 else genparam.SELECTED_MODEL_2
st.sidebar.markdown("**Active Model:**")
st.sidebar.code(active_model)
st.sidebar.divider()
# Display token statistics in sidebar
st.sidebar.subheader("Token Usage Statistics")
if st.session_state.token_statistics:
interaction_count = 0
stats_by_time = {}
# Group stats by timestamp
for stat in st.session_state.token_statistics:
if stat["timestamp"] not in stats_by_time:
stats_by_time[stat["timestamp"]] = []
stats_by_time[stat["timestamp"]].append(stat)
# Display grouped stats
for timestamp, stats in stats_by_time.items():
interaction_count += 1
st.sidebar.markdown(f"**Interaction {interaction_count}** ({timestamp})")
total_input = sum(stat['input_tokens'] for stat in stats)
total_output = sum(stat['output_tokens'] for stat in stats)
total = total_input + total_output
for stat in stats:
st.sidebar.markdown(
f"_{stat['bot_name']}_ \n"
f"Input: {stat['input_tokens']} tokens \n"
f"Output: {stat['output_tokens']} tokens \n"
f"Total: {stat['total_tokens']} tokens"
)
st.sidebar.markdown("**Interaction Totals:**")
st.sidebar.markdown(
f"Total Input: {total_input} tokens \n"
f"Total Output: {total_output} tokens \n"
f"Total Usage: {total} tokens"
)
st.sidebar.markdown("---")
if not check_password():
st.stop()
# Initialize WatsonX client
wml_credentials, client = setup_client()
# Get user input
user_input = st.chat_input("Ask your question here", key="user_input")
if user_input:
# Create three columns
col1, col2, col3 = st.columns(3)
# First column - PATH-er B.
with col1:
st.markdown("<div class='chat-container'>", unsafe_allow_html=True)
st.subheader(f"{genparam.BOT_1_AVATAR} {genparam.BOT_1_NAME}")
st.markdown("<div class='chat-messages'>", unsafe_allow_html=True)
# Display chat history
for message in st.session_state.chat_history_1:
with st.chat_message(message["role"], avatar=message.get("avatar", None)):
st.markdown(message['content'])
# Display new messages
with st.chat_message("user", avatar=genparam.USER_AVATAR):
st.markdown(user_input)
st.session_state.chat_history_1.append({
"role": "user",
"content": user_input,
"avatar": genparam.USER_AVATAR
})
# Get bot response
system_prompt = SYSTEM_PROMPTS[st.session_state.selected_kb]["bot_1"]
stream, prompt_data = fetch_response(
user_input,
client,
system_prompt,
st.session_state.chat_history_1
)
with st.chat_message(genparam.BOT_1_NAME, avatar=genparam.BOT_1_AVATAR):
response = st.write_stream(stream)
st.session_state.chat_history_1.append({
"role": genparam.BOT_1_NAME,
"content": response,
"avatar": genparam.BOT_1_AVATAR
})
# Capture tokens if enabled
if genparam.TOKEN_CAPTURE_ENABLED:
token_stats = capture_tokens(prompt_data, response, client, genparam.BOT_1_NAME)
if token_stats:
st.session_state.token_statistics.append(token_stats)
st.markdown("</div></div>", unsafe_allow_html=True)
# Second column - MOD-ther S.
with col2:
st.markdown("<div class='chat-container'>", unsafe_allow_html=True)
st.subheader(f"{genparam.BOT_2_AVATAR} {genparam.BOT_2_NAME}")
st.markdown("<div class='chat-messages'>", unsafe_allow_html=True)
# Display chat history
for message in st.session_state.chat_history_2:
with st.chat_message(message["role"], avatar=message.get("avatar", None)):
st.markdown(message['content'])
st.session_state.chat_history_2.append({
"role": "user",
"content": user_input,
"avatar": genparam.USER_AVATAR
})
# Get bot response
system_prompt = SYSTEM_PROMPTS[st.session_state.selected_kb]["bot_2"]
stream, prompt_data = fetch_response(
user_input,
client,
system_prompt,
st.session_state.chat_history_2
)
with st.chat_message(genparam.BOT_2_NAME, avatar=genparam.BOT_2_AVATAR):
response = st.write_stream(stream)
st.session_state.chat_history_2.append({
"role": genparam.BOT_2_NAME,
"content": response,
"avatar": genparam.BOT_2_AVATAR
})
# Capture tokens if enabled
if genparam.TOKEN_CAPTURE_ENABLED:
token_stats = capture_tokens(prompt_data, response, client, genparam.BOT_2_NAME)
if token_stats:
st.session_state.token_statistics.append(token_stats)
st.markdown("</div></div>", unsafe_allow_html=True)
# Third column - SYS-ter V.
with col3:
st.markdown("<div class='chat-container'>", unsafe_allow_html=True)
st.subheader(f"{genparam.BOT_3_AVATAR} {genparam.BOT_3_NAME}")
st.markdown("<div class='chat-messages'>", unsafe_allow_html=True)
# Display chat history
for message in st.session_state.chat_history_3:
with st.chat_message(message["role"], avatar=message.get("avatar", None)):
st.markdown(message['content'])
st.session_state.chat_history_3.append({
"role": "user",
"content": user_input,
"avatar": genparam.USER_AVATAR
})
# Get bot response
system_prompt = SYSTEM_PROMPTS[st.session_state.selected_kb]["bot_3"]
stream, prompt_data = fetch_response(
user_input,
client,
system_prompt,
st.session_state.chat_history_3
)
with st.chat_message(genparam.BOT_3_NAME, avatar=genparam.BOT_3_AVATAR):
response = st.write_stream(stream)
st.session_state.chat_history_3.append({
"role": genparam.BOT_3_NAME,
"content": response,
"avatar": genparam.BOT_3_AVATAR
})
# Capture tokens if enabled
if genparam.TOKEN_CAPTURE_ENABLED:
token_stats = capture_tokens(prompt_data, response, client, genparam.BOT_3_NAME)
if token_stats:
st.session_state.token_statistics.append(token_stats)
st.markdown("</div></div>", unsafe_allow_html=True)
# Update sidebar with new question
st.sidebar.markdown("---")
st.sidebar.markdown("**Latest Question:**")
st.sidebar.markdown(f"_{user_input}_")
if __name__ == "__main__":
main()