Create neo_sages2.py
Browse files- neo_sages2.py +529 -0
    	
        neo_sages2.py
    ADDED
    
    | @@ -0,0 +1,529 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import streamlit as st
         | 
| 2 | 
            +
            from io import BytesIO
         | 
| 3 | 
            +
            import ibm_watsonx_ai
         | 
| 4 | 
            +
            import secretsload
         | 
| 5 | 
            +
            import genparam
         | 
| 6 | 
            +
            import requests
         | 
| 7 | 
            +
            import time
         | 
| 8 | 
            +
            import re
         | 
| 9 | 
            +
            import json
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from ibm_watsonx_ai.foundation_models import ModelInference
         | 
| 12 | 
            +
            from ibm_watsonx_ai import Credentials, APIClient
         | 
| 13 | 
            +
            from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
         | 
| 14 | 
            +
            from ibm_watsonx_ai.metanames import GenTextReturnOptMetaNames as RetParams
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from ibm_watsonx_ai.foundation_models import Embeddings
         | 
| 17 | 
            +
            from ibm_watsonx_ai.foundation_models.utils.enums import EmbeddingTypes
         | 
| 18 | 
            +
            from pymilvus import MilvusClient
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            from secretsload import load_stsecrets
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            credentials = load_stsecrets()
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            st.set_page_config(
         | 
| 25 | 
            +
                page_title="The Solutioning Sages",
         | 
| 26 | 
            +
                page_icon="🪄",
         | 
| 27 | 
            +
                initial_sidebar_state="collapsed",
         | 
| 28 | 
            +
                layout="wide"
         | 
| 29 | 
            +
            )
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            # Password protection
         | 
| 32 | 
            +
            def check_password():
         | 
| 33 | 
            +
                def password_entered():
         | 
| 34 | 
            +
                    if st.session_state["password"] == st.secrets["app_password"]:
         | 
| 35 | 
            +
                        st.session_state["password_correct"] = True
         | 
| 36 | 
            +
                        del st.session_state["password"]
         | 
| 37 | 
            +
                    else:
         | 
| 38 | 
            +
                        st.session_state["password_correct"] = False
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                if "password_correct" not in st.session_state:
         | 
| 41 | 
            +
                    st.markdown("\n\n")
         | 
| 42 | 
            +
                    st.text_input("Enter the password", type="password", on_change=password_entered, key="password")
         | 
| 43 | 
            +
                    st.divider()
         | 
| 44 | 
            +
                    st.info("Designed and developed by Milan Mrdenovic © IBM Norway 2024")
         | 
| 45 | 
            +
                    return False
         | 
| 46 | 
            +
                elif not st.session_state["password_correct"]:
         | 
| 47 | 
            +
                    st.markdown("\n\n")
         | 
| 48 | 
            +
                    st.text_input("Enter the password", type="password", on_change=password_entered, key="password")
         | 
| 49 | 
            +
                    st.divider()
         | 
| 50 | 
            +
                    st.error("😕 Incorrect password")
         | 
| 51 | 
            +
                    st.info("Designed and developed by Milan Mrdenovic © IBM Norway 2024")
         | 
| 52 | 
            +
                    return False
         | 
| 53 | 
            +
                else:
         | 
| 54 | 
            +
                    return True
         | 
| 55 | 
            +
             | 
| 56 | 
            +
            def initialize_session_state():
         | 
| 57 | 
            +
                if 'chat_history_1' not in st.session_state:
         | 
| 58 | 
            +
                    st.session_state.chat_history_1 = []
         | 
| 59 | 
            +
                if 'chat_history_2' not in st.session_state:
         | 
| 60 | 
            +
                    st.session_state.chat_history_2 = []
         | 
| 61 | 
            +
                if 'chat_history_3' not in st.session_state:
         | 
| 62 | 
            +
                    st.session_state.chat_history_3 = []
         | 
| 63 | 
            +
                if 'first_question' not in st.session_state:
         | 
| 64 | 
            +
                    st.session_state.first_question = False 
         | 
| 65 | 
            +
                if "counter" not in st.session_state:
         | 
| 66 | 
            +
                    st.session_state["counter"] = 0
         | 
| 67 | 
            +
                if 'token_statistics' not in st.session_state:
         | 
| 68 | 
            +
                    st.session_state.token_statistics = []
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            # three_column_style = """
         | 
| 71 | 
            +
            #     <style>
         | 
| 72 | 
            +
            #     .stColumn {
         | 
| 73 | 
            +
            #         padding: 0.5rem;
         | 
| 74 | 
            +
            #         border-right: 1px solid #dedede;
         | 
| 75 | 
            +
            #     }
         | 
| 76 | 
            +
            #     .stColumn:last-child {
         | 
| 77 | 
            +
            #         border-right: none;
         | 
| 78 | 
            +
            #     }
         | 
| 79 | 
            +
            #     .chat-container {
         | 
| 80 | 
            +
            #         height: calc(100vh - 200px);
         | 
| 81 | 
            +
            #         overflow-y: auto;
         | 
| 82 | 
            +
            #     }
         | 
| 83 | 
            +
            #     </style>
         | 
| 84 | 
            +
            # """
         | 
| 85 | 
            +
             | 
| 86 | 
            +
            three_column_style = """
         | 
| 87 | 
            +
                <style>
         | 
| 88 | 
            +
                .stColumn {
         | 
| 89 | 
            +
                    padding: 0.5rem;
         | 
| 90 | 
            +
                    border-right: 1px solid #dedede;
         | 
| 91 | 
            +
                }
         | 
| 92 | 
            +
                .stColumn:last-child {
         | 
| 93 | 
            +
                    border-right: none;
         | 
| 94 | 
            +
                }
         | 
| 95 | 
            +
                .chat-container {
         | 
| 96 | 
            +
                    height: calc(100vh - 200px);
         | 
| 97 | 
            +
                    overflow-y: auto;
         | 
| 98 | 
            +
                    display: flex;
         | 
| 99 | 
            +
                    flex-direction: column;
         | 
| 100 | 
            +
                }
         | 
| 101 | 
            +
                .chat-messages {
         | 
| 102 | 
            +
                    display: flex;
         | 
| 103 | 
            +
                    flex-direction: column;
         | 
| 104 | 
            +
                    gap: 1rem;
         | 
| 105 | 
            +
                }
         | 
| 106 | 
            +
                </style>
         | 
| 107 | 
            +
            """ # Alt Style
         | 
| 108 | 
            +
             | 
| 109 | 
            +
            #-----
         | 
| 110 | 
            +
            def get_active_model():
         | 
| 111 | 
            +
                return genparam.SELECTED_MODEL_1 if genparam.ACTIVE_MODEL == 0 else genparam.SELECTED_MODEL_2
         | 
| 112 | 
            +
             | 
| 113 | 
            +
            def get_active_prompt_template():
         | 
| 114 | 
            +
                return genparam.PROMPT_TEMPLATE_1 if genparam.ACTIVE_MODEL == 0 else genparam.PROMPT_TEMPLATE_2
         | 
| 115 | 
            +
             | 
| 116 | 
            +
            def get_active_vector_index():
         | 
| 117 | 
            +
                return st.secrets["vector_index_id_1"] if genparam.ACTIVE_INDEX == 0 else st.secrets["vector_index_id_2"]
         | 
| 118 | 
            +
            #-----
         | 
| 119 | 
            +
             | 
| 120 | 
            +
            def setup_client(project_id):
         | 
| 121 | 
            +
                credentials = Credentials(
         | 
| 122 | 
            +
                    url=st.secrets["url"],
         | 
| 123 | 
            +
                    api_key=st.secrets["api_key"]
         | 
| 124 | 
            +
                )
         | 
| 125 | 
            +
                apo = st.secrets["api_key"]
         | 
| 126 | 
            +
                client = APIClient(credentials, project_id=project_id)
         | 
| 127 | 
            +
                return credentials, client
         | 
| 128 | 
            +
             | 
| 129 | 
            +
            wml_credentials, client = setup_client(st.secrets["project_id"])
         | 
| 130 | 
            +
             | 
| 131 | 
            +
            def setup_vector_index(client, wml_credentials, vector_index_id):
         | 
| 132 | 
            +
                vector_index_details = client.data_assets.get_details(vector_index_id)
         | 
| 133 | 
            +
                vector_index_properties = vector_index_details["entity"]["vector_index"]
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                emb = Embeddings(
         | 
| 136 | 
            +
                    model_id=vector_index_properties["settings"]["embedding_model_id"],
         | 
| 137 | 
            +
                    #model_id="sentence-transformers/all-minilm-l12-v2",
         | 
| 138 | 
            +
                    credentials=wml_credentials,
         | 
| 139 | 
            +
                    project_id=st.secrets["project_id"],
         | 
| 140 | 
            +
                    params={
         | 
| 141 | 
            +
                        "truncate_input_tokens": 512
         | 
| 142 | 
            +
                    }
         | 
| 143 | 
            +
                )
         | 
| 144 | 
            +
                
         | 
| 145 | 
            +
                vector_store_schema = vector_index_properties["settings"]["schema_fields"]
         | 
| 146 | 
            +
                connection_details = client.connections.get_details(vector_index_details["entity"]["vector_index"]["store"]["connection_id"])
         | 
| 147 | 
            +
                connection_properties = connection_details["entity"]["properties"]
         | 
| 148 | 
            +
                
         | 
| 149 | 
            +
                milvus_client = MilvusClient(
         | 
| 150 | 
            +
                    uri=f'https://{connection_properties.get("host")}:{connection_properties.get("port")}',
         | 
| 151 | 
            +
                    user=connection_properties.get("username"),
         | 
| 152 | 
            +
                    password=connection_properties.get("password"),
         | 
| 153 | 
            +
                    db_name=vector_index_properties["store"]["database"]
         | 
| 154 | 
            +
                )
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                return milvus_client, emb, vector_index_properties, vector_store_schema
         | 
| 157 | 
            +
             | 
| 158 | 
            +
            def proximity_search(question, milvus_client, emb, vector_index_properties, vector_store_schema):
         | 
| 159 | 
            +
                query_vectors = emb.embed_query(question)
         | 
| 160 | 
            +
                milvus_response = milvus_client.search(
         | 
| 161 | 
            +
                    collection_name=vector_index_properties["store"]["index"],
         | 
| 162 | 
            +
                    data=[query_vectors],
         | 
| 163 | 
            +
                    limit=vector_index_properties["settings"]["top_k"],
         | 
| 164 | 
            +
                    metric_type="L2",
         | 
| 165 | 
            +
                    output_fields=[
         | 
| 166 | 
            +
                        vector_store_schema.get("text"),
         | 
| 167 | 
            +
                        vector_store_schema.get("document_name"),
         | 
| 168 | 
            +
                        vector_store_schema.get("page_number")
         | 
| 169 | 
            +
                    ]
         | 
| 170 | 
            +
                )
         | 
| 171 | 
            +
                
         | 
| 172 | 
            +
                documents = []
         | 
| 173 | 
            +
                
         | 
| 174 | 
            +
                for hit in milvus_response[0]:
         | 
| 175 | 
            +
                    text = hit["entity"].get(vector_store_schema.get("text"), "")
         | 
| 176 | 
            +
                    doc_name = hit["entity"].get(vector_store_schema.get("document_name"), "Unknown Document")
         | 
| 177 | 
            +
                    page_num = hit["entity"].get(vector_store_schema.get("page_number"), "N/A")
         | 
| 178 | 
            +
                    
         | 
| 179 | 
            +
                    formatted_result = f"Document: {doc_name}\nContent: {text}\nPage: {page_num}\n"
         | 
| 180 | 
            +
                    documents.append(formatted_result)
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                joined = "\n".join(documents)
         | 
| 183 | 
            +
                retrieved = f"""Number of Retrieved Documents: {len(documents)}\n\n{joined}"""
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                return retrieved
         | 
| 186 | 
            +
             | 
| 187 | 
            +
            def prepare_prompt(prompt, chat_history):
         | 
| 188 | 
            +
                if genparam.TYPE == "chat" and chat_history:
         | 
| 189 | 
            +
                    chats = "\n".join([f"{message['role']}: \"{message['content']}\"" for message in chat_history])
         | 
| 190 | 
            +
                    prompt = f"""Retrieved Contextual Information:\n__grounding__\n\nConversation History:\n{chats}\n\nNew User Input: {prompt}"""
         | 
| 191 | 
            +
                    return prompt
         | 
| 192 | 
            +
                else:
         | 
| 193 | 
            +
                    prompt = f"""Retrieved Contextual Information:\n__grounding__\n\nUser Input: {prompt}"""
         | 
| 194 | 
            +
                    return prompt
         | 
| 195 | 
            +
             | 
| 196 | 
            +
            def apply_prompt_syntax(prompt, system_prompt, prompt_template, bake_in_prompt_syntax):
         | 
| 197 | 
            +
                model_family_syntax = {
         | 
| 198 | 
            +
                    "llama3-instruct (llama-3, 3.1 & 3.2) - system": """<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""",
         | 
| 199 | 
            +
                    "llama3-instruct (llama-3, 3.1 & 3.2) - user": """<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""",
         | 
| 200 | 
            +
                    "granite-13b-chat & instruct - system": """<|system|>\n{system_prompt}\n<|user|>\n{prompt}\n<|assistant|>\n\n""",
         | 
| 201 | 
            +
                    "granite-13b-chat & instruct - user": """<|user|>\n{prompt}\n<|assistant|>\n\n""",
         | 
| 202 | 
            +
                    "mistral & mixtral v2 tokenizer - system": """<s>[INST] System Prompt: {system_prompt} [/INST][INST] {prompt} [/INST]\n\n""",
         | 
| 203 | 
            +
                    "mistral & mixtral v2 tokenizer - user": """<s>[INST] {prompt} [/INST]\n\n""",
         | 
| 204 | 
            +
                    "no syntax - system": """{system_prompt}\n\n{prompt}""",
         | 
| 205 | 
            +
                    "no syntax - user": """{prompt}"""
         | 
| 206 | 
            +
                }
         | 
| 207 | 
            +
                
         | 
| 208 | 
            +
                if bake_in_prompt_syntax:
         | 
| 209 | 
            +
                    template = model_family_syntax[prompt_template]
         | 
| 210 | 
            +
                    if system_prompt:
         | 
| 211 | 
            +
                        return template.format(system_prompt=system_prompt, prompt=prompt)
         | 
| 212 | 
            +
                return prompt
         | 
| 213 | 
            +
             | 
| 214 | 
            +
            def generate_response(watsonx_llm, prompt_data, params):
         | 
| 215 | 
            +
                generated_response = watsonx_llm.generate_text_stream(prompt=prompt_data, params=params)
         | 
| 216 | 
            +
                for chunk in generated_response:
         | 
| 217 | 
            +
                    yield chunk
         | 
| 218 | 
            +
             | 
| 219 | 
            +
            def fetch_response(user_input, milvus_client, emb, vector_index_properties, vector_store_schema, system_prompt, chat_history):
         | 
| 220 | 
            +
                # Get grounding documents
         | 
| 221 | 
            +
                grounding = proximity_search(
         | 
| 222 | 
            +
                    question=user_input,
         | 
| 223 | 
            +
                    milvus_client=milvus_client,
         | 
| 224 | 
            +
                    emb=emb,
         | 
| 225 | 
            +
                    vector_index_properties=vector_index_properties,
         | 
| 226 | 
            +
                    vector_store_schema=vector_store_schema
         | 
| 227 | 
            +
                )
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                # Special handling for PATH-er B. (first column)
         | 
| 230 | 
            +
                if chat_history == st.session_state.chat_history_1:
         | 
| 231 | 
            +
                    # Display user question first
         | 
| 232 | 
            +
                    with st.chat_message("user", avatar=genparam.USER_AVATAR):
         | 
| 233 | 
            +
                        st.markdown(user_input)
         | 
| 234 | 
            +
                    
         | 
| 235 | 
            +
                    # Parse and display each document from the grounding
         | 
| 236 | 
            +
                    documents = grounding.split("\n\n")[2:]  # Skip the count line and first newline
         | 
| 237 | 
            +
                    for doc in documents:
         | 
| 238 | 
            +
                        if doc.strip():  # Only process non-empty strings
         | 
| 239 | 
            +
                            parts = doc.split("\n")
         | 
| 240 | 
            +
                            doc_name = parts[0].replace("Document: ", "")
         | 
| 241 | 
            +
                            content = parts[1].replace("Content: ", "")
         | 
| 242 | 
            +
                            
         | 
| 243 | 
            +
                            # Display document with delay
         | 
| 244 | 
            +
                            time.sleep(0.5)
         | 
| 245 | 
            +
                            st.markdown(f"**{doc_name}**")
         | 
| 246 | 
            +
                            st.code(content)
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                    # Store in chat history
         | 
| 249 | 
            +
                    return grounding
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                # For MOD-ther S. (second column)
         | 
| 252 | 
            +
                elif chat_history == st.session_state.chat_history_2:
         | 
| 253 | 
            +
                    prompt = prepare_prompt(user_input, chat_history)
         | 
| 254 | 
            +
                    prompt_data = apply_prompt_syntax(
         | 
| 255 | 
            +
                        prompt,
         | 
| 256 | 
            +
                        system_prompt,
         | 
| 257 | 
            +
                        get_active_prompt_template(),
         | 
| 258 | 
            +
                        genparam.BAKE_IN_PROMPT_SYNTAX
         | 
| 259 | 
            +
                    )
         | 
| 260 | 
            +
                    prompt_data = prompt_data.replace("__grounding__", grounding)
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                    # Add debug information to column 1 if enabled
         | 
| 263 | 
            +
                    if genparam.INPUT_DEBUG_VIEW == 1:
         | 
| 264 | 
            +
                        with st.columns(3)[0]:  # Access first column
         | 
| 265 | 
            +
                            st.markdown(f"**{genparam.BOT_2_AVATAR} {genparam.BOT_2_NAME} Prompt Data:**")
         | 
| 266 | 
            +
                            st.code(prompt_data, language="text")
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                # For SYS-ter V. (third column)
         | 
| 269 | 
            +
                else:
         | 
| 270 | 
            +
                    # Get chat history from MOD-ther S.
         | 
| 271 | 
            +
                    mod_ther_history = st.session_state.chat_history_2
         | 
| 272 | 
            +
                    prompt = prepare_prompt(user_input, mod_ther_history)
         | 
| 273 | 
            +
                    prompt_data = apply_prompt_syntax(
         | 
| 274 | 
            +
                        prompt,
         | 
| 275 | 
            +
                        system_prompt,
         | 
| 276 | 
            +
                        get_active_prompt_template(),
         | 
| 277 | 
            +
                        genparam.BAKE_IN_PROMPT_SYNTAX
         | 
| 278 | 
            +
                    )
         | 
| 279 | 
            +
                    prompt_data = prompt_data.replace("__grounding__", grounding)
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                        # Add debug information to column 1 if enabled
         | 
| 282 | 
            +
                    if genparam.INPUT_DEBUG_VIEW == 1:
         | 
| 283 | 
            +
                        with st.columns(3)[0]:  # Access first column
         | 
| 284 | 
            +
                            st.markdown(f"**{genparam.BOT_3_AVATAR} {genparam.BOT_3_NAME} Prompt Data:**")
         | 
| 285 | 
            +
                            st.code(prompt_data, language="text")
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                # Continue with normal processing for columns 2 and 3
         | 
| 288 | 
            +
                watsonx_llm = ModelInference(
         | 
| 289 | 
            +
                    api_client=client, 
         | 
| 290 | 
            +
                    model_id=get_active_model(),
         | 
| 291 | 
            +
                    verify=genparam.VERIFY
         | 
| 292 | 
            +
                )
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                params = {
         | 
| 295 | 
            +
                    GenParams.DECODING_METHOD: genparam.DECODING_METHOD,
         | 
| 296 | 
            +
                    GenParams.MAX_NEW_TOKENS: genparam.MAX_NEW_TOKENS,
         | 
| 297 | 
            +
                    GenParams.MIN_NEW_TOKENS: genparam.MIN_NEW_TOKENS,
         | 
| 298 | 
            +
                    GenParams.REPETITION_PENALTY: genparam.REPETITION_PENALTY,
         | 
| 299 | 
            +
                    GenParams.STOP_SEQUENCES: genparam.STOP_SEQUENCES
         | 
| 300 | 
            +
                }
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                bot_name = None
         | 
| 303 | 
            +
                bot_avatar = None
         | 
| 304 | 
            +
                if chat_history == st.session_state.chat_history_1:
         | 
| 305 | 
            +
                    bot_name = genparam.BOT_1_NAME
         | 
| 306 | 
            +
                    bot_avatar = genparam.BOT_1_AVATAR
         | 
| 307 | 
            +
                elif chat_history == st.session_state.chat_history_2:
         | 
| 308 | 
            +
                    bot_name = genparam.BOT_2_NAME
         | 
| 309 | 
            +
                    bot_avatar = genparam.BOT_2_AVATAR
         | 
| 310 | 
            +
                else:
         | 
| 311 | 
            +
                    bot_name = genparam.BOT_3_NAME
         | 
| 312 | 
            +
                    bot_avatar = genparam.BOT_3_AVATAR
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                with st.chat_message(bot_name, avatar=bot_avatar):
         | 
| 315 | 
            +
                    if chat_history != st.session_state.chat_history_1:  # Only generate responses for columns 2 and 3
         | 
| 316 | 
            +
                        stream = generate_response(watsonx_llm, prompt_data, params)
         | 
| 317 | 
            +
                        response = st.write_stream(stream)
         | 
| 318 | 
            +
                        
         | 
| 319 | 
            +
                        # Only capture tokens for MOD-ther S. and SYS-ter V.
         | 
| 320 | 
            +
                        if genparam.TOKEN_CAPTURE_ENABLED and chat_history != st.session_state.chat_history_1:
         | 
| 321 | 
            +
                            token_stats = capture_tokens(prompt_data, response, bot_name)
         | 
| 322 | 
            +
                            if token_stats:
         | 
| 323 | 
            +
                                st.session_state.token_statistics.append(token_stats)
         | 
| 324 | 
            +
                    else:
         | 
| 325 | 
            +
                        response = grounding  # For column 1, we already displayed the content
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                return response
         | 
| 328 | 
            +
             | 
| 329 | 
            +
            def capture_tokens(prompt_data, response, chat_number):
         | 
| 330 | 
            +
                if not genparam.TOKEN_CAPTURE_ENABLED:
         | 
| 331 | 
            +
                    return
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                watsonx_llm = ModelInference(
         | 
| 334 | 
            +
                    api_client=client, 
         | 
| 335 | 
            +
                    model_id=genparam.SELECTED_MODEL,
         | 
| 336 | 
            +
                    verify=genparam.VERIFY
         | 
| 337 | 
            +
                )
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                input_tokens = watsonx_llm.tokenize(prompt=prompt_data)["result"]["token_count"]
         | 
| 340 | 
            +
                output_tokens = watsonx_llm.tokenize(prompt=response)["result"]["token_count"]
         | 
| 341 | 
            +
                total_tokens = input_tokens + output_tokens
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                return {
         | 
| 344 | 
            +
                    "bot_name": bot_name,
         | 
| 345 | 
            +
                    "input_tokens": input_tokens,
         | 
| 346 | 
            +
                    "output_tokens": output_tokens,
         | 
| 347 | 
            +
                    "total_tokens": total_tokens,
         | 
| 348 | 
            +
                    "timestamp": time.strftime("%H:%M:%S")
         | 
| 349 | 
            +
                }
         | 
| 350 | 
            +
             | 
| 351 | 
            +
            def main():
         | 
| 352 | 
            +
                initialize_session_state()
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                # Apply custom styles
         | 
| 355 | 
            +
                st.markdown(three_column_style, unsafe_allow_html=True)
         | 
| 356 | 
            +
                
         | 
| 357 | 
            +
                # Sidebar configuration
         | 
| 358 | 
            +
                st.sidebar.header('The Solutioning Sages')
         | 
| 359 | 
            +
                st.sidebar.divider()
         | 
| 360 | 
            +
                
         | 
| 361 | 
            +
                # Display token statistics in sidebar
         | 
| 362 | 
            +
                st.sidebar.subheader("Token Usage Statistics")
         | 
| 363 | 
            +
                
         | 
| 364 | 
            +
                # Group token statistics by interaction (for MOD-ther S. and SYS-ter V. only)
         | 
| 365 | 
            +
                if st.session_state.token_statistics:
         | 
| 366 | 
            +
                    current_timestamp = None
         | 
| 367 | 
            +
                    interaction_count = 0
         | 
| 368 | 
            +
                    stats_by_time = {}
         | 
| 369 | 
            +
                    
         | 
| 370 | 
            +
                    # Group stats by timestamp
         | 
| 371 | 
            +
                    for stat in st.session_state.token_statistics:
         | 
| 372 | 
            +
                        if stat["timestamp"] not in stats_by_time:
         | 
| 373 | 
            +
                            stats_by_time[stat["timestamp"]] = []
         | 
| 374 | 
            +
                        stats_by_time[stat["timestamp"]].append(stat)
         | 
| 375 | 
            +
                    
         | 
| 376 | 
            +
                    # Display grouped stats
         | 
| 377 | 
            +
                    for timestamp, stats in stats_by_time.items():
         | 
| 378 | 
            +
                        interaction_count += 1
         | 
| 379 | 
            +
                        st.sidebar.markdown(f"**Interaction {interaction_count}** ({timestamp})")
         | 
| 380 | 
            +
                        
         | 
| 381 | 
            +
                        # Calculate total tokens for this interaction
         | 
| 382 | 
            +
                        total_input = sum(stat['input_tokens'] for stat in stats)
         | 
| 383 | 
            +
                        total_output = sum(stat['output_tokens'] for stat in stats)
         | 
| 384 | 
            +
                        total = total_input + total_output
         | 
| 385 | 
            +
                        
         | 
| 386 | 
            +
                        # Display individual bot statistics
         | 
| 387 | 
            +
                        for stat in stats:
         | 
| 388 | 
            +
                            st.sidebar.markdown(
         | 
| 389 | 
            +
                                f"_{stat['bot_name']}_  \n"
         | 
| 390 | 
            +
                                f"Input: {stat['input_tokens']} tokens  \n"
         | 
| 391 | 
            +
                                f"Output: {stat['output_tokens']} tokens  \n"
         | 
| 392 | 
            +
                                f"Total: {stat['total_tokens']} tokens"
         | 
| 393 | 
            +
                            )
         | 
| 394 | 
            +
                        
         | 
| 395 | 
            +
                        # Display interaction totals
         | 
| 396 | 
            +
                        st.sidebar.markdown("**Interaction Totals:**")
         | 
| 397 | 
            +
                        st.sidebar.markdown(
         | 
| 398 | 
            +
                            f"Total Input: {total_input} tokens  \n"
         | 
| 399 | 
            +
                            f"Total Output: {total_output} tokens  \n"
         | 
| 400 | 
            +
                            f"Total Usage: {total} tokens"
         | 
| 401 | 
            +
                        )
         | 
| 402 | 
            +
                        st.sidebar.markdown("---")
         | 
| 403 | 
            +
                
         | 
| 404 | 
            +
                st.sidebar.markdown("")
         | 
| 405 | 
            +
             | 
| 406 | 
            +
             | 
| 407 | 
            +
                if not check_password():
         | 
| 408 | 
            +
                    st.stop()
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                # Get user input before column creation
         | 
| 411 | 
            +
                user_input = st.chat_input("Ask your question here", key="user_input")
         | 
| 412 | 
            +
                
         | 
| 413 | 
            +
                if user_input:
         | 
| 414 | 
            +
                    # Create three columns
         | 
| 415 | 
            +
                    col1, col2, col3 = st.columns(3)
         | 
| 416 | 
            +
                    
         | 
| 417 | 
            +
                    # First column - PATH-er B. (Document Display)
         | 
| 418 | 
            +
                    with col1:
         | 
| 419 | 
            +
                        st.markdown("<div class='chat-container'>", unsafe_allow_html=True)
         | 
| 420 | 
            +
                        st.subheader(f"{genparam.BOT_1_AVATAR} {genparam.BOT_1_NAME}")
         | 
| 421 | 
            +
                        st.markdown("<div class='chat-messages'>", unsafe_allow_html=True)
         | 
| 422 | 
            +
                        
         | 
| 423 | 
            +
                        # Display previous messages
         | 
| 424 | 
            +
                        for message in st.session_state.chat_history_1:
         | 
| 425 | 
            +
                            if message["role"] == "user":
         | 
| 426 | 
            +
                                with st.chat_message(message["role"], avatar=genparam.USER_AVATAR):
         | 
| 427 | 
            +
                                    st.markdown(message['content'])
         | 
| 428 | 
            +
                            else:
         | 
| 429 | 
            +
                                # Parse and display stored documents
         | 
| 430 | 
            +
                                documents = message['content'].split("\n\n")[2:]  # Skip count line
         | 
| 431 | 
            +
                                for doc in documents:
         | 
| 432 | 
            +
                                    if doc.strip():
         | 
| 433 | 
            +
                                        parts = doc.split("\n")
         | 
| 434 | 
            +
                                        doc_name = parts[0].replace("Document: ", "")
         | 
| 435 | 
            +
                                        content = parts[1].replace("Content: ", "")
         | 
| 436 | 
            +
                                        st.markdown(f"**{doc_name}**")
         | 
| 437 | 
            +
                                        st.code(content)
         | 
| 438 | 
            +
                        
         | 
| 439 | 
            +
                        # Add user message and get new response
         | 
| 440 | 
            +
                        st.session_state.chat_history_1.append({"role": "user", "content": user_input, "avatar": genparam.USER_AVATAR})
         | 
| 441 | 
            +
                        milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
         | 
| 442 | 
            +
                            client, 
         | 
| 443 | 
            +
                            wml_credentials,
         | 
| 444 | 
            +
                            st.secrets["vector_index_id_1"]  # Use first vector index
         | 
| 445 | 
            +
                        )
         | 
| 446 | 
            +
                        system_prompt = genparam.BOT_1_PROMPT
         | 
| 447 | 
            +
                        
         | 
| 448 | 
            +
                        response = fetch_response(
         | 
| 449 | 
            +
                            user_input, 
         | 
| 450 | 
            +
                            milvus_client, 
         | 
| 451 | 
            +
                            emb, 
         | 
| 452 | 
            +
                            vector_index_properties, 
         | 
| 453 | 
            +
                            vector_store_schema,
         | 
| 454 | 
            +
                            system_prompt,
         | 
| 455 | 
            +
                            st.session_state.chat_history_1
         | 
| 456 | 
            +
                        )
         | 
| 457 | 
            +
                        st.session_state.chat_history_1.append({"role": genparam.BOT_1_NAME, "content": response, "avatar": genparam.BOT_1_AVATAR})
         | 
| 458 | 
            +
                        st.markdown("</div></div>", unsafe_allow_html=True)
         | 
| 459 | 
            +
             | 
| 460 | 
            +
                    # Second column - MOD-ther S. (Uses documents from first vector index)
         | 
| 461 | 
            +
                    with col2:
         | 
| 462 | 
            +
                        st.markdown("<div class='chat-container'>", unsafe_allow_html=True)
         | 
| 463 | 
            +
                        st.subheader(f"{genparam.BOT_2_AVATAR} {genparam.BOT_2_NAME}")
         | 
| 464 | 
            +
                        st.markdown("<div class='chat-messages'>", unsafe_allow_html=True)
         | 
| 465 | 
            +
                        
         | 
| 466 | 
            +
                        for message in st.session_state.chat_history_2:
         | 
| 467 | 
            +
                            if message["role"] != "user":
         | 
| 468 | 
            +
                                with st.chat_message(message["role"], avatar=genparam.BOT_2_AVATAR):
         | 
| 469 | 
            +
                                    st.markdown(message['content'])
         | 
| 470 | 
            +
                        
         | 
| 471 | 
            +
                        st.session_state.chat_history_2.append({"role": "user", "content": user_input, "avatar": genparam.USER_AVATAR})
         | 
| 472 | 
            +
                        milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
         | 
| 473 | 
            +
                            client, 
         | 
| 474 | 
            +
                            wml_credentials,
         | 
| 475 | 
            +
                            st.secrets["vector_index_id_1"]  # Use first vector index
         | 
| 476 | 
            +
                        )
         | 
| 477 | 
            +
                        system_prompt = genparam.BOT_2_PROMPT
         | 
| 478 | 
            +
                        
         | 
| 479 | 
            +
                        response = fetch_response(
         | 
| 480 | 
            +
                            user_input, 
         | 
| 481 | 
            +
                            milvus_client, 
         | 
| 482 | 
            +
                            emb, 
         | 
| 483 | 
            +
                            vector_index_properties, 
         | 
| 484 | 
            +
                            vector_store_schema,
         | 
| 485 | 
            +
                            system_prompt,
         | 
| 486 | 
            +
                            st.session_state.chat_history_2
         | 
| 487 | 
            +
                        )
         | 
| 488 | 
            +
                        st.session_state.chat_history_2.append({"role": genparam.BOT_2_NAME, "content": response, "avatar": genparam.BOT_2_AVATAR})
         | 
| 489 | 
            +
                        st.markdown("</div></div>", unsafe_allow_html=True)
         | 
| 490 | 
            +
                        
         | 
| 491 | 
            +
                    # Third column - SYS-ter V. (Uses second vector index and chat history from second column)
         | 
| 492 | 
            +
                    with col3:
         | 
| 493 | 
            +
                        st.markdown("<div class='chat-container'>", unsafe_allow_html=True)
         | 
| 494 | 
            +
                        st.subheader(f"{genparam.BOT_3_AVATAR} {genparam.BOT_3_NAME}")
         | 
| 495 | 
            +
                        st.markdown("<div class='chat-messages'>", unsafe_allow_html=True)
         | 
| 496 | 
            +
                        
         | 
| 497 | 
            +
                        for message in st.session_state.chat_history_3:
         | 
| 498 | 
            +
                            if message["role"] != "user":
         | 
| 499 | 
            +
                                with st.chat_message(message["role"], avatar=genparam.BOT_3_AVATAR):
         | 
| 500 | 
            +
                                    st.markdown(message['content'])
         | 
| 501 | 
            +
                        
         | 
| 502 | 
            +
                        st.session_state.chat_history_3.append({"role": "user", "content": user_input, "avatar": genparam.USER_AVATAR})
         | 
| 503 | 
            +
                        milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
         | 
| 504 | 
            +
                            client, 
         | 
| 505 | 
            +
                            wml_credentials,
         | 
| 506 | 
            +
                            st.secrets["vector_index_id_2"]  # Use second vector index
         | 
| 507 | 
            +
                        )
         | 
| 508 | 
            +
                        system_prompt = genparam.BOT_3_PROMPT
         | 
| 509 | 
            +
                        
         | 
| 510 | 
            +
                        response = fetch_response(
         | 
| 511 | 
            +
                            user_input, 
         | 
| 512 | 
            +
                            milvus_client, 
         | 
| 513 | 
            +
                            emb, 
         | 
| 514 | 
            +
                            vector_index_properties, 
         | 
| 515 | 
            +
                            vector_store_schema,
         | 
| 516 | 
            +
                            system_prompt,
         | 
| 517 | 
            +
                            st.session_state.chat_history_3
         | 
| 518 | 
            +
                        )
         | 
| 519 | 
            +
                        st.session_state.chat_history_3.append({"role": genparam.BOT_3_NAME, "content": response, "avatar": genparam.BOT_3_AVATAR})
         | 
| 520 | 
            +
                        st.markdown("</div></div>", unsafe_allow_html=True)
         | 
| 521 | 
            +
             | 
| 522 | 
            +
             | 
| 523 | 
            +
                    # Update sidebar with new question
         | 
| 524 | 
            +
                    st.sidebar.markdown("---")
         | 
| 525 | 
            +
                    st.sidebar.markdown("**Latest Question:**")
         | 
| 526 | 
            +
                    st.sidebar.markdown(f"_{user_input}_")
         | 
| 527 | 
            +
             | 
| 528 | 
            +
            if __name__ == "__main__":
         | 
| 529 | 
            +
                main()
         | 
