Fading_Moments / neo_sages2.py
MilanM's picture
Create neo_sages2.py
226c234 verified
import streamlit as st
from io import BytesIO
import ibm_watsonx_ai
import secretsload
import genparam
import requests
import time
import re
import json
from ibm_watsonx_ai.foundation_models import ModelInference
from ibm_watsonx_ai import Credentials, APIClient
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
from ibm_watsonx_ai.metanames import GenTextReturnOptMetaNames as RetParams
from ibm_watsonx_ai.foundation_models import Embeddings
from ibm_watsonx_ai.foundation_models.utils.enums import EmbeddingTypes
from pymilvus import MilvusClient
from secretsload import load_stsecrets
credentials = load_stsecrets()
st.set_page_config(
page_title="The Solutioning Sages",
page_icon="🪄",
initial_sidebar_state="collapsed",
layout="wide"
)
# Password protection
def check_password():
def password_entered():
if st.session_state["password"] == st.secrets["app_password"]:
st.session_state["password_correct"] = True
del st.session_state["password"]
else:
st.session_state["password_correct"] = False
if "password_correct" not in st.session_state:
st.markdown("\n\n")
st.text_input("Enter the password", type="password", on_change=password_entered, key="password")
st.divider()
st.info("Designed and developed by Milan Mrdenovic © IBM Norway 2024")
return False
elif not st.session_state["password_correct"]:
st.markdown("\n\n")
st.text_input("Enter the password", type="password", on_change=password_entered, key="password")
st.divider()
st.error("😕 Incorrect password")
st.info("Designed and developed by Milan Mrdenovic © IBM Norway 2024")
return False
else:
return True
def initialize_session_state():
if 'chat_history_1' not in st.session_state:
st.session_state.chat_history_1 = []
if 'chat_history_2' not in st.session_state:
st.session_state.chat_history_2 = []
if 'chat_history_3' not in st.session_state:
st.session_state.chat_history_3 = []
if 'first_question' not in st.session_state:
st.session_state.first_question = False
if "counter" not in st.session_state:
st.session_state["counter"] = 0
if 'token_statistics' not in st.session_state:
st.session_state.token_statistics = []
# three_column_style = """
# <style>
# .stColumn {
# padding: 0.5rem;
# border-right: 1px solid #dedede;
# }
# .stColumn:last-child {
# border-right: none;
# }
# .chat-container {
# height: calc(100vh - 200px);
# overflow-y: auto;
# }
# </style>
# """
three_column_style = """
<style>
.stColumn {
padding: 0.5rem;
border-right: 1px solid #dedede;
}
.stColumn:last-child {
border-right: none;
}
.chat-container {
height: calc(100vh - 200px);
overflow-y: auto;
display: flex;
flex-direction: column;
}
.chat-messages {
display: flex;
flex-direction: column;
gap: 1rem;
}
</style>
""" # Alt Style
#-----
def get_active_model():
return genparam.SELECTED_MODEL_1 if genparam.ACTIVE_MODEL == 0 else genparam.SELECTED_MODEL_2
def get_active_prompt_template():
return genparam.PROMPT_TEMPLATE_1 if genparam.ACTIVE_MODEL == 0 else genparam.PROMPT_TEMPLATE_2
def get_active_vector_index():
return st.secrets["vector_index_id_1"] if genparam.ACTIVE_INDEX == 0 else st.secrets["vector_index_id_2"]
#-----
def setup_client(project_id):
credentials = Credentials(
url=st.secrets["url"],
api_key=st.secrets["api_key"]
)
apo = st.secrets["api_key"]
client = APIClient(credentials, project_id=project_id)
return credentials, client
wml_credentials, client = setup_client(st.secrets["project_id"])
def setup_vector_index(client, wml_credentials, vector_index_id):
vector_index_details = client.data_assets.get_details(vector_index_id)
vector_index_properties = vector_index_details["entity"]["vector_index"]
emb = Embeddings(
model_id=vector_index_properties["settings"]["embedding_model_id"],
#model_id="sentence-transformers/all-minilm-l12-v2",
credentials=wml_credentials,
project_id=st.secrets["project_id"],
params={
"truncate_input_tokens": 512
}
)
vector_store_schema = vector_index_properties["settings"]["schema_fields"]
connection_details = client.connections.get_details(vector_index_details["entity"]["vector_index"]["store"]["connection_id"])
connection_properties = connection_details["entity"]["properties"]
milvus_client = MilvusClient(
uri=f'https://{connection_properties.get("host")}:{connection_properties.get("port")}',
user=connection_properties.get("username"),
password=connection_properties.get("password"),
db_name=vector_index_properties["store"]["database"]
)
return milvus_client, emb, vector_index_properties, vector_store_schema
def proximity_search(question, milvus_client, emb, vector_index_properties, vector_store_schema):
query_vectors = emb.embed_query(question)
milvus_response = milvus_client.search(
collection_name=vector_index_properties["store"]["index"],
data=[query_vectors],
limit=vector_index_properties["settings"]["top_k"],
metric_type="L2",
output_fields=[
vector_store_schema.get("text"),
vector_store_schema.get("document_name"),
vector_store_schema.get("page_number")
]
)
documents = []
for hit in milvus_response[0]:
text = hit["entity"].get(vector_store_schema.get("text"), "")
doc_name = hit["entity"].get(vector_store_schema.get("document_name"), "Unknown Document")
page_num = hit["entity"].get(vector_store_schema.get("page_number"), "N/A")
formatted_result = f"Document: {doc_name}\nContent: {text}\nPage: {page_num}\n"
documents.append(formatted_result)
joined = "\n".join(documents)
retrieved = f"""Number of Retrieved Documents: {len(documents)}\n\n{joined}"""
return retrieved
def prepare_prompt(prompt, chat_history):
if genparam.TYPE == "chat" and chat_history:
chats = "\n".join([f"{message['role']}: \"{message['content']}\"" for message in chat_history])
prompt = f"""Retrieved Contextual Information:\n__grounding__\n\nConversation History:\n{chats}\n\nNew User Input: {prompt}"""
return prompt
else:
prompt = f"""Retrieved Contextual Information:\n__grounding__\n\nUser Input: {prompt}"""
return prompt
def apply_prompt_syntax(prompt, system_prompt, prompt_template, bake_in_prompt_syntax):
model_family_syntax = {
"llama3-instruct (llama-3, 3.1 & 3.2) - system": """<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""",
"llama3-instruct (llama-3, 3.1 & 3.2) - user": """<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""",
"granite-13b-chat & instruct - system": """<|system|>\n{system_prompt}\n<|user|>\n{prompt}\n<|assistant|>\n\n""",
"granite-13b-chat & instruct - user": """<|user|>\n{prompt}\n<|assistant|>\n\n""",
"mistral & mixtral v2 tokenizer - system": """<s>[INST] System Prompt: {system_prompt} [/INST][INST] {prompt} [/INST]\n\n""",
"mistral & mixtral v2 tokenizer - user": """<s>[INST] {prompt} [/INST]\n\n""",
"no syntax - system": """{system_prompt}\n\n{prompt}""",
"no syntax - user": """{prompt}"""
}
if bake_in_prompt_syntax:
template = model_family_syntax[prompt_template]
if system_prompt:
return template.format(system_prompt=system_prompt, prompt=prompt)
return prompt
def generate_response(watsonx_llm, prompt_data, params):
generated_response = watsonx_llm.generate_text_stream(prompt=prompt_data, params=params)
for chunk in generated_response:
yield chunk
def fetch_response(user_input, milvus_client, emb, vector_index_properties, vector_store_schema, system_prompt, chat_history):
# Get grounding documents
grounding = proximity_search(
question=user_input,
milvus_client=milvus_client,
emb=emb,
vector_index_properties=vector_index_properties,
vector_store_schema=vector_store_schema
)
# Special handling for PATH-er B. (first column)
if chat_history == st.session_state.chat_history_1:
# Display user question first
with st.chat_message("user", avatar=genparam.USER_AVATAR):
st.markdown(user_input)
# Parse and display each document from the grounding
documents = grounding.split("\n\n")[2:] # Skip the count line and first newline
for doc in documents:
if doc.strip(): # Only process non-empty strings
parts = doc.split("\n")
doc_name = parts[0].replace("Document: ", "")
content = parts[1].replace("Content: ", "")
# Display document with delay
time.sleep(0.5)
st.markdown(f"**{doc_name}**")
st.code(content)
# Store in chat history
return grounding
# For MOD-ther S. (second column)
elif chat_history == st.session_state.chat_history_2:
prompt = prepare_prompt(user_input, chat_history)
prompt_data = apply_prompt_syntax(
prompt,
system_prompt,
get_active_prompt_template(),
genparam.BAKE_IN_PROMPT_SYNTAX
)
prompt_data = prompt_data.replace("__grounding__", grounding)
# Add debug information to column 1 if enabled
if genparam.INPUT_DEBUG_VIEW == 1:
with st.columns(3)[0]: # Access first column
st.markdown(f"**{genparam.BOT_2_AVATAR} {genparam.BOT_2_NAME} Prompt Data:**")
st.code(prompt_data, language="text")
# For SYS-ter V. (third column)
else:
# Get chat history from MOD-ther S.
mod_ther_history = st.session_state.chat_history_2
prompt = prepare_prompt(user_input, mod_ther_history)
prompt_data = apply_prompt_syntax(
prompt,
system_prompt,
get_active_prompt_template(),
genparam.BAKE_IN_PROMPT_SYNTAX
)
prompt_data = prompt_data.replace("__grounding__", grounding)
# Add debug information to column 1 if enabled
if genparam.INPUT_DEBUG_VIEW == 1:
with st.columns(3)[0]: # Access first column
st.markdown(f"**{genparam.BOT_3_AVATAR} {genparam.BOT_3_NAME} Prompt Data:**")
st.code(prompt_data, language="text")
# Continue with normal processing for columns 2 and 3
watsonx_llm = ModelInference(
api_client=client,
model_id=get_active_model(),
verify=genparam.VERIFY
)
params = {
GenParams.DECODING_METHOD: genparam.DECODING_METHOD,
GenParams.MAX_NEW_TOKENS: genparam.MAX_NEW_TOKENS,
GenParams.MIN_NEW_TOKENS: genparam.MIN_NEW_TOKENS,
GenParams.REPETITION_PENALTY: genparam.REPETITION_PENALTY,
GenParams.STOP_SEQUENCES: genparam.STOP_SEQUENCES
}
bot_name = None
bot_avatar = None
if chat_history == st.session_state.chat_history_1:
bot_name = genparam.BOT_1_NAME
bot_avatar = genparam.BOT_1_AVATAR
elif chat_history == st.session_state.chat_history_2:
bot_name = genparam.BOT_2_NAME
bot_avatar = genparam.BOT_2_AVATAR
else:
bot_name = genparam.BOT_3_NAME
bot_avatar = genparam.BOT_3_AVATAR
with st.chat_message(bot_name, avatar=bot_avatar):
if chat_history != st.session_state.chat_history_1: # Only generate responses for columns 2 and 3
stream = generate_response(watsonx_llm, prompt_data, params)
response = st.write_stream(stream)
# Only capture tokens for MOD-ther S. and SYS-ter V.
if genparam.TOKEN_CAPTURE_ENABLED and chat_history != st.session_state.chat_history_1:
token_stats = capture_tokens(prompt_data, response, bot_name)
if token_stats:
st.session_state.token_statistics.append(token_stats)
else:
response = grounding # For column 1, we already displayed the content
return response
def capture_tokens(prompt_data, response, chat_number):
if not genparam.TOKEN_CAPTURE_ENABLED:
return
watsonx_llm = ModelInference(
api_client=client,
model_id=genparam.SELECTED_MODEL,
verify=genparam.VERIFY
)
input_tokens = watsonx_llm.tokenize(prompt=prompt_data)["result"]["token_count"]
output_tokens = watsonx_llm.tokenize(prompt=response)["result"]["token_count"]
total_tokens = input_tokens + output_tokens
return {
"bot_name": bot_name,
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": total_tokens,
"timestamp": time.strftime("%H:%M:%S")
}
def main():
initialize_session_state()
# Apply custom styles
st.markdown(three_column_style, unsafe_allow_html=True)
# Sidebar configuration
st.sidebar.header('The Solutioning Sages')
st.sidebar.divider()
# Display token statistics in sidebar
st.sidebar.subheader("Token Usage Statistics")
# Group token statistics by interaction (for MOD-ther S. and SYS-ter V. only)
if st.session_state.token_statistics:
current_timestamp = None
interaction_count = 0
stats_by_time = {}
# Group stats by timestamp
for stat in st.session_state.token_statistics:
if stat["timestamp"] not in stats_by_time:
stats_by_time[stat["timestamp"]] = []
stats_by_time[stat["timestamp"]].append(stat)
# Display grouped stats
for timestamp, stats in stats_by_time.items():
interaction_count += 1
st.sidebar.markdown(f"**Interaction {interaction_count}** ({timestamp})")
# Calculate total tokens for this interaction
total_input = sum(stat['input_tokens'] for stat in stats)
total_output = sum(stat['output_tokens'] for stat in stats)
total = total_input + total_output
# Display individual bot statistics
for stat in stats:
st.sidebar.markdown(
f"_{stat['bot_name']}_ \n"
f"Input: {stat['input_tokens']} tokens \n"
f"Output: {stat['output_tokens']} tokens \n"
f"Total: {stat['total_tokens']} tokens"
)
# Display interaction totals
st.sidebar.markdown("**Interaction Totals:**")
st.sidebar.markdown(
f"Total Input: {total_input} tokens \n"
f"Total Output: {total_output} tokens \n"
f"Total Usage: {total} tokens"
)
st.sidebar.markdown("---")
st.sidebar.markdown("")
if not check_password():
st.stop()
# Get user input before column creation
user_input = st.chat_input("Ask your question here", key="user_input")
if user_input:
# Create three columns
col1, col2, col3 = st.columns(3)
# First column - PATH-er B. (Document Display)
with col1:
st.markdown("<div class='chat-container'>", unsafe_allow_html=True)
st.subheader(f"{genparam.BOT_1_AVATAR} {genparam.BOT_1_NAME}")
st.markdown("<div class='chat-messages'>", unsafe_allow_html=True)
# Display previous messages
for message in st.session_state.chat_history_1:
if message["role"] == "user":
with st.chat_message(message["role"], avatar=genparam.USER_AVATAR):
st.markdown(message['content'])
else:
# Parse and display stored documents
documents = message['content'].split("\n\n")[2:] # Skip count line
for doc in documents:
if doc.strip():
parts = doc.split("\n")
doc_name = parts[0].replace("Document: ", "")
content = parts[1].replace("Content: ", "")
st.markdown(f"**{doc_name}**")
st.code(content)
# Add user message and get new response
st.session_state.chat_history_1.append({"role": "user", "content": user_input, "avatar": genparam.USER_AVATAR})
milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
client,
wml_credentials,
st.secrets["vector_index_id_1"] # Use first vector index
)
system_prompt = genparam.BOT_1_PROMPT
response = fetch_response(
user_input,
milvus_client,
emb,
vector_index_properties,
vector_store_schema,
system_prompt,
st.session_state.chat_history_1
)
st.session_state.chat_history_1.append({"role": genparam.BOT_1_NAME, "content": response, "avatar": genparam.BOT_1_AVATAR})
st.markdown("</div></div>", unsafe_allow_html=True)
# Second column - MOD-ther S. (Uses documents from first vector index)
with col2:
st.markdown("<div class='chat-container'>", unsafe_allow_html=True)
st.subheader(f"{genparam.BOT_2_AVATAR} {genparam.BOT_2_NAME}")
st.markdown("<div class='chat-messages'>", unsafe_allow_html=True)
for message in st.session_state.chat_history_2:
if message["role"] != "user":
with st.chat_message(message["role"], avatar=genparam.BOT_2_AVATAR):
st.markdown(message['content'])
st.session_state.chat_history_2.append({"role": "user", "content": user_input, "avatar": genparam.USER_AVATAR})
milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
client,
wml_credentials,
st.secrets["vector_index_id_1"] # Use first vector index
)
system_prompt = genparam.BOT_2_PROMPT
response = fetch_response(
user_input,
milvus_client,
emb,
vector_index_properties,
vector_store_schema,
system_prompt,
st.session_state.chat_history_2
)
st.session_state.chat_history_2.append({"role": genparam.BOT_2_NAME, "content": response, "avatar": genparam.BOT_2_AVATAR})
st.markdown("</div></div>", unsafe_allow_html=True)
# Third column - SYS-ter V. (Uses second vector index and chat history from second column)
with col3:
st.markdown("<div class='chat-container'>", unsafe_allow_html=True)
st.subheader(f"{genparam.BOT_3_AVATAR} {genparam.BOT_3_NAME}")
st.markdown("<div class='chat-messages'>", unsafe_allow_html=True)
for message in st.session_state.chat_history_3:
if message["role"] != "user":
with st.chat_message(message["role"], avatar=genparam.BOT_3_AVATAR):
st.markdown(message['content'])
st.session_state.chat_history_3.append({"role": "user", "content": user_input, "avatar": genparam.USER_AVATAR})
milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
client,
wml_credentials,
st.secrets["vector_index_id_2"] # Use second vector index
)
system_prompt = genparam.BOT_3_PROMPT
response = fetch_response(
user_input,
milvus_client,
emb,
vector_index_properties,
vector_store_schema,
system_prompt,
st.session_state.chat_history_3
)
st.session_state.chat_history_3.append({"role": genparam.BOT_3_NAME, "content": response, "avatar": genparam.BOT_3_AVATAR})
st.markdown("</div></div>", unsafe_allow_html=True)
# Update sidebar with new question
st.sidebar.markdown("---")
st.sidebar.markdown("**Latest Question:**")
st.sidebar.markdown(f"_{user_input}_")
if __name__ == "__main__":
main()