|
import streamlit as st |
|
from io import BytesIO |
|
import ibm_watsonx_ai |
|
import secretsload |
|
import genparam |
|
import requests |
|
import time |
|
import re |
|
import json |
|
|
|
from ibm_watsonx_ai.foundation_models import ModelInference |
|
from ibm_watsonx_ai import Credentials, APIClient |
|
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams |
|
from ibm_watsonx_ai.metanames import GenTextReturnOptMetaNames as RetParams |
|
from knowledge_bases import KNOWLEDGE_BASE_OPTIONS, SYSTEM_PROMPTS, VECTOR_INDEXES |
|
from ibm_watsonx_ai.foundation_models import Embeddings |
|
from ibm_watsonx_ai.foundation_models.utils.enums import EmbeddingTypes |
|
from pymilvus import MilvusClient |
|
|
|
from secretsload import load_stsecrets |
|
|
|
credentials = load_stsecrets() |
|
|
|
st.set_page_config( |
|
page_title="The Solutioning Sages", |
|
page_icon="🪄", |
|
initial_sidebar_state="collapsed", |
|
layout="wide" |
|
) |
|
|
|
|
|
def check_password(): |
|
def password_entered(): |
|
if st.session_state["password"] == st.secrets["app_password"]: |
|
st.session_state["password_correct"] = True |
|
del st.session_state["password"] |
|
else: |
|
st.session_state["password_correct"] = False |
|
|
|
if "password_correct" not in st.session_state: |
|
st.markdown("\n\n") |
|
st.text_input("Enter the password", type="password", on_change=password_entered, key="password") |
|
st.divider() |
|
st.info("Designed and developed by Milan Mrdenovic © IBM Norway 2024") |
|
return False |
|
elif not st.session_state["password_correct"]: |
|
st.markdown("\n\n") |
|
st.text_input("Enter the password", type="password", on_change=password_entered, key="password") |
|
st.divider() |
|
st.error("😕 Incorrect password") |
|
st.info("Designed and developed by Milan Mrdenovic © IBM Norway 2024") |
|
return False |
|
else: |
|
return True |
|
|
|
def initialize_session_state(): |
|
if 'chat_history_1' not in st.session_state: |
|
st.session_state.chat_history_1 = [] |
|
if 'chat_history_2' not in st.session_state: |
|
st.session_state.chat_history_2 = [] |
|
if 'chat_history_3' not in st.session_state: |
|
st.session_state.chat_history_3 = [] |
|
if 'first_question' not in st.session_state: |
|
st.session_state.first_question = False |
|
if "counter" not in st.session_state: |
|
st.session_state["counter"] = 0 |
|
if 'token_statistics' not in st.session_state: |
|
st.session_state.token_statistics = [] |
|
if 'selected_kb' not in st.session_state: |
|
st.session_state.selected_kb = KNOWLEDGE_BASE_OPTIONS[0] |
|
if 'current_system_prompts' not in st.session_state: |
|
st.session_state.current_system_prompts = SYSTEM_PROMPTS[st.session_state.selected_kb] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
three_column_style = """ |
|
<style> |
|
.stColumn { |
|
padding: 0.5rem; |
|
border-right: 1px solid #dedede; |
|
} |
|
.stColumn:last-child { |
|
border-right: none; |
|
} |
|
.chat-container { |
|
height: calc(100vh - 200px); |
|
overflow-y: auto; |
|
display: flex; |
|
flex-direction: column; |
|
} |
|
.chat-messages { |
|
display: flex; |
|
flex-direction: column; |
|
gap: 1rem; |
|
} |
|
</style> |
|
""" |
|
|
|
|
|
def get_active_model(): |
|
return genparam.SELECTED_MODEL_1 if genparam.ACTIVE_MODEL == 0 else genparam.SELECTED_MODEL_2 |
|
|
|
def get_active_prompt_template(): |
|
return genparam.PROMPT_TEMPLATE_1 if genparam.ACTIVE_MODEL == 0 else genparam.PROMPT_TEMPLATE_2 |
|
|
|
def get_active_vector_index(): |
|
selected_kb = st.session_state.selected_kb |
|
if genparam.ACTIVE_INDEX == 0: |
|
return VECTOR_INDEXES[selected_kb]["index_1"] |
|
else: |
|
return VECTOR_INDEXES[selected_kb]["index_2"] |
|
|
|
|
|
def setup_client(project_id=None): |
|
credentials = Credentials( |
|
url=st.secrets["url"], |
|
api_key=st.secrets["api_key"] |
|
) |
|
|
|
project_id = project_id or st.secrets["project_id"] |
|
client = APIClient(credentials, project_id=project_id) |
|
return credentials, client |
|
|
|
wml_credentials, client = setup_client(st.secrets["project_id"]) |
|
|
|
def setup_vector_index(client, wml_credentials, vector_index_id): |
|
vector_index_details = client.data_assets.get_details(vector_index_id) |
|
vector_index_properties = vector_index_details["entity"]["vector_index"] |
|
|
|
emb = Embeddings( |
|
model_id=vector_index_properties["settings"]["embedding_model_id"], |
|
|
|
credentials=wml_credentials, |
|
project_id=st.secrets["project_id"], |
|
params={ |
|
"truncate_input_tokens": 512 |
|
} |
|
) |
|
|
|
vector_store_schema = vector_index_properties["settings"]["schema_fields"] |
|
connection_details = client.connections.get_details(vector_index_details["entity"]["vector_index"]["store"]["connection_id"]) |
|
connection_properties = connection_details["entity"]["properties"] |
|
|
|
milvus_client = MilvusClient( |
|
uri=f'https://{connection_properties.get("host")}:{connection_properties.get("port")}', |
|
user=connection_properties.get("username"), |
|
password=connection_properties.get("password"), |
|
db_name=vector_index_properties["store"]["database"] |
|
) |
|
|
|
return milvus_client, emb, vector_index_properties, vector_store_schema |
|
|
|
def proximity_search(question, milvus_client, emb, vector_index_properties, vector_store_schema): |
|
query_vectors = emb.embed_query(question) |
|
milvus_response = milvus_client.search( |
|
collection_name=vector_index_properties["store"]["index"], |
|
data=[query_vectors], |
|
limit=vector_index_properties["settings"]["top_k"], |
|
metric_type="L2", |
|
output_fields=[ |
|
vector_store_schema.get("text"), |
|
vector_store_schema.get("document_name"), |
|
vector_store_schema.get("page_number") |
|
] |
|
) |
|
|
|
documents = [] |
|
|
|
for hit in milvus_response[0]: |
|
text = hit["entity"].get(vector_store_schema.get("text"), "") |
|
doc_name = hit["entity"].get(vector_store_schema.get("document_name"), "Unknown Document") |
|
page_num = hit["entity"].get(vector_store_schema.get("page_number"), "N/A") |
|
|
|
formatted_result = f"Document: {doc_name}\nContent: {text}\nPage: {page_num}\n" |
|
documents.append(formatted_result) |
|
|
|
joined = "\n".join(documents) |
|
retrieved = f"""Number of Retrieved Documents: {len(documents)}\n\n{joined}""" |
|
|
|
return retrieved |
|
|
|
def prepare_prompt(prompt, chat_history): |
|
if genparam.TYPE == "chat" and chat_history: |
|
chats = "\n".join([f"{message['role']}: \"{message['content']}\"" for message in chat_history]) |
|
prompt = f"""Retrieved Contextual Information:\n__grounding__\n\nConversation History:\n{chats}\n\nNew User Input: {prompt}""" |
|
return prompt |
|
else: |
|
prompt = f"""Retrieved Contextual Information:\n__grounding__\n\nUser Input: {prompt}""" |
|
return prompt |
|
|
|
def apply_prompt_syntax(prompt, system_prompt, prompt_template, bake_in_prompt_syntax): |
|
model_family_syntax = { |
|
"llama3-instruct (llama-3, 3.1 & 3.2) - system": """<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""", |
|
"llama3-instruct (llama-3, 3.1 & 3.2) - user": """<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""", |
|
"granite-13b-chat & instruct - system": """<|system|>\n{system_prompt}\n<|user|>\n{prompt}\n<|assistant|>\n\n""", |
|
"granite-13b-chat & instruct - user": """<|user|>\n{prompt}\n<|assistant|>\n\n""", |
|
"mistral & mixtral v2 tokenizer - system": """<s>[INST] System Prompt: {system_prompt} [/INST][INST] {prompt} [/INST]\n\n""", |
|
"mistral & mixtral v2 tokenizer - user": """<s>[INST] {prompt} [/INST]\n\n""", |
|
"no syntax - system": """{system_prompt}\n\n{prompt}""", |
|
"no syntax - user": """{prompt}""" |
|
} |
|
|
|
if bake_in_prompt_syntax: |
|
template = model_family_syntax[prompt_template] |
|
if system_prompt: |
|
return template.format(system_prompt=system_prompt, prompt=prompt) |
|
return prompt |
|
|
|
def generate_response(watsonx_llm, prompt_data, params): |
|
generated_response = watsonx_llm.generate_text_stream(prompt=prompt_data, params=params) |
|
for chunk in generated_response: |
|
yield chunk |
|
|
|
def fetch_response(user_input, milvus_client, emb, vector_index_properties, vector_store_schema, system_prompt, chat_history): |
|
|
|
grounding = proximity_search( |
|
question=user_input, |
|
milvus_client=milvus_client, |
|
emb=emb, |
|
vector_index_properties=vector_index_properties, |
|
vector_store_schema=vector_store_schema |
|
) |
|
|
|
|
|
if chat_history == st.session_state.chat_history_1: |
|
|
|
with st.chat_message("user", avatar=genparam.USER_AVATAR): |
|
st.markdown(user_input) |
|
|
|
|
|
documents = grounding.split("\n\n")[2:] |
|
for doc in documents: |
|
if doc.strip(): |
|
parts = doc.split("\n") |
|
doc_name = parts[0].replace("Document: ", "") |
|
content = parts[1].replace("Content: ", "") |
|
|
|
|
|
time.sleep(0.5) |
|
st.markdown(f"**{doc_name}**") |
|
st.code(content) |
|
|
|
|
|
return grounding |
|
|
|
|
|
else: |
|
prompt = prepare_prompt(user_input, chat_history) |
|
prompt_data = apply_prompt_syntax( |
|
prompt, |
|
system_prompt, |
|
get_active_prompt_template(), |
|
genparam.BAKE_IN_PROMPT_SYNTAX |
|
) |
|
prompt_data = prompt_data.replace("__grounding__", grounding) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
watsonx_llm = ModelInference( |
|
api_client=client, |
|
model_id=get_active_model(), |
|
verify=genparam.VERIFY |
|
) |
|
|
|
params = { |
|
GenParams.DECODING_METHOD: genparam.DECODING_METHOD, |
|
GenParams.MAX_NEW_TOKENS: genparam.MAX_NEW_TOKENS, |
|
GenParams.MIN_NEW_TOKENS: genparam.MIN_NEW_TOKENS, |
|
GenParams.REPETITION_PENALTY: genparam.REPETITION_PENALTY, |
|
GenParams.STOP_SEQUENCES: genparam.STOP_SEQUENCES |
|
} |
|
|
|
bot_name = None |
|
bot_avatar = None |
|
if chat_history == st.session_state.chat_history_1: |
|
bot_name = genparam.BOT_1_NAME |
|
bot_avatar = genparam.BOT_1_AVATAR |
|
elif chat_history == st.session_state.chat_history_2: |
|
bot_name = genparam.BOT_2_NAME |
|
bot_avatar = genparam.BOT_2_AVATAR |
|
else: |
|
bot_name = genparam.BOT_3_NAME |
|
bot_avatar = genparam.BOT_3_AVATAR |
|
|
|
with st.chat_message(bot_name, avatar=bot_avatar): |
|
if chat_history != st.session_state.chat_history_1: |
|
stream = generate_response(watsonx_llm, prompt_data, params) |
|
response = st.write_stream(stream) |
|
|
|
|
|
if genparam.TOKEN_CAPTURE_ENABLED and chat_history != st.session_state.chat_history_1: |
|
token_stats = capture_tokens(prompt_data, response, bot_name) |
|
if token_stats: |
|
st.session_state.token_statistics.append(token_stats) |
|
else: |
|
response = grounding |
|
|
|
return response |
|
|
|
def capture_tokens(prompt_data, response, chat_number): |
|
if not genparam.TOKEN_CAPTURE_ENABLED: |
|
return |
|
|
|
watsonx_llm = ModelInference( |
|
api_client=client, |
|
model_id=genparam.SELECTED_MODEL, |
|
verify=genparam.VERIFY |
|
) |
|
|
|
input_tokens = watsonx_llm.tokenize(prompt=prompt_data)["result"]["token_count"] |
|
output_tokens = watsonx_llm.tokenize(prompt=response)["result"]["token_count"] |
|
total_tokens = input_tokens + output_tokens |
|
|
|
return { |
|
"bot_name": bot_name, |
|
"input_tokens": input_tokens, |
|
"output_tokens": output_tokens, |
|
"total_tokens": total_tokens, |
|
"timestamp": time.strftime("%H:%M:%S") |
|
} |
|
|
|
def main(): |
|
initialize_session_state() |
|
|
|
|
|
st.markdown(three_column_style, unsafe_allow_html=True) |
|
|
|
|
|
st.sidebar.header('The Solutioning Sages') |
|
st.sidebar.divider() |
|
|
|
|
|
selected_kb = st.sidebar.selectbox( |
|
"Select Knowledge Base", |
|
KNOWLEDGE_BASE_OPTIONS, |
|
index=KNOWLEDGE_BASE_OPTIONS.index(st.session_state.selected_kb) |
|
) |
|
|
|
|
|
if selected_kb != st.session_state.selected_kb: |
|
st.session_state.selected_kb = selected_kb |
|
|
|
global client, wml_credentials |
|
wml_credentials, client = setup_client(VECTOR_INDEXES[selected_kb]["project_id"]) |
|
|
|
|
|
with st.sidebar.expander("Knowledge Base Contents"): |
|
for doc in VECTOR_INDEXES[selected_kb]["contents"]: |
|
st.write(f"📄 {doc}") |
|
|
|
|
|
st.sidebar.divider() |
|
active_model = genparam.SELECTED_MODEL_1 if genparam.ACTIVE_MODEL == 0 else genparam.SELECTED_MODEL_2 |
|
st.sidebar.markdown("**Active Model:**") |
|
st.sidebar.code(active_model) |
|
|
|
st.sidebar.divider() |
|
|
|
|
|
st.sidebar.subheader("Token Usage Statistics") |
|
|
|
|
|
if st.session_state.token_statistics: |
|
current_timestamp = None |
|
interaction_count = 0 |
|
stats_by_time = {} |
|
|
|
|
|
for stat in st.session_state.token_statistics: |
|
if stat["timestamp"] not in stats_by_time: |
|
stats_by_time[stat["timestamp"]] = [] |
|
stats_by_time[stat["timestamp"]].append(stat) |
|
|
|
|
|
for timestamp, stats in stats_by_time.items(): |
|
interaction_count += 1 |
|
st.sidebar.markdown(f"**Interaction {interaction_count}** ({timestamp})") |
|
|
|
|
|
total_input = sum(stat['input_tokens'] for stat in stats) |
|
total_output = sum(stat['output_tokens'] for stat in stats) |
|
total = total_input + total_output |
|
|
|
|
|
for stat in stats: |
|
st.sidebar.markdown( |
|
f"_{stat['bot_name']}_ \n" |
|
f"Input: {stat['input_tokens']} tokens \n" |
|
f"Output: {stat['output_tokens']} tokens \n" |
|
f"Total: {stat['total_tokens']} tokens" |
|
) |
|
|
|
|
|
st.sidebar.markdown("**Interaction Totals:**") |
|
st.sidebar.markdown( |
|
f"Total Input: {total_input} tokens \n" |
|
f"Total Output: {total_output} tokens \n" |
|
f"Total Usage: {total} tokens" |
|
) |
|
st.sidebar.markdown("---") |
|
|
|
st.sidebar.markdown("") |
|
|
|
if not check_password(): |
|
st.stop() |
|
|
|
|
|
user_input = st.chat_input("Ask your question here", key="user_input") |
|
|
|
if user_input: |
|
|
|
col1, col2, col3 = st.columns(3) |
|
|
|
|
|
with col1: |
|
st.markdown("<div class='chat-container'>", unsafe_allow_html=True) |
|
st.subheader(f"{genparam.BOT_1_AVATAR} {genparam.BOT_1_NAME}") |
|
st.markdown("<div class='chat-messages'>", unsafe_allow_html=True) |
|
|
|
|
|
for message in st.session_state.chat_history_1: |
|
if message["role"] == "user": |
|
with st.chat_message(message["role"], avatar=genparam.USER_AVATAR): |
|
st.markdown(message['content']) |
|
else: |
|
|
|
documents = message['content'].split("\n\n")[2:] |
|
for doc in documents: |
|
if doc.strip(): |
|
parts = doc.split("\n") |
|
doc_name = parts[0].replace("Document: ", "") |
|
content = parts[1].replace("Content: ", "") |
|
st.markdown(f"**{doc_name}**") |
|
st.code(content) |
|
|
|
|
|
st.session_state.chat_history_1.append({"role": "user", "content": user_input, "avatar": genparam.USER_AVATAR}) |
|
milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index( |
|
client, |
|
wml_credentials, |
|
VECTOR_INDEXES[st.session_state.selected_kb]["index_1"] |
|
) |
|
system_prompt = genparam.BOT_1_PROMPT |
|
|
|
response = fetch_response( |
|
user_input, |
|
milvus_client, |
|
emb, |
|
vector_index_properties, |
|
vector_store_schema, |
|
system_prompt, |
|
st.session_state.chat_history_1 |
|
) |
|
st.session_state.chat_history_1.append({"role": genparam.BOT_1_NAME, "content": response, "avatar": genparam.BOT_1_AVATAR}) |
|
st.markdown("</div></div>", unsafe_allow_html=True) |
|
|
|
|
|
with col2: |
|
st.markdown("<div class='chat-container'>", unsafe_allow_html=True) |
|
st.subheader(f"{genparam.BOT_2_AVATAR} {genparam.BOT_2_NAME}") |
|
st.markdown("<div class='chat-messages'>", unsafe_allow_html=True) |
|
|
|
for message in st.session_state.chat_history_2: |
|
if message["role"] != "user": |
|
with st.chat_message(message["role"], avatar=genparam.BOT_2_AVATAR): |
|
st.markdown(message['content']) |
|
|
|
st.session_state.chat_history_2.append({"role": "user", "content": user_input, "avatar": genparam.USER_AVATAR}) |
|
milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index( |
|
client, |
|
wml_credentials, |
|
VECTOR_INDEXES[st.session_state.selected_kb]["index_1"] |
|
) |
|
system_prompt = SYSTEM_PROMPTS[st.session_state.selected_kb]["bot_2"] |
|
|
|
response = fetch_response( |
|
user_input, |
|
milvus_client, |
|
emb, |
|
vector_index_properties, |
|
vector_store_schema, |
|
system_prompt, |
|
st.session_state.chat_history_2 |
|
) |
|
|
|
if genparam.INPUT_DEBUG_VIEW == 1: |
|
with col1: |
|
bot_name = genparam.BOT_2_NAME if st.session_state.chat_history_1 == st.session_state.chat_history_2 else genparam.BOT_3_NAME |
|
bot_avatar = genparam.BOT_2_AVATAR if st.session_state.chat_history_1 == st.session_state.chat_history_2 else genparam.BOT_3_AVATAR |
|
st.markdown(f"**{bot_avatar} {bot_name} Prompt Data:**") |
|
st.code(prompt_data, language="text") |
|
|
|
st.session_state.chat_history_2.append({"role": genparam.BOT_2_NAME, "content": response, "avatar": genparam.BOT_2_AVATAR}) |
|
st.markdown("</div></div>", unsafe_allow_html=True) |
|
|
|
|
|
with col3: |
|
st.markdown("<div class='chat-container'>", unsafe_allow_html=True) |
|
st.subheader(f"{genparam.BOT_3_AVATAR} {genparam.BOT_3_NAME}") |
|
st.markdown("<div class='chat-messages'>", unsafe_allow_html=True) |
|
|
|
for message in st.session_state.chat_history_3: |
|
if message["role"] != "user": |
|
with st.chat_message(message["role"], avatar=genparam.BOT_3_AVATAR): |
|
st.markdown(message['content']) |
|
|
|
st.session_state.chat_history_3.append({"role": "user", "content": user_input, "avatar": genparam.USER_AVATAR}) |
|
milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index( |
|
client, |
|
wml_credentials, |
|
VECTOR_INDEXES[st.session_state.selected_kb]["index_2"] |
|
) |
|
system_prompt = SYSTEM_PROMPTS[st.session_state.selected_kb]["bot_3"] |
|
|
|
response = fetch_response( |
|
user_input, |
|
milvus_client, |
|
emb, |
|
vector_index_properties, |
|
vector_store_schema, |
|
system_prompt, |
|
st.session_state.chat_history_3 |
|
) |
|
|
|
if genparam.INPUT_DEBUG_VIEW == 1: |
|
with col1: |
|
bot_name = genparam.BOT_2_NAME if st.session_state.chat_history_1 == st.session_state.chat_history_2 else genparam.BOT_3_NAME |
|
bot_avatar = genparam.BOT_2_AVATAR if st.session_state.chat_history_1 == st.session_state.chat_history_2 else genparam.BOT_3_AVATAR |
|
st.markdown(f"**{bot_avatar} {bot_name} Prompt Data:**") |
|
st.code(prompt_data, language="text") |
|
|
|
st.session_state.chat_history_3.append({"role": genparam.BOT_3_NAME, "content": response, "avatar": genparam.BOT_3_AVATAR}) |
|
st.markdown("</div></div>", unsafe_allow_html=True) |
|
|
|
|
|
st.sidebar.markdown("---") |
|
st.sidebar.markdown("**Latest Question:**") |
|
st.sidebar.markdown(f"_{user_input}_") |
|
|
|
if __name__ == "__main__": |
|
main() |