MilanM commited on
Commit
cf10217
·
verified ·
1 Parent(s): 5249af8

Delete neo_sages.py

Browse files
Files changed (1) hide show
  1. neo_sages.py +0 -529
neo_sages.py DELETED
@@ -1,529 +0,0 @@
1
- import streamlit as st
2
- from io import BytesIO
3
- import ibm_watsonx_ai
4
- import secretsload
5
- import genparam
6
- import requests
7
- import time
8
- import re
9
- import json
10
-
11
- from ibm_watsonx_ai.foundation_models import ModelInference
12
- from ibm_watsonx_ai import Credentials, APIClient
13
- from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
14
- from ibm_watsonx_ai.metanames import GenTextReturnOptMetaNames as RetParams
15
-
16
- from ibm_watsonx_ai.foundation_models import Embeddings
17
- from ibm_watsonx_ai.foundation_models.utils.enums import EmbeddingTypes
18
- from pymilvus import MilvusClient
19
-
20
- from secretsload import load_stsecrets
21
-
22
- credentials = load_stsecrets()
23
-
24
- st.set_page_config(
25
- page_title="The Solutioning Sages",
26
- page_icon="🪄",
27
- initial_sidebar_state="collapsed",
28
- layout="wide"
29
- )
30
-
31
- # Password protection
32
- def check_password():
33
- def password_entered():
34
- if st.session_state["password"] == st.secrets["app_password"]:
35
- st.session_state["password_correct"] = True
36
- del st.session_state["password"]
37
- else:
38
- st.session_state["password_correct"] = False
39
-
40
- if "password_correct" not in st.session_state:
41
- st.markdown("\n\n")
42
- st.text_input("Enter the password", type="password", on_change=password_entered, key="password")
43
- st.divider()
44
- st.info("Designed and developed by Milan Mrdenovic © IBM Norway 2024")
45
- return False
46
- elif not st.session_state["password_correct"]:
47
- st.markdown("\n\n")
48
- st.text_input("Enter the password", type="password", on_change=password_entered, key="password")
49
- st.divider()
50
- st.error("😕 Incorrect password")
51
- st.info("Designed and developed by Milan Mrdenovic © IBM Norway 2024")
52
- return False
53
- else:
54
- return True
55
-
56
- def initialize_session_state():
57
- if 'chat_history_1' not in st.session_state:
58
- st.session_state.chat_history_1 = []
59
- if 'chat_history_2' not in st.session_state:
60
- st.session_state.chat_history_2 = []
61
- if 'chat_history_3' not in st.session_state:
62
- st.session_state.chat_history_3 = []
63
- if 'first_question' not in st.session_state:
64
- st.session_state.first_question = False
65
- if "counter" not in st.session_state:
66
- st.session_state["counter"] = 0
67
- if 'token_statistics' not in st.session_state:
68
- st.session_state.token_statistics = []
69
-
70
- # three_column_style = """
71
- # <style>
72
- # .stColumn {
73
- # padding: 0.5rem;
74
- # border-right: 1px solid #dedede;
75
- # }
76
- # .stColumn:last-child {
77
- # border-right: none;
78
- # }
79
- # .chat-container {
80
- # height: calc(100vh - 200px);
81
- # overflow-y: auto;
82
- # }
83
- # </style>
84
- # """
85
-
86
- three_column_style = """
87
- <style>
88
- .stColumn {
89
- padding: 0.5rem;
90
- border-right: 1px solid #dedede;
91
- }
92
- .stColumn:last-child {
93
- border-right: none;
94
- }
95
- .chat-container {
96
- height: calc(100vh - 200px);
97
- overflow-y: auto;
98
- display: flex;
99
- flex-direction: column;
100
- }
101
- .chat-messages {
102
- display: flex;
103
- flex-direction: column;
104
- gap: 1rem;
105
- }
106
- </style>
107
- """ # Alt Style
108
-
109
- #-----
110
- def get_active_model():
111
- return genparam.SELECTED_MODEL_1 if genparam.ACTIVE_MODEL == 0 else genparam.SELECTED_MODEL_2
112
-
113
- def get_active_prompt_template():
114
- return genparam.PROMPT_TEMPLATE_1 if genparam.ACTIVE_MODEL == 0 else genparam.PROMPT_TEMPLATE_2
115
-
116
- def get_active_vector_index():
117
- return st.secrets["vector_index_id_1"] if genparam.ACTIVE_INDEX == 0 else st.secrets["vector_index_id_2"]
118
- #-----
119
-
120
- def setup_client(project_id):
121
- credentials = Credentials(
122
- url=st.secrets["url"],
123
- api_key=st.secrets["api_key"]
124
- )
125
- apo = st.secrets["api_key"]
126
- client = APIClient(credentials, project_id=project_id)
127
- return credentials, client
128
-
129
- wml_credentials, client = setup_client(st.secrets["project_id"])
130
-
131
- def setup_vector_index(client, wml_credentials, vector_index_id):
132
- vector_index_details = client.data_assets.get_details(vector_index_id)
133
- vector_index_properties = vector_index_details["entity"]["vector_index"]
134
-
135
- emb = Embeddings(
136
- model_id=vector_index_properties["settings"]["embedding_model_id"],
137
- #model_id="sentence-transformers/all-minilm-l12-v2",
138
- credentials=wml_credentials,
139
- project_id=st.secrets["project_id"],
140
- params={
141
- "truncate_input_tokens": 512
142
- }
143
- )
144
-
145
- vector_store_schema = vector_index_properties["settings"]["schema_fields"]
146
- connection_details = client.connections.get_details(vector_index_details["entity"]["vector_index"]["store"]["connection_id"])
147
- connection_properties = connection_details["entity"]["properties"]
148
-
149
- milvus_client = MilvusClient(
150
- uri=f'https://{connection_properties.get("host")}:{connection_properties.get("port")}',
151
- user=connection_properties.get("username"),
152
- password=connection_properties.get("password"),
153
- db_name=vector_index_properties["store"]["database"]
154
- )
155
-
156
- return milvus_client, emb, vector_index_properties, vector_store_schema
157
-
158
- def proximity_search(question, milvus_client, emb, vector_index_properties, vector_store_schema):
159
- query_vectors = emb.embed_query(question)
160
- milvus_response = milvus_client.search(
161
- collection_name=vector_index_properties["store"]["index"],
162
- data=[query_vectors],
163
- limit=vector_index_properties["settings"]["top_k"],
164
- metric_type="L2",
165
- output_fields=[
166
- vector_store_schema.get("text"),
167
- vector_store_schema.get("document_name"),
168
- vector_store_schema.get("page_number")
169
- ]
170
- )
171
-
172
- documents = []
173
-
174
- for hit in milvus_response[0]:
175
- text = hit["entity"].get(vector_store_schema.get("text"), "")
176
- doc_name = hit["entity"].get(vector_store_schema.get("document_name"), "Unknown Document")
177
- page_num = hit["entity"].get(vector_store_schema.get("page_number"), "N/A")
178
-
179
- formatted_result = f"Document: {doc_name}\nContent: {text}\nPage: {page_num}\n"
180
- documents.append(formatted_result)
181
-
182
- joined = "\n".join(documents)
183
- retrieved = f"""Number of Retrieved Documents: {len(documents)}\n\n{joined}"""
184
-
185
- return retrieved
186
-
187
- def prepare_prompt(prompt, chat_history):
188
- if genparam.TYPE == "chat" and chat_history:
189
- chats = "\n".join([f"{message['role']}: \"{message['content']}\"" for message in chat_history])
190
- prompt = f"""Retrieved Contextual Information:\n__grounding__\n\nConversation History:\n{chats}\n\nNew User Input: {prompt}"""
191
- return prompt
192
- else:
193
- prompt = f"""Retrieved Contextual Information:\n__grounding__\n\nUser Input: {prompt}"""
194
- return prompt
195
-
196
- def apply_prompt_syntax(prompt, system_prompt, prompt_template, bake_in_prompt_syntax):
197
- model_family_syntax = {
198
- "llama3-instruct (llama-3, 3.1 & 3.2) - system": """<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""",
199
- "llama3-instruct (llama-3, 3.1 & 3.2) - user": """<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""",
200
- "granite-13b-chat & instruct - system": """<|system|>\n{system_prompt}\n<|user|>\n{prompt}\n<|assistant|>\n\n""",
201
- "granite-13b-chat & instruct - user": """<|user|>\n{prompt}\n<|assistant|>\n\n""",
202
- "mistral & mixtral v2 tokenizer - system": """<s>[INST] System Prompt: {system_prompt} [/INST][INST] {prompt} [/INST]\n\n""",
203
- "mistral & mixtral v2 tokenizer - user": """<s>[INST] {prompt} [/INST]\n\n""",
204
- "no syntax - system": """{system_prompt}\n\n{prompt}""",
205
- "no syntax - user": """{prompt}"""
206
- }
207
-
208
- if bake_in_prompt_syntax:
209
- template = model_family_syntax[prompt_template]
210
- if system_prompt:
211
- return template.format(system_prompt=system_prompt, prompt=prompt)
212
- return prompt
213
-
214
- def generate_response(watsonx_llm, prompt_data, params):
215
- generated_response = watsonx_llm.generate_text_stream(prompt=prompt_data, params=params)
216
- for chunk in generated_response:
217
- yield chunk
218
-
219
- def fetch_response(user_input, milvus_client, emb, vector_index_properties, vector_store_schema, system_prompt, chat_history):
220
- # Get grounding documents
221
- grounding = proximity_search(
222
- question=user_input,
223
- milvus_client=milvus_client,
224
- emb=emb,
225
- vector_index_properties=vector_index_properties,
226
- vector_store_schema=vector_store_schema
227
- )
228
-
229
- # Special handling for PATH-er B. (first column)
230
- if chat_history == st.session_state.chat_history_1:
231
- # Display user question first
232
- with st.chat_message("user", avatar=genparam.USER_AVATAR):
233
- st.markdown(user_input)
234
-
235
- # Parse and display each document from the grounding
236
- documents = grounding.split("\n\n")[2:] # Skip the count line and first newline
237
- for doc in documents:
238
- if doc.strip(): # Only process non-empty strings
239
- parts = doc.split("\n")
240
- doc_name = parts[0].replace("Document: ", "")
241
- content = parts[1].replace("Content: ", "")
242
-
243
- # Display document with delay
244
- time.sleep(0.5)
245
- st.markdown(f"**{doc_name}**")
246
- st.code(content)
247
-
248
- # Store in chat history
249
- return grounding
250
-
251
- # For MOD-ther S. (second column)
252
- elif chat_history == st.session_state.chat_history_2:
253
- prompt = prepare_prompt(user_input, chat_history)
254
- prompt_data = apply_prompt_syntax(
255
- prompt,
256
- system_prompt,
257
- get_active_prompt_template(),
258
- genparam.BAKE_IN_PROMPT_SYNTAX
259
- )
260
- prompt_data = prompt_data.replace("__grounding__", grounding)
261
-
262
- # Add debug information to column 1 if enabled
263
- if genparam.INPUT_DEBUG_VIEW == 1:
264
- with st.columns(3)[0]: # Access first column
265
- st.markdown(f"**{genparam.BOT_2_AVATAR} {genparam.BOT_2_NAME} Prompt Data:**")
266
- st.code(prompt_data, language="text")
267
-
268
- # For SYS-ter V. (third column)
269
- else:
270
- # Get chat history from MOD-ther S.
271
- mod_ther_history = st.session_state.chat_history_2
272
- prompt = prepare_prompt(user_input, mod_ther_history)
273
- prompt_data = apply_prompt_syntax(
274
- prompt,
275
- system_prompt,
276
- get_active_prompt_template(),
277
- genparam.BAKE_IN_PROMPT_SYNTAX
278
- )
279
- prompt_data = prompt_data.replace("__grounding__", grounding)
280
-
281
- # Add debug information to column 1 if enabled
282
- if genparam.INPUT_DEBUG_VIEW == 1:
283
- with st.columns(3)[0]: # Access first column
284
- st.markdown(f"**{genparam.BOT_3_AVATAR} {genparam.BOT_3_NAME} Prompt Data:**")
285
- st.code(prompt_data, language="text")
286
-
287
- # Continue with normal processing for columns 2 and 3
288
- watsonx_llm = ModelInference(
289
- api_client=client,
290
- model_id=get_active_model(),
291
- verify=genparam.VERIFY
292
- )
293
-
294
- params = {
295
- GenParams.DECODING_METHOD: genparam.DECODING_METHOD,
296
- GenParams.MAX_NEW_TOKENS: genparam.MAX_NEW_TOKENS,
297
- GenParams.MIN_NEW_TOKENS: genparam.MIN_NEW_TOKENS,
298
- GenParams.REPETITION_PENALTY: genparam.REPETITION_PENALTY,
299
- GenParams.STOP_SEQUENCES: genparam.STOP_SEQUENCES
300
- }
301
-
302
- bot_name = None
303
- bot_avatar = None
304
- if chat_history == st.session_state.chat_history_1:
305
- bot_name = genparam.BOT_1_NAME
306
- bot_avatar = genparam.BOT_1_AVATAR
307
- elif chat_history == st.session_state.chat_history_2:
308
- bot_name = genparam.BOT_2_NAME
309
- bot_avatar = genparam.BOT_2_AVATAR
310
- else:
311
- bot_name = genparam.BOT_3_NAME
312
- bot_avatar = genparam.BOT_3_AVATAR
313
-
314
- with st.chat_message(bot_name, avatar=bot_avatar):
315
- if chat_history != st.session_state.chat_history_1: # Only generate responses for columns 2 and 3
316
- stream = generate_response(watsonx_llm, prompt_data, params)
317
- response = st.write_stream(stream)
318
-
319
- # Only capture tokens for MOD-ther S. and SYS-ter V.
320
- if genparam.TOKEN_CAPTURE_ENABLED and chat_history != st.session_state.chat_history_1:
321
- token_stats = capture_tokens(prompt_data, response, bot_name)
322
- if token_stats:
323
- st.session_state.token_statistics.append(token_stats)
324
- else:
325
- response = grounding # For column 1, we already displayed the content
326
-
327
- return response
328
-
329
- def capture_tokens(prompt_data, response, chat_number):
330
- if not genparam.TOKEN_CAPTURE_ENABLED:
331
- return
332
-
333
- watsonx_llm = ModelInference(
334
- api_client=client,
335
- model_id=genparam.SELECTED_MODEL,
336
- verify=genparam.VERIFY
337
- )
338
-
339
- input_tokens = watsonx_llm.tokenize(prompt=prompt_data)["result"]["token_count"]
340
- output_tokens = watsonx_llm.tokenize(prompt=response)["result"]["token_count"]
341
- total_tokens = input_tokens + output_tokens
342
-
343
- return {
344
- "bot_name": bot_name,
345
- "input_tokens": input_tokens,
346
- "output_tokens": output_tokens,
347
- "total_tokens": total_tokens,
348
- "timestamp": time.strftime("%H:%M:%S")
349
- }
350
-
351
- def main():
352
- initialize_session_state()
353
-
354
- # Apply custom styles
355
- st.markdown(three_column_style, unsafe_allow_html=True)
356
-
357
- # Sidebar configuration
358
- st.sidebar.header('The Solutioning Sages')
359
- st.sidebar.divider()
360
-
361
- # Display token statistics in sidebar
362
- st.sidebar.subheader("Token Usage Statistics")
363
-
364
- # Group token statistics by interaction (for MOD-ther S. and SYS-ter V. only)
365
- if st.session_state.token_statistics:
366
- current_timestamp = None
367
- interaction_count = 0
368
- stats_by_time = {}
369
-
370
- # Group stats by timestamp
371
- for stat in st.session_state.token_statistics:
372
- if stat["timestamp"] not in stats_by_time:
373
- stats_by_time[stat["timestamp"]] = []
374
- stats_by_time[stat["timestamp"]].append(stat)
375
-
376
- # Display grouped stats
377
- for timestamp, stats in stats_by_time.items():
378
- interaction_count += 1
379
- st.sidebar.markdown(f"**Interaction {interaction_count}** ({timestamp})")
380
-
381
- # Calculate total tokens for this interaction
382
- total_input = sum(stat['input_tokens'] for stat in stats)
383
- total_output = sum(stat['output_tokens'] for stat in stats)
384
- total = total_input + total_output
385
-
386
- # Display individual bot statistics
387
- for stat in stats:
388
- st.sidebar.markdown(
389
- f"_{stat['bot_name']}_ \n"
390
- f"Input: {stat['input_tokens']} tokens \n"
391
- f"Output: {stat['output_tokens']} tokens \n"
392
- f"Total: {stat['total_tokens']} tokens"
393
- )
394
-
395
- # Display interaction totals
396
- st.sidebar.markdown("**Interaction Totals:**")
397
- st.sidebar.markdown(
398
- f"Total Input: {total_input} tokens \n"
399
- f"Total Output: {total_output} tokens \n"
400
- f"Total Usage: {total} tokens"
401
- )
402
- st.sidebar.markdown("---")
403
-
404
- st.sidebar.markdown("")
405
-
406
-
407
- if not check_password():
408
- st.stop()
409
-
410
- # Get user input before column creation
411
- user_input = st.chat_input("Ask your question here", key="user_input")
412
-
413
- if user_input:
414
- # Create three columns
415
- col1, col2, col3 = st.columns(3)
416
-
417
- # First column - PATH-er B. (Document Display)
418
- with col1:
419
- st.markdown("<div class='chat-container'>", unsafe_allow_html=True)
420
- st.subheader(f"{genparam.BOT_1_AVATAR} {genparam.BOT_1_NAME}")
421
- st.markdown("<div class='chat-messages'>", unsafe_allow_html=True)
422
-
423
- # Display previous messages
424
- for message in st.session_state.chat_history_1:
425
- if message["role"] == "user":
426
- with st.chat_message(message["role"], avatar=genparam.USER_AVATAR):
427
- st.markdown(message['content'])
428
- else:
429
- # Parse and display stored documents
430
- documents = message['content'].split("\n\n")[2:] # Skip count line
431
- for doc in documents:
432
- if doc.strip():
433
- parts = doc.split("\n")
434
- doc_name = parts[0].replace("Document: ", "")
435
- content = parts[1].replace("Content: ", "")
436
- st.markdown(f"**{doc_name}**")
437
- st.code(content)
438
-
439
- # Add user message and get new response
440
- st.session_state.chat_history_1.append({"role": "user", "content": user_input, "avatar": genparam.USER_AVATAR})
441
- milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
442
- client,
443
- wml_credentials,
444
- st.secrets["vector_index_id_1"] # Use first vector index
445
- )
446
- system_prompt = genparam.BOT_1_PROMPT
447
-
448
- response = fetch_response(
449
- user_input,
450
- milvus_client,
451
- emb,
452
- vector_index_properties,
453
- vector_store_schema,
454
- system_prompt,
455
- st.session_state.chat_history_1
456
- )
457
- st.session_state.chat_history_1.append({"role": genparam.BOT_1_NAME, "content": response, "avatar": genparam.BOT_1_AVATAR})
458
- st.markdown("</div></div>", unsafe_allow_html=True)
459
-
460
- # Second column - MOD-ther S. (Uses documents from first vector index)
461
- with col2:
462
- st.markdown("<div class='chat-container'>", unsafe_allow_html=True)
463
- st.subheader(f"{genparam.BOT_2_AVATAR} {genparam.BOT_2_NAME}")
464
- st.markdown("<div class='chat-messages'>", unsafe_allow_html=True)
465
-
466
- for message in st.session_state.chat_history_2:
467
- if message["role"] != "user":
468
- with st.chat_message(message["role"], avatar=genparam.BOT_2_AVATAR):
469
- st.markdown(message['content'])
470
-
471
- st.session_state.chat_history_2.append({"role": "user", "content": user_input, "avatar": genparam.USER_AVATAR})
472
- milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
473
- client,
474
- wml_credentials,
475
- st.secrets["vector_index_id_1"] # Use first vector index
476
- )
477
- system_prompt = genparam.BOT_2_PROMPT
478
-
479
- response = fetch_response(
480
- user_input,
481
- milvus_client,
482
- emb,
483
- vector_index_properties,
484
- vector_store_schema,
485
- system_prompt,
486
- st.session_state.chat_history_2
487
- )
488
- st.session_state.chat_history_2.append({"role": genparam.BOT_2_NAME, "content": response, "avatar": genparam.BOT_2_AVATAR})
489
- st.markdown("</div></div>", unsafe_allow_html=True)
490
-
491
- # Third column - SYS-ter V. (Uses second vector index and chat history from second column)
492
- with col3:
493
- st.markdown("<div class='chat-container'>", unsafe_allow_html=True)
494
- st.subheader(f"{genparam.BOT_3_AVATAR} {genparam.BOT_3_NAME}")
495
- st.markdown("<div class='chat-messages'>", unsafe_allow_html=True)
496
-
497
- for message in st.session_state.chat_history_3:
498
- if message["role"] != "user":
499
- with st.chat_message(message["role"], avatar=genparam.BOT_3_AVATAR):
500
- st.markdown(message['content'])
501
-
502
- st.session_state.chat_history_3.append({"role": "user", "content": user_input, "avatar": genparam.USER_AVATAR})
503
- milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
504
- client,
505
- wml_credentials,
506
- st.secrets["vector_index_id_2"] # Use second vector index
507
- )
508
- system_prompt = genparam.BOT_3_PROMPT
509
-
510
- response = fetch_response(
511
- user_input,
512
- milvus_client,
513
- emb,
514
- vector_index_properties,
515
- vector_store_schema,
516
- system_prompt,
517
- st.session_state.chat_history_3
518
- )
519
- st.session_state.chat_history_3.append({"role": genparam.BOT_3_NAME, "content": response, "avatar": genparam.BOT_3_AVATAR})
520
- st.markdown("</div></div>", unsafe_allow_html=True)
521
-
522
-
523
- # Update sidebar with new question
524
- st.sidebar.markdown("---")
525
- st.sidebar.markdown("**Latest Question:**")
526
- st.sidebar.markdown(f"_{user_input}_")
527
-
528
- if __name__ == "__main__":
529
- main()