import streamlit as st from io import BytesIO import ibm_watsonx_ai import secretsload import genparam import requests import time import re import json from ibm_watsonx_ai.foundation_models import ModelInference from ibm_watsonx_ai import Credentials, APIClient from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams from ibm_watsonx_ai.metanames import GenTextReturnOptMetaNames as RetParams from ibm_watsonx_ai.foundation_models import Embeddings from ibm_watsonx_ai.foundation_models.utils.enums import EmbeddingTypes from pymilvus import MilvusClient from secretsload import load_stsecrets credentials = load_stsecrets() st.set_page_config( page_title="The Solutioning Sages", page_icon="🪄", initial_sidebar_state="collapsed", layout="wide" ) # Password protection def check_password(): def password_entered(): if st.session_state["password"] == st.secrets["app_password"]: st.session_state["password_correct"] = True del st.session_state["password"] else: st.session_state["password_correct"] = False if "password_correct" not in st.session_state: st.markdown("\n\n") st.text_input("Enter the password", type="password", on_change=password_entered, key="password") st.divider() st.info("Designed and developed by Milan Mrdenovic © IBM Norway 2024") return False elif not st.session_state["password_correct"]: st.markdown("\n\n") st.text_input("Enter the password", type="password", on_change=password_entered, key="password") st.divider() st.error("😕 Incorrect password") st.info("Designed and developed by Milan Mrdenovic © IBM Norway 2024") return False else: return True def initialize_session_state(): if 'chat_history_1' not in st.session_state: st.session_state.chat_history_1 = [] if 'chat_history_2' not in st.session_state: st.session_state.chat_history_2 = [] if 'chat_history_3' not in st.session_state: st.session_state.chat_history_3 = [] if 'first_question' not in st.session_state: st.session_state.first_question = False if "counter" not in st.session_state: st.session_state["counter"] = 0 if 'token_statistics' not in st.session_state: st.session_state.token_statistics = [] # three_column_style = """ # # """ three_column_style = """ """ # Alt Style #----- def get_active_model(): return genparam.SELECTED_MODEL_1 if genparam.ACTIVE_MODEL == 0 else genparam.SELECTED_MODEL_2 def get_active_prompt_template(): return genparam.PROMPT_TEMPLATE_1 if genparam.ACTIVE_MODEL == 0 else genparam.PROMPT_TEMPLATE_2 def get_active_vector_index(): return st.secrets["vector_index_id_1"] if genparam.ACTIVE_INDEX == 0 else st.secrets["vector_index_id_2"] #----- def setup_client(project_id): credentials = Credentials( url=st.secrets["url"], api_key=st.secrets["api_key"] ) apo = st.secrets["api_key"] client = APIClient(credentials, project_id=project_id) return credentials, client wml_credentials, client = setup_client(st.secrets["project_id"]) def setup_vector_index(client, wml_credentials, vector_index_id): vector_index_details = client.data_assets.get_details(vector_index_id) vector_index_properties = vector_index_details["entity"]["vector_index"] emb = Embeddings( model_id=vector_index_properties["settings"]["embedding_model_id"], #model_id="sentence-transformers/all-minilm-l12-v2", credentials=wml_credentials, project_id=st.secrets["project_id"], params={ "truncate_input_tokens": 512 } ) vector_store_schema = vector_index_properties["settings"]["schema_fields"] connection_details = client.connections.get_details(vector_index_details["entity"]["vector_index"]["store"]["connection_id"]) connection_properties = connection_details["entity"]["properties"] milvus_client = MilvusClient( uri=f'https://{connection_properties.get("host")}:{connection_properties.get("port")}', user=connection_properties.get("username"), password=connection_properties.get("password"), db_name=vector_index_properties["store"]["database"] ) return milvus_client, emb, vector_index_properties, vector_store_schema def proximity_search(question, milvus_client, emb, vector_index_properties, vector_store_schema): query_vectors = emb.embed_query(question) milvus_response = milvus_client.search( collection_name=vector_index_properties["store"]["index"], data=[query_vectors], limit=vector_index_properties["settings"]["top_k"], metric_type="L2", output_fields=[ vector_store_schema.get("text"), vector_store_schema.get("document_name"), vector_store_schema.get("page_number") ] ) documents = [] for hit in milvus_response[0]: text = hit["entity"].get(vector_store_schema.get("text"), "") doc_name = hit["entity"].get(vector_store_schema.get("document_name"), "Unknown Document") page_num = hit["entity"].get(vector_store_schema.get("page_number"), "N/A") formatted_result = f"Document: {doc_name}\nContent: {text}\nPage: {page_num}\n" documents.append(formatted_result) joined = "\n".join(documents) retrieved = f"""Number of Retrieved Documents: {len(documents)}\n\n{joined}""" return retrieved def prepare_prompt(prompt, chat_history): if genparam.TYPE == "chat" and chat_history: chats = "\n".join([f"{message['role']}: \"{message['content']}\"" for message in chat_history]) prompt = f"""Retrieved Contextual Information:\n__grounding__\n\nConversation History:\n{chats}\n\nNew User Input: {prompt}""" return prompt else: prompt = f"""Retrieved Contextual Information:\n__grounding__\n\nUser Input: {prompt}""" return prompt def apply_prompt_syntax(prompt, system_prompt, prompt_template, bake_in_prompt_syntax): model_family_syntax = { "llama3-instruct (llama-3, 3.1 & 3.2) - system": """<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""", "llama3-instruct (llama-3, 3.1 & 3.2) - user": """<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""", "granite-13b-chat & instruct - system": """<|system|>\n{system_prompt}\n<|user|>\n{prompt}\n<|assistant|>\n\n""", "granite-13b-chat & instruct - user": """<|user|>\n{prompt}\n<|assistant|>\n\n""", "mistral & mixtral v2 tokenizer - system": """[INST] System Prompt: {system_prompt} [/INST][INST] {prompt} [/INST]\n\n""", "mistral & mixtral v2 tokenizer - user": """[INST] {prompt} [/INST]\n\n""", "no syntax - system": """{system_prompt}\n\n{prompt}""", "no syntax - user": """{prompt}""" } if bake_in_prompt_syntax: template = model_family_syntax[prompt_template] if system_prompt: return template.format(system_prompt=system_prompt, prompt=prompt) return prompt def generate_response(watsonx_llm, prompt_data, params): generated_response = watsonx_llm.generate_text_stream(prompt=prompt_data, params=params) for chunk in generated_response: yield chunk def fetch_response(user_input, milvus_client, emb, vector_index_properties, vector_store_schema, system_prompt, chat_history): # Get grounding documents grounding = proximity_search( question=user_input, milvus_client=milvus_client, emb=emb, vector_index_properties=vector_index_properties, vector_store_schema=vector_store_schema ) # Special handling for PATH-er B. (first column) if chat_history == st.session_state.chat_history_1: # Display user question first with st.chat_message("user", avatar=genparam.USER_AVATAR): st.markdown(user_input) # Parse and display each document from the grounding documents = grounding.split("\n\n")[2:] # Skip the count line and first newline for doc in documents: if doc.strip(): # Only process non-empty strings parts = doc.split("\n") doc_name = parts[0].replace("Document: ", "") content = parts[1].replace("Content: ", "") # Display document with delay time.sleep(0.5) st.markdown(f"**{doc_name}**") st.code(content) # Store in chat history return grounding # For MOD-ther S. (second column) elif chat_history == st.session_state.chat_history_2: prompt = prepare_prompt(user_input, chat_history) prompt_data = apply_prompt_syntax( prompt, system_prompt, get_active_prompt_template(), genparam.BAKE_IN_PROMPT_SYNTAX ) prompt_data = prompt_data.replace("__grounding__", grounding) # Add debug information to column 1 if enabled if genparam.INPUT_DEBUG_VIEW == 1: with st.columns(3)[0]: # Access first column st.markdown(f"**{genparam.BOT_2_AVATAR} {genparam.BOT_2_NAME} Prompt Data:**") st.code(prompt_data, language="text") # For SYS-ter V. (third column) else: # Get chat history from MOD-ther S. mod_ther_history = st.session_state.chat_history_2 prompt = prepare_prompt(user_input, mod_ther_history) prompt_data = apply_prompt_syntax( prompt, system_prompt, get_active_prompt_template(), genparam.BAKE_IN_PROMPT_SYNTAX ) prompt_data = prompt_data.replace("__grounding__", grounding) # Add debug information to column 1 if enabled if genparam.INPUT_DEBUG_VIEW == 1: with st.columns(3)[0]: # Access first column st.markdown(f"**{genparam.BOT_3_AVATAR} {genparam.BOT_3_NAME} Prompt Data:**") st.code(prompt_data, language="text") # Continue with normal processing for columns 2 and 3 watsonx_llm = ModelInference( api_client=client, model_id=get_active_model(), verify=genparam.VERIFY ) params = { GenParams.DECODING_METHOD: genparam.DECODING_METHOD, GenParams.MAX_NEW_TOKENS: genparam.MAX_NEW_TOKENS, GenParams.MIN_NEW_TOKENS: genparam.MIN_NEW_TOKENS, GenParams.REPETITION_PENALTY: genparam.REPETITION_PENALTY, GenParams.STOP_SEQUENCES: genparam.STOP_SEQUENCES } bot_name = None bot_avatar = None if chat_history == st.session_state.chat_history_1: bot_name = genparam.BOT_1_NAME bot_avatar = genparam.BOT_1_AVATAR elif chat_history == st.session_state.chat_history_2: bot_name = genparam.BOT_2_NAME bot_avatar = genparam.BOT_2_AVATAR else: bot_name = genparam.BOT_3_NAME bot_avatar = genparam.BOT_3_AVATAR with st.chat_message(bot_name, avatar=bot_avatar): if chat_history != st.session_state.chat_history_1: # Only generate responses for columns 2 and 3 stream = generate_response(watsonx_llm, prompt_data, params) response = st.write_stream(stream) # Only capture tokens for MOD-ther S. and SYS-ter V. if genparam.TOKEN_CAPTURE_ENABLED and chat_history != st.session_state.chat_history_1: token_stats = capture_tokens(prompt_data, response, bot_name) if token_stats: st.session_state.token_statistics.append(token_stats) else: response = grounding # For column 1, we already displayed the content return response def capture_tokens(prompt_data, response, chat_number): if not genparam.TOKEN_CAPTURE_ENABLED: return watsonx_llm = ModelInference( api_client=client, model_id=genparam.SELECTED_MODEL, verify=genparam.VERIFY ) input_tokens = watsonx_llm.tokenize(prompt=prompt_data)["result"]["token_count"] output_tokens = watsonx_llm.tokenize(prompt=response)["result"]["token_count"] total_tokens = input_tokens + output_tokens return { "bot_name": bot_name, "input_tokens": input_tokens, "output_tokens": output_tokens, "total_tokens": total_tokens, "timestamp": time.strftime("%H:%M:%S") } def main(): initialize_session_state() # Apply custom styles st.markdown(three_column_style, unsafe_allow_html=True) # Sidebar configuration st.sidebar.header('The Solutioning Sages') st.sidebar.divider() # Display token statistics in sidebar st.sidebar.subheader("Token Usage Statistics") # Group token statistics by interaction (for MOD-ther S. and SYS-ter V. only) if st.session_state.token_statistics: current_timestamp = None interaction_count = 0 stats_by_time = {} # Group stats by timestamp for stat in st.session_state.token_statistics: if stat["timestamp"] not in stats_by_time: stats_by_time[stat["timestamp"]] = [] stats_by_time[stat["timestamp"]].append(stat) # Display grouped stats for timestamp, stats in stats_by_time.items(): interaction_count += 1 st.sidebar.markdown(f"**Interaction {interaction_count}** ({timestamp})") # Calculate total tokens for this interaction total_input = sum(stat['input_tokens'] for stat in stats) total_output = sum(stat['output_tokens'] for stat in stats) total = total_input + total_output # Display individual bot statistics for stat in stats: st.sidebar.markdown( f"_{stat['bot_name']}_ \n" f"Input: {stat['input_tokens']} tokens \n" f"Output: {stat['output_tokens']} tokens \n" f"Total: {stat['total_tokens']} tokens" ) # Display interaction totals st.sidebar.markdown("**Interaction Totals:**") st.sidebar.markdown( f"Total Input: {total_input} tokens \n" f"Total Output: {total_output} tokens \n" f"Total Usage: {total} tokens" ) st.sidebar.markdown("---") st.sidebar.markdown("") if not check_password(): st.stop() # Get user input before column creation user_input = st.chat_input("Ask your question here", key="user_input") if user_input: # Create three columns col1, col2, col3 = st.columns(3) # First column - PATH-er B. (Document Display) with col1: st.markdown("
", unsafe_allow_html=True) st.subheader(f"{genparam.BOT_1_AVATAR} {genparam.BOT_1_NAME}") st.markdown("
", unsafe_allow_html=True) # Display previous messages for message in st.session_state.chat_history_1: if message["role"] == "user": with st.chat_message(message["role"], avatar=genparam.USER_AVATAR): st.markdown(message['content']) else: # Parse and display stored documents documents = message['content'].split("\n\n")[2:] # Skip count line for doc in documents: if doc.strip(): parts = doc.split("\n") doc_name = parts[0].replace("Document: ", "") content = parts[1].replace("Content: ", "") st.markdown(f"**{doc_name}**") st.code(content) # Add user message and get new response st.session_state.chat_history_1.append({"role": "user", "content": user_input, "avatar": genparam.USER_AVATAR}) milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index( client, wml_credentials, st.secrets["vector_index_id_1"] # Use first vector index ) system_prompt = genparam.BOT_1_PROMPT response = fetch_response( user_input, milvus_client, emb, vector_index_properties, vector_store_schema, system_prompt, st.session_state.chat_history_1 ) st.session_state.chat_history_1.append({"role": genparam.BOT_1_NAME, "content": response, "avatar": genparam.BOT_1_AVATAR}) st.markdown("
", unsafe_allow_html=True) # Second column - MOD-ther S. (Uses documents from first vector index) with col2: st.markdown("
", unsafe_allow_html=True) st.subheader(f"{genparam.BOT_2_AVATAR} {genparam.BOT_2_NAME}") st.markdown("
", unsafe_allow_html=True) for message in st.session_state.chat_history_2: if message["role"] != "user": with st.chat_message(message["role"], avatar=genparam.BOT_2_AVATAR): st.markdown(message['content']) st.session_state.chat_history_2.append({"role": "user", "content": user_input, "avatar": genparam.USER_AVATAR}) milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index( client, wml_credentials, st.secrets["vector_index_id_1"] # Use first vector index ) system_prompt = genparam.BOT_2_PROMPT response = fetch_response( user_input, milvus_client, emb, vector_index_properties, vector_store_schema, system_prompt, st.session_state.chat_history_2 ) st.session_state.chat_history_2.append({"role": genparam.BOT_2_NAME, "content": response, "avatar": genparam.BOT_2_AVATAR}) st.markdown("
", unsafe_allow_html=True) # Third column - SYS-ter V. (Uses second vector index and chat history from second column) with col3: st.markdown("
", unsafe_allow_html=True) st.subheader(f"{genparam.BOT_3_AVATAR} {genparam.BOT_3_NAME}") st.markdown("
", unsafe_allow_html=True) for message in st.session_state.chat_history_3: if message["role"] != "user": with st.chat_message(message["role"], avatar=genparam.BOT_3_AVATAR): st.markdown(message['content']) st.session_state.chat_history_3.append({"role": "user", "content": user_input, "avatar": genparam.USER_AVATAR}) milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index( client, wml_credentials, st.secrets["vector_index_id_2"] # Use second vector index ) system_prompt = genparam.BOT_3_PROMPT response = fetch_response( user_input, milvus_client, emb, vector_index_properties, vector_store_schema, system_prompt, st.session_state.chat_history_3 ) st.session_state.chat_history_3.append({"role": genparam.BOT_3_NAME, "content": response, "avatar": genparam.BOT_3_AVATAR}) st.markdown("
", unsafe_allow_html=True) # Update sidebar with new question st.sidebar.markdown("---") st.sidebar.markdown("**Latest Question:**") st.sidebar.markdown(f"_{user_input}_") if __name__ == "__main__": main()