MilanM commited on
Commit
5249af8
·
verified ·
1 Parent(s): a4c45e2

Create neo_sages5.py

Browse files
Files changed (1) hide show
  1. neo_sages5.py +627 -0
neo_sages5.py ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from io import BytesIO
3
+ import ibm_watsonx_ai
4
+ import secretsload
5
+ import genparam
6
+ import requests
7
+ import time
8
+ import re
9
+ import json
10
+ import cos_creds
11
+ from ibm_watsonx_ai.foundation_models import ModelInference
12
+ from ibm_watsonx_ai import Credentials, APIClient
13
+ from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
14
+ from ibm_watsonx_ai.metanames import GenTextReturnOptMetaNames as RetParams
15
+ from knowledge_bases import KNOWLEDGE_BASE_OPTIONS, SYSTEM_PROMPTS, VECTOR_INDEXES
16
+ from ibm_watsonx_ai.foundation_models import Embeddings
17
+ from ibm_watsonx_ai.foundation_models.utils.enums import EmbeddingTypes
18
+ from pymilvus import MilvusClient
19
+
20
+ from secretsload import load_stsecrets
21
+
22
+ credentials = load_stsecrets()
23
+
24
+ st.set_page_config(
25
+ page_title="The Solutioning Sages",
26
+ page_icon="🪄",
27
+ initial_sidebar_state="collapsed",
28
+ layout="wide"
29
+ )
30
+
31
+ # Password protection
32
+ def check_password():
33
+ def password_entered():
34
+ if st.session_state["password"] == st.secrets["app_password"]:
35
+ st.session_state["password_correct"] = True
36
+ del st.session_state["password"]
37
+ else:
38
+ st.session_state["password_correct"] = False
39
+
40
+ if "password_correct" not in st.session_state:
41
+ st.markdown("\n\n")
42
+ st.text_input("Enter the password", type="password", on_change=password_entered, key="password")
43
+ st.divider()
44
+ st.info("Designed and developed by Milan Mrdenovic © IBM Norway 2024")
45
+ return False
46
+ elif not st.session_state["password_correct"]:
47
+ st.markdown("\n\n")
48
+ st.text_input("Enter the password", type="password", on_change=password_entered, key="password")
49
+ st.divider()
50
+ st.error("😕 Incorrect password")
51
+ st.info("Designed and developed by Milan Mrdenovic © IBM Norway 2024")
52
+ return False
53
+ else:
54
+ return True
55
+
56
+ def initialize_session_state():
57
+ if 'chat_history_1' not in st.session_state:
58
+ st.session_state.chat_history_1 = []
59
+ if 'chat_history_2' not in st.session_state:
60
+ st.session_state.chat_history_2 = []
61
+ if 'chat_history_3' not in st.session_state:
62
+ st.session_state.chat_history_3 = []
63
+ if 'first_question' not in st.session_state:
64
+ st.session_state.first_question = False
65
+ if "counter" not in st.session_state:
66
+ st.session_state["counter"] = 0
67
+ if 'token_statistics' not in st.session_state:
68
+ st.session_state.token_statistics = []
69
+ if 'selected_kb' not in st.session_state:
70
+ st.session_state.selected_kb = KNOWLEDGE_BASE_OPTIONS[0]
71
+ if 'current_system_prompts' not in st.session_state:
72
+ st.session_state.current_system_prompts = SYSTEM_PROMPTS[st.session_state.selected_kb]
73
+
74
+ # three_column_style = """
75
+ # <style>
76
+ # .stColumn {
77
+ # padding: 0.5rem;
78
+ # border-right: 1px solid #dedede;
79
+ # }
80
+ # .stColumn:last-child {
81
+ # border-right: none;
82
+ # }
83
+ # .chat-container {
84
+ # height: calc(100vh - 200px);
85
+ # overflow-y: auto;
86
+ # }
87
+ # </style>
88
+ # """
89
+
90
+ three_column_style = """
91
+ <style>
92
+ .stColumn {
93
+ padding: 0.5rem;
94
+ border-right: 1px solid #dedede;
95
+ }
96
+ .stColumn:last-child {
97
+ border-right: none;
98
+ }
99
+ .chat-container {
100
+ height: calc(100vh - 200px);
101
+ overflow-y: auto;
102
+ display: flex;
103
+ flex-direction: column;
104
+ }
105
+ .chat-messages {
106
+ display: flex;
107
+ flex-direction: column;
108
+ gap: 1rem;
109
+ }
110
+ </style>
111
+ """ # Alt Style
112
+
113
+ #-----
114
+ def get_active_model():
115
+ return genparam.SELECTED_MODEL_1 if genparam.ACTIVE_MODEL == 0 else genparam.SELECTED_MODEL_2
116
+
117
+ def get_active_prompt_template():
118
+ return genparam.PROMPT_TEMPLATE_1 if genparam.ACTIVE_MODEL == 0 else genparam.PROMPT_TEMPLATE_2
119
+
120
+ def get_active_vector_index():
121
+ selected_kb = st.session_state.selected_kb
122
+ if genparam.ACTIVE_INDEX == 0:
123
+ return VECTOR_INDEXES[selected_kb]["index_1"]
124
+ else:
125
+ return VECTOR_INDEXES[selected_kb]["index_2"]
126
+ #-----
127
+
128
+ def setup_client(project_id=None):
129
+ credentials = Credentials(
130
+ url=st.secrets["url"],
131
+ api_key=st.secrets["api_key"]
132
+ )
133
+ # Use the passed project_id if provided, otherwise fallback to secrets
134
+ project_id = project_id or st.secrets["project_id"]
135
+ client = APIClient(credentials, project_id=project_id)
136
+ return credentials, client
137
+
138
+ wml_credentials, client = setup_client(st.secrets["project_id"])
139
+
140
+ import ibm_boto3
141
+ from ibm_botocore.client import Config
142
+ ###==========================================1
143
+ def setup_retrieval_cos_client(config_dict=None):
144
+ # Default credentials (same as main notebook for now)
145
+ default_config = {
146
+ "api_key": cos_creds.COS_API_KEY,
147
+ "instance_id": cos_creds.COS_INSTANCE_ID,
148
+ "endpoint": cos_creds.COS_ENDPOINT,
149
+ "bucket": cos_creds.BUCKET_NAME
150
+ }
151
+
152
+ # Use provided config or default
153
+ config = config_dict if config_dict is not None else default_config
154
+
155
+ # Initialize the retrieval COS client
156
+ retrieval_cos_client = ibm_boto3.client(
157
+ "s3",
158
+ ibm_api_key_id=config["api_key"],
159
+ ibm_service_instance_id=config["instance_id"],
160
+ config=Config(signature_version="oauth"),
161
+ endpoint_url=config["endpoint"]
162
+ )
163
+
164
+ # Verify the connection by trying to list objects
165
+ try:
166
+ retrieval_cos_client.list_objects(Bucket=config["bucket"], MaxKeys=1)
167
+ print("Retrieval COS client successfully initialized and connected")
168
+ except Exception as e:
169
+ print(f"Error verifying retrieval COS client connection: {str(e)}")
170
+ raise
171
+
172
+ return retrieval_cos_client
173
+
174
+ retrieval_cos_client = setup_retrieval_cos_client()
175
+
176
+ def load_callable_index_config(config_path):
177
+ try:
178
+ # Download config file content using retrieval client
179
+ response = retrieval_cos_client.get_object(Bucket=cos_creds.BUCKET_NAME, Key=config_path)
180
+ config_content = response['Body'].read().decode('utf-8')
181
+ return json.loads(config_content)
182
+ except Exception as e:
183
+ raise Exception(f"Error loading callable index config: {str(e)}")
184
+ ###==========================================1-2
185
+ def setup_vector_index(client, wml_credentials, config_path):
186
+ # Load the configuration using load_callable_index_config
187
+ config = load_callable_index_config(config_path)
188
+
189
+ # Initialize embeddings
190
+ emb = Embeddings(
191
+ model_id=config["embedding"]["model_id"],
192
+ credentials=wml_credentials,
193
+ project_id=PROJECT_ID,
194
+ params={
195
+ "truncate_input_tokens": config["embedding"]["max_tokens"]
196
+ }
197
+ )
198
+
199
+ # Get connection details
200
+ connection_details = client.connections.get_details(config["connection"]["connection_id"])
201
+ connection_properties = connection_details["entity"]["properties"]
202
+
203
+ # Initialize Milvus client
204
+ milvus_client = MilvusClient(
205
+ uri=f'https://{connection_properties.get("host")}:{connection_properties.get("port")}',
206
+ user=connection_properties.get("username"),
207
+ password=connection_properties.get("password"),
208
+ db_name=config["connection"]["database"]
209
+ )
210
+
211
+ # Prepare vector index properties
212
+ vector_index_properties = {
213
+ "store": {
214
+ "index": config["collection"]["name"],
215
+ "database": config["connection"]["database"]
216
+ },
217
+ "settings": {
218
+ "embedding_model_id": config["embedding"]["model_id"],
219
+ "schema_fields": config["collection"]["schema"]["fields"],
220
+ "top_k": config["index_settings"]["top_k"]
221
+ }
222
+ }
223
+
224
+ # Return the entire config instead of just the schema fields
225
+ return milvus_client, emb, vector_index_properties, config
226
+
227
+ ###==========================================1-2
228
+ def proximity_search(question, milvus_client, emb, vector_index_properties, config):
229
+
230
+ query_vectors = emb.embed_query(question)
231
+ schema_fields = config["collection"]["schema"]["fields"]
232
+ milvus_response = milvus_client.search(
233
+ collection_name=vector_index_properties["store"]["index"],
234
+ data=[query_vectors],
235
+ limit=vector_index_properties["settings"]["top_k"],
236
+ metric_type="L2",
237
+ output_fields=[
238
+ schema_fields["text_field"],
239
+ schema_fields["document_name"],
240
+ schema_fields["page_number"]
241
+ ]
242
+ )
243
+
244
+ documents = []
245
+
246
+ for hit in milvus_response[0]:
247
+ text = hit["entity"].get(schema_fields["text_field"], "")
248
+ doc_name = hit["entity"].get(schema_fields["document_name"], "Unknown Document")
249
+ page_num = hit["entity"].get(schema_fields["page_number"], "N/A")
250
+
251
+ formatted_result = f"Document: {doc_name}\nContent: {text}\nPage: {page_num}\n"
252
+ documents.append(formatted_result)
253
+
254
+ # Format final output
255
+ joined = "\n".join(documents)
256
+ retrieved = f"Number of Retrieved Documents: {len(documents)}\n\n{joined}"
257
+ return retrieved
258
+ ###==========================================2-3
259
+
260
+ def prepare_prompt(prompt, chat_history):
261
+ if genparam.TYPE == "chat" and chat_history:
262
+ chats = "\n".join([f"{message['role']}: \"{message['content']}\"" for message in chat_history])
263
+ prompt = f"""Retrieved Contextual Information:\n__grounding__\n\nConversation History:\n{chats}\n\nNew User Input: {prompt}"""
264
+ return prompt
265
+ else:
266
+ prompt = f"""Retrieved Contextual Information:\n__grounding__\n\nUser Input: {prompt}"""
267
+ return prompt
268
+
269
+ def apply_prompt_syntax(prompt, system_prompt, prompt_template, bake_in_prompt_syntax):
270
+ model_family_syntax = {
271
+ "llama3-instruct (llama-3, 3.1 & 3.2) - system": """<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""",
272
+ "llama3-instruct (llama-3, 3.1 & 3.2) - user": """<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""",
273
+ "granite-13b-chat & instruct - system": """<|system|>\n{system_prompt}\n<|user|>\n{prompt}\n<|assistant|>\n\n""",
274
+ "granite-13b-chat & instruct - user": """<|user|>\n{prompt}\n<|assistant|>\n\n""",
275
+ "mistral & mixtral v2 tokenizer - system": """<s>[INST] System Prompt: {system_prompt} [/INST][INST] {prompt} [/INST]\n\n""",
276
+ "mistral & mixtral v2 tokenizer - user": """<s>[INST] {prompt} [/INST]\n\n""",
277
+ "no syntax - system": """{system_prompt}\n\n{prompt}""",
278
+ "no syntax - user": """{prompt}"""
279
+ }
280
+
281
+ if bake_in_prompt_syntax:
282
+ template = model_family_syntax[prompt_template]
283
+ if system_prompt:
284
+ return template.format(system_prompt=system_prompt, prompt=prompt)
285
+ return prompt
286
+
287
+ def generate_response(watsonx_llm, prompt_data, params):
288
+ generated_response = watsonx_llm.generate_text_stream(prompt=prompt_data, params=params)
289
+ for chunk in generated_response:
290
+ yield chunk
291
+
292
+ def fetch_response(user_input, milvus_client, emb, vector_props, config, system_prompt, chat_history):
293
+ # Get grounding documents
294
+ grounding = proximity_search(
295
+ question=user_input,
296
+ milvus_client=milvus_client,
297
+ emb=emb,
298
+ vector_props=vector_promps,
299
+ config=config_path
300
+ )
301
+
302
+ # Special handling for PATH-er B. (first column)
303
+ if chat_history == st.session_state.chat_history_1:
304
+ # Display user question first
305
+ with st.chat_message("user", avatar=genparam.USER_AVATAR):
306
+ st.markdown(user_input)
307
+
308
+ # Parse and display each document from the grounding
309
+ documents = grounding.split("\n\n")[2:] # Skip the count line and first newline
310
+ for doc in documents:
311
+ if doc.strip(): # Only process non-empty strings
312
+ parts = doc.split("\n")
313
+ doc_name = parts[0].replace("Document: ", "")
314
+ content = parts[1].replace("Content: ", "")
315
+
316
+ # Display document with delay
317
+ time.sleep(0.5)
318
+ st.markdown(f"**{doc_name}**")
319
+ st.code(content)
320
+
321
+ # Store in chat history
322
+ return grounding
323
+
324
+ # For MOD-ther S. (second column) and SYS-ter V. (third column)
325
+ else:
326
+ prompt = prepare_prompt(user_input, chat_history)
327
+ prompt_data = apply_prompt_syntax(
328
+ prompt,
329
+ system_prompt, # Using the system_prompt passed to the function
330
+ get_active_prompt_template(),
331
+ genparam.BAKE_IN_PROMPT_SYNTAX
332
+ )
333
+ prompt_data = prompt_data.replace("__grounding__", grounding)
334
+
335
+ # # Add debug information to column 1 if enabled
336
+ # if genparam.INPUT_DEBUG_VIEW == 1:
337
+ # with col1: # Access first column
338
+ # bot_name = genparam.BOT_2_NAME if chat_history == st.session_state.chat_history_2 else genparam.BOT_3_NAME
339
+ # bot_avatar = genparam.BOT_2_AVATAR if chat_history == st.session_state.chat_history_2 else genparam.BOT_3_AVATAR
340
+ # st.markdown(f"**{bot_avatar} {bot_name} Prompt Data:**")
341
+ # st.code(prompt_data, language="text")
342
+
343
+ # Continue with normal processing for columns 2 and 3
344
+ watsonx_llm = ModelInference(
345
+ api_client=client,
346
+ model_id=get_active_model(),
347
+ verify=genparam.VERIFY
348
+ )
349
+
350
+ params = {
351
+ GenParams.DECODING_METHOD: genparam.DECODING_METHOD,
352
+ GenParams.MAX_NEW_TOKENS: genparam.MAX_NEW_TOKENS,
353
+ GenParams.MIN_NEW_TOKENS: genparam.MIN_NEW_TOKENS,
354
+ GenParams.REPETITION_PENALTY: genparam.REPETITION_PENALTY,
355
+ GenParams.STOP_SEQUENCES: genparam.STOP_SEQUENCES
356
+ }
357
+
358
+ bot_name = None
359
+ bot_avatar = None
360
+ if chat_history == st.session_state.chat_history_1:
361
+ bot_name = genparam.BOT_1_NAME
362
+ bot_avatar = genparam.BOT_1_AVATAR
363
+ elif chat_history == st.session_state.chat_history_2:
364
+ bot_name = genparam.BOT_2_NAME
365
+ bot_avatar = genparam.BOT_2_AVATAR
366
+ else:
367
+ bot_name = genparam.BOT_3_NAME
368
+ bot_avatar = genparam.BOT_3_AVATAR
369
+
370
+ with st.chat_message(bot_name, avatar=bot_avatar):
371
+ if chat_history != st.session_state.chat_history_1: # Only generate responses for columns 2 and 3
372
+ stream = generate_response(watsonx_llm, prompt_data, params)
373
+ response = st.write_stream(stream)
374
+
375
+ # Only capture tokens for MOD-ther S. and SYS-ter V.
376
+ if genparam.TOKEN_CAPTURE_ENABLED and chat_history != st.session_state.chat_history_1:
377
+ token_stats = capture_tokens(prompt_data, response, bot_name)
378
+ if token_stats:
379
+ st.session_state.token_statistics.append(token_stats)
380
+ else:
381
+ response = grounding # For column 1, we already displayed the content
382
+
383
+ return response
384
+
385
+ def capture_tokens(prompt_data, response, chat_number):
386
+ if not genparam.TOKEN_CAPTURE_ENABLED:
387
+ return
388
+
389
+ watsonx_llm = ModelInference(
390
+ api_client=client,
391
+ model_id=genparam.SELECTED_MODEL,
392
+ verify=genparam.VERIFY
393
+ )
394
+
395
+ input_tokens = watsonx_llm.tokenize(prompt=prompt_data)["result"]["token_count"]
396
+ output_tokens = watsonx_llm.tokenize(prompt=response)["result"]["token_count"]
397
+ total_tokens = input_tokens + output_tokens
398
+
399
+ return {
400
+ "bot_name": bot_name,
401
+ "input_tokens": input_tokens,
402
+ "output_tokens": output_tokens,
403
+ "total_tokens": total_tokens,
404
+ "timestamp": time.strftime("%H:%M:%S")
405
+ }
406
+
407
+ def main():
408
+ initialize_session_state()
409
+
410
+ # Apply custom styles
411
+ st.markdown(three_column_style, unsafe_allow_html=True)
412
+
413
+ # Sidebar configuration
414
+ st.sidebar.header('The Solutioning Sages')
415
+ st.sidebar.divider()
416
+
417
+ # Knowledge Base Selection
418
+ selected_kb = st.sidebar.selectbox(
419
+ "Select Knowledge Base",
420
+ KNOWLEDGE_BASE_OPTIONS,
421
+ index=KNOWLEDGE_BASE_OPTIONS.index(st.session_state.selected_kb)
422
+ )
423
+
424
+ # Update knowledge base related values if selection changes
425
+ if selected_kb != st.session_state.selected_kb:
426
+ st.session_state.selected_kb = selected_kb
427
+ # Update the client with the new project_id
428
+ global client, wml_credentials
429
+ wml_credentials, client = setup_client(VECTOR_INDEXES[selected_kb]["project_id"])
430
+
431
+ # Display current knowledge base contents
432
+ with st.sidebar.expander("Knowledge Base Contents"):
433
+ for doc in VECTOR_INDEXES[selected_kb]["contents"]:
434
+ st.write(f"📄 {doc}")
435
+
436
+ # Display active model information
437
+ st.sidebar.divider()
438
+ active_model = genparam.SELECTED_MODEL_1 if genparam.ACTIVE_MODEL == 0 else genparam.SELECTED_MODEL_2
439
+ st.sidebar.markdown("**Active Model:**")
440
+ st.sidebar.code(active_model)
441
+
442
+ st.sidebar.divider()
443
+
444
+ # Display token statistics in sidebar
445
+ st.sidebar.subheader("Token Usage Statistics")
446
+
447
+ # Group token statistics by interaction (for MOD-ther S. and SYS-ter V. only)
448
+ if st.session_state.token_statistics:
449
+ current_timestamp = None
450
+ interaction_count = 0
451
+ stats_by_time = {}
452
+
453
+ # Group stats by timestamp
454
+ for stat in st.session_state.token_statistics:
455
+ if stat["timestamp"] not in stats_by_time:
456
+ stats_by_time[stat["timestamp"]] = []
457
+ stats_by_time[stat["timestamp"]].append(stat)
458
+
459
+ # Display grouped stats
460
+ for timestamp, stats in stats_by_time.items():
461
+ interaction_count += 1
462
+ st.sidebar.markdown(f"**Interaction {interaction_count}** ({timestamp})")
463
+
464
+ # Calculate total tokens for this interaction
465
+ total_input = sum(stat['input_tokens'] for stat in stats)
466
+ total_output = sum(stat['output_tokens'] for stat in stats)
467
+ total = total_input + total_output
468
+
469
+ # Display individual bot statistics
470
+ for stat in stats:
471
+ st.sidebar.markdown(
472
+ f"_{stat['bot_name']}_ \n"
473
+ f"Input: {stat['input_tokens']} tokens \n"
474
+ f"Output: {stat['output_tokens']} tokens \n"
475
+ f"Total: {stat['total_tokens']} tokens"
476
+ )
477
+
478
+ # Display interaction totals
479
+ st.sidebar.markdown("**Interaction Totals:**")
480
+ st.sidebar.markdown(
481
+ f"Total Input: {total_input} tokens \n"
482
+ f"Total Output: {total_output} tokens \n"
483
+ f"Total Usage: {total} tokens"
484
+ )
485
+ st.sidebar.markdown("---")
486
+
487
+ st.sidebar.markdown("")
488
+
489
+ if not check_password():
490
+ st.stop()
491
+
492
+ # Get user input before column creation
493
+ user_input = st.chat_input("Ask your question here", key="user_input")
494
+
495
+ if user_input:
496
+ # Create three columns
497
+ col1, col2, col3 = st.columns(3)
498
+
499
+ # First column - PATH-er B. (Document Display)
500
+ with col1:
501
+ st.markdown("<div class='chat-container'>", unsafe_allow_html=True)
502
+ st.subheader(f"{genparam.BOT_1_AVATAR} {genparam.BOT_1_NAME}")
503
+ st.markdown("<div class='chat-messages'>", unsafe_allow_html=True)
504
+
505
+ # Display previous messages
506
+ for message in st.session_state.chat_history_1:
507
+ if message["role"] == "user":
508
+ with st.chat_message(message["role"], avatar=genparam.USER_AVATAR):
509
+ st.markdown(message['content'])
510
+ else:
511
+ # Parse and display stored documents
512
+ documents = message['content'].split("\n\n")[2:] # Skip count line
513
+ for doc in documents:
514
+ if doc.strip():
515
+ parts = doc.split("\n")
516
+ doc_name = parts[0].replace("Document: ", "")
517
+ content = parts[1].replace("Content: ", "")
518
+ st.markdown(f"**{doc_name}**")
519
+ st.code(content)
520
+
521
+ # Add user message and get new response
522
+ st.session_state.chat_history_1.append({"role": "user", "content": user_input, "avatar": genparam.USER_AVATAR})
523
+ milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
524
+ client,
525
+ wml_credentials,
526
+ VECTOR_INDEXES[st.session_state.selected_kb]["index_1"]
527
+ )
528
+ system_prompt = genparam.BOT_1_PROMPT
529
+
530
+ response = fetch_response(
531
+ user_input,
532
+ milvus_client,
533
+ emb,
534
+ vector_index_properties,
535
+ vector_store_schema,
536
+ system_prompt,
537
+ st.session_state.chat_history_1
538
+ )
539
+ st.session_state.chat_history_1.append({"role": genparam.BOT_1_NAME, "content": response, "avatar": genparam.BOT_1_AVATAR})
540
+ st.markdown("</div></div>", unsafe_allow_html=True)
541
+
542
+ # Second column - MOD-ther S. (Uses documents from first vector index)
543
+ with col2:
544
+ st.markdown("<div class='chat-container'>", unsafe_allow_html=True)
545
+ st.subheader(f"{genparam.BOT_2_AVATAR} {genparam.BOT_2_NAME}")
546
+ st.markdown("<div class='chat-messages'>", unsafe_allow_html=True)
547
+
548
+ for message in st.session_state.chat_history_2:
549
+ if message["role"] != "user":
550
+ with st.chat_message(message["role"], avatar=genparam.BOT_2_AVATAR):
551
+ st.markdown(message['content'])
552
+
553
+ st.session_state.chat_history_2.append({"role": "user", "content": user_input, "avatar": genparam.USER_AVATAR})
554
+ milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
555
+ client,
556
+ wml_credentials,
557
+ #VECTOR_INDEXES[st.session_state.selected_kb]["index_1"]
558
+ config_path
559
+ )
560
+ system_prompt = SYSTEM_PROMPTS[st.session_state.selected_kb]["bot_2"]
561
+
562
+ response = fetch_response(
563
+ user_input,
564
+ milvus_client,
565
+ emb,
566
+ vector_index_properties,
567
+ vector_store_schema,
568
+ system_prompt,
569
+ st.session_state.chat_history_2
570
+ )
571
+
572
+ if genparam.INPUT_DEBUG_VIEW == 1:
573
+ with col1: # Access first column
574
+ bot_name = genparam.BOT_2_NAME if st.session_state.chat_history_1 == st.session_state.chat_history_2 else genparam.BOT_3_NAME
575
+ bot_avatar = genparam.BOT_2_AVATAR if st.session_state.chat_history_1 == st.session_state.chat_history_2 else genparam.BOT_3_AVATAR
576
+ st.markdown(f"**{bot_avatar} {bot_name} Prompt Data:**")
577
+ st.code(prompt_data, language="text")
578
+
579
+ st.session_state.chat_history_2.append({"role": genparam.BOT_2_NAME, "content": response, "avatar": genparam.BOT_2_AVATAR})
580
+ st.markdown("</div></div>", unsafe_allow_html=True)
581
+
582
+ # Third column - SYS-ter V. (Uses second vector index and chat history from second column)
583
+ with col3:
584
+ st.markdown("<div class='chat-container'>", unsafe_allow_html=True)
585
+ st.subheader(f"{genparam.BOT_3_AVATAR} {genparam.BOT_3_NAME}")
586
+ st.markdown("<div class='chat-messages'>", unsafe_allow_html=True)
587
+
588
+ for message in st.session_state.chat_history_3:
589
+ if message["role"] != "user":
590
+ with st.chat_message(message["role"], avatar=genparam.BOT_3_AVATAR):
591
+ st.markdown(message['content'])
592
+
593
+ st.session_state.chat_history_3.append({"role": "user", "content": user_input, "avatar": genparam.USER_AVATAR})
594
+ milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
595
+ client,
596
+ wml_credentials,
597
+ VECTOR_INDEXES[st.session_state.selected_kb]["index_2"]
598
+ )
599
+ system_prompt = SYSTEM_PROMPTS[st.session_state.selected_kb]["bot_3"]
600
+
601
+ response = fetch_response(
602
+ user_input,
603
+ milvus_client,
604
+ emb,
605
+ vector_index_properties,
606
+ vector_store_schema,
607
+ system_prompt,
608
+ st.session_state.chat_history_3
609
+ )
610
+
611
+ if genparam.INPUT_DEBUG_VIEW == 1:
612
+ with col1: # Access first column
613
+ bot_name = genparam.BOT_2_NAME if st.session_state.chat_history_1 == st.session_state.chat_history_2 else genparam.BOT_3_NAME
614
+ bot_avatar = genparam.BOT_2_AVATAR if st.session_state.chat_history_1 == st.session_state.chat_history_2 else genparam.BOT_3_AVATAR
615
+ st.markdown(f"**{bot_avatar} {bot_name} Prompt Data:**")
616
+ st.code(prompt_data, language="text")
617
+
618
+ st.session_state.chat_history_3.append({"role": genparam.BOT_3_NAME, "content": response, "avatar": genparam.BOT_3_AVATAR})
619
+ st.markdown("</div></div>", unsafe_allow_html=True)
620
+
621
+ # Update sidebar with new question
622
+ st.sidebar.markdown("---")
623
+ st.sidebar.markdown("**Latest Question:**")
624
+ st.sidebar.markdown(f"_{user_input}_")
625
+
626
+ if __name__ == "__main__":
627
+ main()