MilanM commited on
Commit
11a4188
·
verified ·
1 Parent(s): cf10217

Update neo_sages5.py

Browse files
Files changed (1) hide show
  1. neo_sages5.py +408 -589
neo_sages5.py CHANGED
@@ -1,627 +1,446 @@
1
  import streamlit as st
2
- from io import BytesIO
3
- import ibm_watsonx_ai
4
- import secretsload
5
- import genparam
6
- import requests
7
- import time
8
- import re
9
- import json
10
- import cos_creds
11
- from ibm_watsonx_ai.foundation_models import ModelInference
12
- from ibm_watsonx_ai import Credentials, APIClient
13
- from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
14
- from ibm_watsonx_ai.metanames import GenTextReturnOptMetaNames as RetParams
15
- from knowledge_bases import KNOWLEDGE_BASE_OPTIONS, SYSTEM_PROMPTS, VECTOR_INDEXES
16
- from ibm_watsonx_ai.foundation_models import Embeddings
17
- from ibm_watsonx_ai.foundation_models.utils.enums import EmbeddingTypes
18
- from pymilvus import MilvusClient
19
-
20
- from secretsload import load_stsecrets
21
-
22
- credentials = load_stsecrets()
23
-
24
- st.set_page_config(
25
- page_title="The Solutioning Sages",
26
- page_icon="🪄",
27
- initial_sidebar_state="collapsed",
28
- layout="wide"
29
- )
30
-
31
- # Password protection
32
- def check_password():
33
- def password_entered():
34
- if st.session_state["password"] == st.secrets["app_password"]:
35
- st.session_state["password_correct"] = True
36
- del st.session_state["password"]
37
- else:
38
- st.session_state["password_correct"] = False
39
-
40
- if "password_correct" not in st.session_state:
41
- st.markdown("\n\n")
42
- st.text_input("Enter the password", type="password", on_change=password_entered, key="password")
43
- st.divider()
44
- st.info("Designed and developed by Milan Mrdenovic © IBM Norway 2024")
45
- return False
46
- elif not st.session_state["password_correct"]:
47
- st.markdown("\n\n")
48
- st.text_input("Enter the password", type="password", on_change=password_entered, key="password")
49
- st.divider()
50
- st.error("😕 Incorrect password")
51
- st.info("Designed and developed by Milan Mrdenovic © IBM Norway 2024")
52
- return False
53
- else:
54
- return True
55
-
56
- def initialize_session_state():
57
- if 'chat_history_1' not in st.session_state:
58
- st.session_state.chat_history_1 = []
59
- if 'chat_history_2' not in st.session_state:
60
- st.session_state.chat_history_2 = []
61
- if 'chat_history_3' not in st.session_state:
62
- st.session_state.chat_history_3 = []
63
- if 'first_question' not in st.session_state:
64
- st.session_state.first_question = False
65
- if "counter" not in st.session_state:
66
- st.session_state["counter"] = 0
67
- if 'token_statistics' not in st.session_state:
68
- st.session_state.token_statistics = []
69
- if 'selected_kb' not in st.session_state:
70
- st.session_state.selected_kb = KNOWLEDGE_BASE_OPTIONS[0]
71
- if 'current_system_prompts' not in st.session_state:
72
- st.session_state.current_system_prompts = SYSTEM_PROMPTS[st.session_state.selected_kb]
73
-
74
- # three_column_style = """
75
- # <style>
76
- # .stColumn {
77
- # padding: 0.5rem;
78
- # border-right: 1px solid #dedede;
79
- # }
80
- # .stColumn:last-child {
81
- # border-right: none;
82
- # }
83
- # .chat-container {
84
- # height: calc(100vh - 200px);
85
- # overflow-y: auto;
86
- # }
87
- # </style>
88
- # """
89
-
90
- three_column_style = """
91
- <style>
92
- .stColumn {
93
- padding: 0.5rem;
94
- border-right: 1px solid #dedede;
95
- }
96
- .stColumn:last-child {
97
- border-right: none;
98
- }
99
- .chat-container {
100
- height: calc(100vh - 200px);
101
- overflow-y: auto;
102
- display: flex;
103
- flex-direction: column;
104
- }
105
- .chat-messages {
106
- display: flex;
107
- flex-direction: column;
108
- gap: 1rem;
109
- }
110
- </style>
111
- """ # Alt Style
112
-
113
- #-----
114
- def get_active_model():
115
- return genparam.SELECTED_MODEL_1 if genparam.ACTIVE_MODEL == 0 else genparam.SELECTED_MODEL_2
116
-
117
- def get_active_prompt_template():
118
- return genparam.PROMPT_TEMPLATE_1 if genparam.ACTIVE_MODEL == 0 else genparam.PROMPT_TEMPLATE_2
119
-
120
- def get_active_vector_index():
121
- selected_kb = st.session_state.selected_kb
122
- if genparam.ACTIVE_INDEX == 0:
123
- return VECTOR_INDEXES[selected_kb]["index_1"]
124
- else:
125
- return VECTOR_INDEXES[selected_kb]["index_2"]
126
- #-----
127
 
128
- def setup_client(project_id=None):
129
- credentials = Credentials(
130
- url=st.secrets["url"],
131
- api_key=st.secrets["api_key"]
132
- )
133
- # Use the passed project_id if provided, otherwise fallback to secrets
134
- project_id = project_id or st.secrets["project_id"]
135
- client = APIClient(credentials, project_id=project_id)
136
- return credentials, client
137
-
138
- wml_credentials, client = setup_client(st.secrets["project_id"])
139
-
140
- import ibm_boto3
141
- from ibm_botocore.client import Config
142
- ###==========================================1
143
- def setup_retrieval_cos_client(config_dict=None):
144
- # Default credentials (same as main notebook for now)
145
- default_config = {
146
- "api_key": cos_creds.COS_API_KEY,
147
- "instance_id": cos_creds.COS_INSTANCE_ID,
148
- "endpoint": cos_creds.COS_ENDPOINT,
149
- "bucket": cos_creds.BUCKET_NAME
150
- }
151
-
152
- # Use provided config or default
153
- config = config_dict if config_dict is not None else default_config
154
-
155
- # Initialize the retrieval COS client
156
- retrieval_cos_client = ibm_boto3.client(
157
- "s3",
158
- ibm_api_key_id=config["api_key"],
159
- ibm_service_instance_id=config["instance_id"],
160
- config=Config(signature_version="oauth"),
161
- endpoint_url=config["endpoint"]
162
- )
163
-
164
- # Verify the connection by trying to list objects
165
- try:
166
- retrieval_cos_client.list_objects(Bucket=config["bucket"], MaxKeys=1)
167
- print("Retrieval COS client successfully initialized and connected")
168
- except Exception as e:
169
- print(f"Error verifying retrieval COS client connection: {str(e)}")
170
- raise
171
-
172
- return retrieval_cos_client
173
-
174
- retrieval_cos_client = setup_retrieval_cos_client()
175
-
176
- def load_callable_index_config(config_path):
177
- try:
178
- # Download config file content using retrieval client
179
- response = retrieval_cos_client.get_object(Bucket=cos_creds.BUCKET_NAME, Key=config_path)
180
- config_content = response['Body'].read().decode('utf-8')
181
- return json.loads(config_content)
182
- except Exception as e:
183
- raise Exception(f"Error loading callable index config: {str(e)}")
184
- ###==========================================1-2
185
- def setup_vector_index(client, wml_credentials, config_path):
186
- # Load the configuration using load_callable_index_config
187
- config = load_callable_index_config(config_path)
188
-
189
- # Initialize embeddings
190
- emb = Embeddings(
191
- model_id=config["embedding"]["model_id"],
192
- credentials=wml_credentials,
193
- project_id=PROJECT_ID,
194
- params={
195
- "truncate_input_tokens": config["embedding"]["max_tokens"]
196
  }
197
- )
198
-
199
- # Get connection details
200
- connection_details = client.connections.get_details(config["connection"]["connection_id"])
201
- connection_properties = connection_details["entity"]["properties"]
202
-
203
- # Initialize Milvus client
204
- milvus_client = MilvusClient(
205
- uri=f'https://{connection_properties.get("host")}:{connection_properties.get("port")}',
206
- user=connection_properties.get("username"),
207
- password=connection_properties.get("password"),
208
- db_name=config["connection"]["database"]
209
- )
210
-
211
- # Prepare vector index properties
212
- vector_index_properties = {
213
- "store": {
214
- "index": config["collection"]["name"],
215
- "database": config["connection"]["database"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  },
217
- "settings": {
218
- "embedding_model_id": config["embedding"]["model_id"],
219
- "schema_fields": config["collection"]["schema"]["fields"],
220
- "top_k": config["index_settings"]["top_k"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  }
222
- }
223
-
224
- # Return the entire config instead of just the schema fields
225
- return milvus_client, emb, vector_index_properties, config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
- ###==========================================1-2
228
- def proximity_search(question, milvus_client, emb, vector_index_properties, config):
 
 
 
 
 
 
 
 
 
229
 
230
- query_vectors = emb.embed_query(question)
231
- schema_fields = config["collection"]["schema"]["fields"]
232
- milvus_response = milvus_client.search(
233
- collection_name=vector_index_properties["store"]["index"],
234
- data=[query_vectors],
235
- limit=vector_index_properties["settings"]["top_k"],
236
- metric_type="L2",
237
- output_fields=[
238
- schema_fields["text_field"],
239
- schema_fields["document_name"],
240
- schema_fields["page_number"]
241
- ]
242
- )
243
-
244
- documents = []
245
 
246
- for hit in milvus_response[0]:
247
- text = hit["entity"].get(schema_fields["text_field"], "")
248
- doc_name = hit["entity"].get(schema_fields["document_name"], "Unknown Document")
249
- page_num = hit["entity"].get(schema_fields["page_number"], "N/A")
250
-
251
- formatted_result = f"Document: {doc_name}\nContent: {text}\nPage: {page_num}\n"
252
- documents.append(formatted_result)
253
 
254
- # Format final output
255
- joined = "\n".join(documents)
256
- retrieved = f"Number of Retrieved Documents: {len(documents)}\n\n{joined}"
257
- return retrieved
258
- ###==========================================2-3
259
-
260
- def prepare_prompt(prompt, chat_history):
261
- if genparam.TYPE == "chat" and chat_history:
262
- chats = "\n".join([f"{message['role']}: \"{message['content']}\"" for message in chat_history])
263
- prompt = f"""Retrieved Contextual Information:\n__grounding__\n\nConversation History:\n{chats}\n\nNew User Input: {prompt}"""
264
- return prompt
265
  else:
266
- prompt = f"""Retrieved Contextual Information:\n__grounding__\n\nUser Input: {prompt}"""
267
- return prompt
268
-
269
- def apply_prompt_syntax(prompt, system_prompt, prompt_template, bake_in_prompt_syntax):
270
- model_family_syntax = {
271
- "llama3-instruct (llama-3, 3.1 & 3.2) - system": """<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""",
272
- "llama3-instruct (llama-3, 3.1 & 3.2) - user": """<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""",
273
- "granite-13b-chat & instruct - system": """<|system|>\n{system_prompt}\n<|user|>\n{prompt}\n<|assistant|>\n\n""",
274
- "granite-13b-chat & instruct - user": """<|user|>\n{prompt}\n<|assistant|>\n\n""",
275
- "mistral & mixtral v2 tokenizer - system": """<s>[INST] System Prompt: {system_prompt} [/INST][INST] {prompt} [/INST]\n\n""",
276
- "mistral & mixtral v2 tokenizer - user": """<s>[INST] {prompt} [/INST]\n\n""",
277
- "no syntax - system": """{system_prompt}\n\n{prompt}""",
278
- "no syntax - user": """{prompt}"""
279
- }
280
-
281
- if bake_in_prompt_syntax:
282
- template = model_family_syntax[prompt_template]
283
- if system_prompt:
284
- return template.format(system_prompt=system_prompt, prompt=prompt)
285
- return prompt
286
-
287
- def generate_response(watsonx_llm, prompt_data, params):
288
- generated_response = watsonx_llm.generate_text_stream(prompt=prompt_data, params=params)
289
- for chunk in generated_response:
290
- yield chunk
291
-
292
- def fetch_response(user_input, milvus_client, emb, vector_props, config, system_prompt, chat_history):
293
- # Get grounding documents
294
- grounding = proximity_search(
295
- question=user_input,
296
- milvus_client=milvus_client,
297
- emb=emb,
298
- vector_props=vector_promps,
299
- config=config_path
300
- )
301
-
302
- # Special handling for PATH-er B. (first column)
303
- if chat_history == st.session_state.chat_history_1:
304
- # Display user question first
305
- with st.chat_message("user", avatar=genparam.USER_AVATAR):
306
- st.markdown(user_input)
307
-
308
- # Parse and display each document from the grounding
309
- documents = grounding.split("\n\n")[2:] # Skip the count line and first newline
310
- for doc in documents:
311
- if doc.strip(): # Only process non-empty strings
312
- parts = doc.split("\n")
313
- doc_name = parts[0].replace("Document: ", "")
314
- content = parts[1].replace("Content: ", "")
315
-
316
- # Display document with delay
317
- time.sleep(0.5)
318
- st.markdown(f"**{doc_name}**")
319
- st.code(content)
320
-
321
- # Store in chat history
322
- return grounding
323
-
324
- # For MOD-ther S. (second column) and SYS-ter V. (third column)
325
- else:
326
- prompt = prepare_prompt(user_input, chat_history)
327
- prompt_data = apply_prompt_syntax(
328
- prompt,
329
- system_prompt, # Using the system_prompt passed to the function
330
- get_active_prompt_template(),
331
- genparam.BAKE_IN_PROMPT_SYNTAX
332
- )
333
- prompt_data = prompt_data.replace("__grounding__", grounding)
334
-
335
- # # Add debug information to column 1 if enabled
336
- # if genparam.INPUT_DEBUG_VIEW == 1:
337
- # with col1: # Access first column
338
- # bot_name = genparam.BOT_2_NAME if chat_history == st.session_state.chat_history_2 else genparam.BOT_3_NAME
339
- # bot_avatar = genparam.BOT_2_AVATAR if chat_history == st.session_state.chat_history_2 else genparam.BOT_3_AVATAR
340
- # st.markdown(f"**{bot_avatar} {bot_name} Prompt Data:**")
341
- # st.code(prompt_data, language="text")
342
-
343
- # Continue with normal processing for columns 2 and 3
344
- watsonx_llm = ModelInference(
345
- api_client=client,
346
- model_id=get_active_model(),
347
- verify=genparam.VERIFY
348
- )
349
 
350
- params = {
351
- GenParams.DECODING_METHOD: genparam.DECODING_METHOD,
352
- GenParams.MAX_NEW_TOKENS: genparam.MAX_NEW_TOKENS,
353
- GenParams.MIN_NEW_TOKENS: genparam.MIN_NEW_TOKENS,
354
- GenParams.REPETITION_PENALTY: genparam.REPETITION_PENALTY,
355
- GenParams.STOP_SEQUENCES: genparam.STOP_SEQUENCES
356
- }
357
-
358
- bot_name = None
359
- bot_avatar = None
360
- if chat_history == st.session_state.chat_history_1:
361
- bot_name = genparam.BOT_1_NAME
362
- bot_avatar = genparam.BOT_1_AVATAR
363
- elif chat_history == st.session_state.chat_history_2:
364
- bot_name = genparam.BOT_2_NAME
365
- bot_avatar = genparam.BOT_2_AVATAR
 
 
 
 
 
366
  else:
367
- bot_name = genparam.BOT_3_NAME
368
- bot_avatar = genparam.BOT_3_AVATAR
369
-
370
- with st.chat_message(bot_name, avatar=bot_avatar):
371
- if chat_history != st.session_state.chat_history_1: # Only generate responses for columns 2 and 3
372
- stream = generate_response(watsonx_llm, prompt_data, params)
373
- response = st.write_stream(stream)
374
-
375
- # Only capture tokens for MOD-ther S. and SYS-ter V.
376
- if genparam.TOKEN_CAPTURE_ENABLED and chat_history != st.session_state.chat_history_1:
377
- token_stats = capture_tokens(prompt_data, response, bot_name)
378
- if token_stats:
379
- st.session_state.token_statistics.append(token_stats)
380
- else:
381
- response = grounding # For column 1, we already displayed the content
382
-
383
- return response
384
-
385
- def capture_tokens(prompt_data, response, chat_number):
386
- if not genparam.TOKEN_CAPTURE_ENABLED:
387
- return
388
-
389
- watsonx_llm = ModelInference(
390
- api_client=client,
391
- model_id=genparam.SELECTED_MODEL,
392
- verify=genparam.VERIFY
393
- )
394
-
395
- input_tokens = watsonx_llm.tokenize(prompt=prompt_data)["result"]["token_count"]
396
- output_tokens = watsonx_llm.tokenize(prompt=response)["result"]["token_count"]
397
- total_tokens = input_tokens + output_tokens
398
-
399
- return {
400
- "bot_name": bot_name,
401
- "input_tokens": input_tokens,
402
- "output_tokens": output_tokens,
403
- "total_tokens": total_tokens,
404
- "timestamp": time.strftime("%H:%M:%S")
405
- }
406
 
407
  def main():
 
408
  initialize_session_state()
409
-
410
- # Apply custom styles
411
- st.markdown(three_column_style, unsafe_allow_html=True)
412
 
413
- # Sidebar configuration
414
- st.sidebar.header('The Solutioning Sages')
415
  st.sidebar.divider()
416
 
417
- # Knowledge Base Selection
418
- selected_kb = st.sidebar.selectbox(
419
- "Select Knowledge Base",
420
- KNOWLEDGE_BASE_OPTIONS,
421
- index=KNOWLEDGE_BASE_OPTIONS.index(st.session_state.selected_kb)
422
  )
423
 
424
- # Update knowledge base related values if selection changes
425
- if selected_kb != st.session_state.selected_kb:
426
- st.session_state.selected_kb = selected_kb
427
- # Update the client with the new project_id
428
- global client, wml_credentials
429
- wml_credentials, client = setup_client(VECTOR_INDEXES[selected_kb]["project_id"])
430
 
431
- # Display current knowledge base contents
432
- with st.sidebar.expander("Knowledge Base Contents"):
433
- for doc in VECTOR_INDEXES[selected_kb]["contents"]:
434
- st.write(f"📄 {doc}")
435
 
436
- # Display active model information
437
- st.sidebar.divider()
438
- active_model = genparam.SELECTED_MODEL_1 if genparam.ACTIVE_MODEL == 0 else genparam.SELECTED_MODEL_2
439
- st.sidebar.markdown("**Active Model:**")
440
- st.sidebar.code(active_model)
 
441
 
 
 
 
 
 
 
 
442
  st.sidebar.divider()
 
 
 
 
443
 
444
- # Display token statistics in sidebar
445
- st.sidebar.subheader("Token Usage Statistics")
446
 
447
- # Group token statistics by interaction (for MOD-ther S. and SYS-ter V. only)
448
- if st.session_state.token_statistics:
449
- current_timestamp = None
450
- interaction_count = 0
451
- stats_by_time = {}
452
 
453
- # Group stats by timestamp
454
- for stat in st.session_state.token_statistics:
455
- if stat["timestamp"] not in stats_by_time:
456
- stats_by_time[stat["timestamp"]] = []
457
- stats_by_time[stat["timestamp"]].append(stat)
 
 
 
458
 
459
- # Display grouped stats
460
- for timestamp, stats in stats_by_time.items():
461
- interaction_count += 1
462
- st.sidebar.markdown(f"**Interaction {interaction_count}** ({timestamp})")
463
-
464
- # Calculate total tokens for this interaction
465
- total_input = sum(stat['input_tokens'] for stat in stats)
466
- total_output = sum(stat['output_tokens'] for stat in stats)
467
- total = total_input + total_output
468
-
469
- # Display individual bot statistics
470
- for stat in stats:
471
- st.sidebar.markdown(
472
- f"_{stat['bot_name']}_ \n"
473
- f"Input: {stat['input_tokens']} tokens \n"
474
- f"Output: {stat['output_tokens']} tokens \n"
475
- f"Total: {stat['total_tokens']} tokens"
476
- )
477
-
478
- # Display interaction totals
479
- st.sidebar.markdown("**Interaction Totals:**")
480
- st.sidebar.markdown(
481
- f"Total Input: {total_input} tokens \n"
482
- f"Total Output: {total_output} tokens \n"
483
- f"Total Usage: {total} tokens"
484
- )
485
- st.sidebar.markdown("---")
486
 
487
- st.sidebar.markdown("")
488
-
489
- if not check_password():
490
- st.stop()
491
-
492
- # Get user input before column creation
493
- user_input = st.chat_input("Ask your question here", key="user_input")
 
 
 
 
494
 
495
- if user_input:
496
- # Create three columns
497
- col1, col2, col3 = st.columns(3)
 
 
 
498
 
499
- # First column - PATH-er B. (Document Display)
500
- with col1:
501
- st.markdown("<div class='chat-container'>", unsafe_allow_html=True)
502
- st.subheader(f"{genparam.BOT_1_AVATAR} {genparam.BOT_1_NAME}")
503
- st.markdown("<div class='chat-messages'>", unsafe_allow_html=True)
504
-
505
- # Display previous messages
506
- for message in st.session_state.chat_history_1:
507
- if message["role"] == "user":
508
- with st.chat_message(message["role"], avatar=genparam.USER_AVATAR):
509
- st.markdown(message['content'])
510
- else:
511
- # Parse and display stored documents
512
- documents = message['content'].split("\n\n")[2:] # Skip count line
513
- for doc in documents:
514
- if doc.strip():
515
- parts = doc.split("\n")
516
- doc_name = parts[0].replace("Document: ", "")
517
- content = parts[1].replace("Content: ", "")
518
- st.markdown(f"**{doc_name}**")
519
- st.code(content)
520
-
521
- # Add user message and get new response
522
- st.session_state.chat_history_1.append({"role": "user", "content": user_input, "avatar": genparam.USER_AVATAR})
523
- milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
524
- client,
525
- wml_credentials,
526
- VECTOR_INDEXES[st.session_state.selected_kb]["index_1"]
527
  )
528
- system_prompt = genparam.BOT_1_PROMPT
529
 
530
- response = fetch_response(
531
- user_input,
532
- milvus_client,
533
- emb,
534
- vector_index_properties,
535
- vector_store_schema,
536
- system_prompt,
537
- st.session_state.chat_history_1
538
  )
539
- st.session_state.chat_history_1.append({"role": genparam.BOT_1_NAME, "content": response, "avatar": genparam.BOT_1_AVATAR})
540
- st.markdown("</div></div>", unsafe_allow_html=True)
541
-
542
- # Second column - MOD-ther S. (Uses documents from first vector index)
543
- with col2:
544
- st.markdown("<div class='chat-container'>", unsafe_allow_html=True)
545
- st.subheader(f"{genparam.BOT_2_AVATAR} {genparam.BOT_2_NAME}")
546
- st.markdown("<div class='chat-messages'>", unsafe_allow_html=True)
547
-
548
- for message in st.session_state.chat_history_2:
549
- if message["role"] != "user":
550
- with st.chat_message(message["role"], avatar=genparam.BOT_2_AVATAR):
551
- st.markdown(message['content'])
552
-
553
- st.session_state.chat_history_2.append({"role": "user", "content": user_input, "avatar": genparam.USER_AVATAR})
554
- milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
555
- client,
556
- wml_credentials,
557
- #VECTOR_INDEXES[st.session_state.selected_kb]["index_1"]
558
- config_path
559
- )
560
- system_prompt = SYSTEM_PROMPTS[st.session_state.selected_kb]["bot_2"]
561
-
562
- response = fetch_response(
563
- user_input,
564
- milvus_client,
565
- emb,
566
- vector_index_properties,
567
- vector_store_schema,
568
- system_prompt,
569
- st.session_state.chat_history_2
570
- )
571
-
572
- if genparam.INPUT_DEBUG_VIEW == 1:
573
- with col1: # Access first column
574
- bot_name = genparam.BOT_2_NAME if st.session_state.chat_history_1 == st.session_state.chat_history_2 else genparam.BOT_3_NAME
575
- bot_avatar = genparam.BOT_2_AVATAR if st.session_state.chat_history_1 == st.session_state.chat_history_2 else genparam.BOT_3_AVATAR
576
- st.markdown(f"**{bot_avatar} {bot_name} Prompt Data:**")
577
- st.code(prompt_data, language="text")
578
-
579
- st.session_state.chat_history_2.append({"role": genparam.BOT_2_NAME, "content": response, "avatar": genparam.BOT_2_AVATAR})
580
- st.markdown("</div></div>", unsafe_allow_html=True)
581
-
582
- # Third column - SYS-ter V. (Uses second vector index and chat history from second column)
583
- with col3:
584
- st.markdown("<div class='chat-container'>", unsafe_allow_html=True)
585
- st.subheader(f"{genparam.BOT_3_AVATAR} {genparam.BOT_3_NAME}")
586
- st.markdown("<div class='chat-messages'>", unsafe_allow_html=True)
587
 
588
- for message in st.session_state.chat_history_3:
589
- if message["role"] != "user":
590
- with st.chat_message(message["role"], avatar=genparam.BOT_3_AVATAR):
591
- st.markdown(message['content'])
592
-
593
- st.session_state.chat_history_3.append({"role": "user", "content": user_input, "avatar": genparam.USER_AVATAR})
594
- milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index(
595
- client,
596
- wml_credentials,
597
- VECTOR_INDEXES[st.session_state.selected_kb]["index_2"]
598
- )
599
- system_prompt = SYSTEM_PROMPTS[st.session_state.selected_kb]["bot_3"]
600
-
601
- response = fetch_response(
602
- user_input,
603
- milvus_client,
604
- emb,
605
- vector_index_properties,
606
- vector_store_schema,
607
- system_prompt,
608
- st.session_state.chat_history_3
609
- )
610
-
611
- if genparam.INPUT_DEBUG_VIEW == 1:
612
- with col1: # Access first column
613
- bot_name = genparam.BOT_2_NAME if st.session_state.chat_history_1 == st.session_state.chat_history_2 else genparam.BOT_3_NAME
614
- bot_avatar = genparam.BOT_2_AVATAR if st.session_state.chat_history_1 == st.session_state.chat_history_2 else genparam.BOT_3_AVATAR
615
- st.markdown(f"**{bot_avatar} {bot_name} Prompt Data:**")
616
- st.code(prompt_data, language="text")
617
-
618
- st.session_state.chat_history_3.append({"role": genparam.BOT_3_NAME, "content": response, "avatar": genparam.BOT_3_AVATAR})
619
- st.markdown("</div></div>", unsafe_allow_html=True)
620
-
621
- # Update sidebar with new question
622
- st.sidebar.markdown("---")
623
- st.sidebar.markdown("**Latest Question:**")
624
- st.sidebar.markdown(f"_{user_input}_")
625
 
626
  if __name__ == "__main__":
627
  main()
 
1
  import streamlit as st
2
+ from langchain import PromptTemplate
3
+ from typing import TypedDict, List, Dict, Optional
4
+ from langchain.graphs import StateGraph
5
+ from dataclasses import dataclass, field
6
+ import random
7
+
8
+ # Data Structures
9
+ @dataclass
10
+ class StoryState:
11
+ current_step: int = 0
12
+ max_steps: int = 5
13
+ story_log: List[str] = field(default_factory=list)
14
+ user_inputs: List[str] = field(default_factory=list)
15
+ character1_responses: List[str] = field(default_factory=list)
16
+ character2_responses: List[str] = field(default_factory=list)
17
+ story_outcome: Optional[str] = None
18
+
19
+ class MicroStory(TypedDict):
20
+ title: str
21
+ initial_setup: str
22
+ character1_name: str
23
+ character2_name: str
24
+ steps: List[str]
25
+ success_conditions: List[str]
26
+ failure_conditions: List[str]
27
+
28
+ from langchain.graphs import StateGraph
29
+ from typing import Dict, List, Any
30
+ from dataclasses import dataclass
31
+ from enum import Enum
32
+
33
+ class StoryNodeType(Enum):
34
+ SETUP = "setup"
35
+ USER_INPUT = "user_input"
36
+ CHARACTER1_RESPONSE = "character1_response"
37
+ CHARACTER2_RESPONSE = "character2_response"
38
+ EVALUATION = "evaluation"
39
+
40
+ @dataclass
41
+ class StoryGraphState:
42
+ current_node: StoryNodeType
43
+ story_data: Dict[str, Any]
44
+ accumulated_context: List[Dict[str, str]]
45
+ step_count: int = 0
46
+
47
+ def create_story_graph() -> StateGraph:
48
+ """Creates the state graph for story progression"""
49
+ graph = StateGraph()
50
+
51
+ # Define state transitions
52
+ def setup_to_user_input(state: StoryGraphState) -> StoryGraphState:
53
+ state.current_node = StoryNodeType.USER_INPUT
54
+ return state
55
+
56
+ def user_input_to_char1(state: StoryGraphState, user_input: str) -> StoryGraphState:
57
+ state.current_node = StoryNodeType.CHARACTER1_RESPONSE
58
+ state.accumulated_context.append({"role": "user", "content": user_input})
59
+ return state
60
+
61
+ def char1_to_char2(state: StoryGraphState, char1_response: str) -> StoryGraphState:
62
+ state.current_node = StoryNodeType.CHARACTER2_RESPONSE
63
+ state.accumulated_context.append({"role": "character1", "content": char1_response})
64
+ return state
65
+
66
+ def char2_to_evaluation(state: StoryGraphState, char2_response: str) -> StoryGraphState:
67
+ state.current_node = StoryNodeType.EVALUATION
68
+ state.accumulated_context.append({"role": "character2", "content": char2_response})
69
+ state.step_count += 1
70
+ return state
71
+
72
+ def evaluation_to_next(state: StoryGraphState) -> StoryGraphState:
73
+ if state.step_count >= 5:
74
+ # Story is complete, stay in evaluation
75
+ return state
76
+ # Move to next user input
77
+ state.current_node = StoryNodeType.USER_INPUT
78
+ return state
79
+
80
+ # Add nodes and edges
81
+ graph.add_node("setup", setup_to_user_input)
82
+ graph.add_node("user_input", user_input_to_char1)
83
+ graph.add_node("character1_response", char1_to_char2)
84
+ graph.add_node("character2_response", char2_to_evaluation)
85
+ graph.add_node("evaluation", evaluation_to_next)
86
+
87
+ # Connect nodes
88
+ graph.add_edge("setup", "user_input")
89
+ graph.add_edge("user_input", "character1_response")
90
+ graph.add_edge("character1_response", "character2_response")
91
+ graph.add_edge("character2_response", "evaluation")
92
+ graph.add_edge("evaluation", "user_input")
93
+
94
+ return graph
95
+
96
+ class StoryRunner:
97
+ def __init__(self, story_data: Dict[str, Any]):
98
+ self.graph = create_story_graph()
99
+ self.state = StoryGraphState(
100
+ current_node=StoryNodeType.SETUP,
101
+ story_data=story_data,
102
+ accumulated_context=[]
103
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ def process_user_input(self, user_input: str) -> Dict[str, Any]:
106
+ """Process user input and advance the story state"""
107
+ if self.state.current_node != StoryNodeType.USER_INPUT:
108
+ raise ValueError("Not ready for user input")
109
+
110
+ # Advance through the graph
111
+ self.state = self.graph.transition("user_input", self.state, user_input)
112
+ self.state = self.graph.transition("character1_response", self.state, None)
113
+ self.state = self.graph.transition("character2_response", self.state, None)
114
+ self.state = self.graph.transition("evaluation", self.state, None)
115
+
116
+ # Return current state info
117
+ return {
118
+ "step_count": self.state.step_count,
119
+ "is_complete": self.state.step_count >= 5,
120
+ "current_context": self.state.accumulated_context[-3:] if self.state.accumulated_context else [],
121
+ "current_node": self.state.current_node.value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  }
123
+
124
+ def get_full_context(self) -> List[Dict[str, str]]:
125
+ """Get the full conversation context"""
126
+ return self.state.accumulated_context
127
+
128
+ def is_complete(self) -> bool:
129
+ """Check if the story is complete"""
130
+ return self.state.step_count >= 5
131
+
132
+ # Sample Stories Database
133
+ # Story Categories and Templates
134
+ STORY_CATEGORIES = {
135
+ "Mystery": [
136
+ {
137
+ "title": "The Library Mystery",
138
+ "initial_setup": "In the ancient library of St. Bartholomew's, a rare manuscript has gone missing.",
139
+ "character1_name": "Detective Nash",
140
+ "character2_name": "Librarian Wells",
141
+ "steps": [
142
+ "You notice strange symbols carved into the reading desk",
143
+ "A student mentions seeing someone in medieval clothing",
144
+ "The manuscript tracking system shows impossible timestamps",
145
+ "Temperature drops significantly in the rare books section",
146
+ "You find a hidden door behind the card catalog"
147
+ ],
148
+ "success_conditions": [
149
+ "mentioned checking the security cameras",
150
+ "investigated the symbols",
151
+ "questioned the student further",
152
+ "connected medieval sighting with timestamps"
153
+ ],
154
+ "failure_conditions": [
155
+ "accused the librarian",
156
+ "ignored the symbols",
157
+ "left the library",
158
+ "called the police immediately"
159
+ ]
160
  },
161
+ {
162
+ "title": "The Digital Deception",
163
+ "initial_setup": "A tech startup's revolutionary AI algorithm has been stolen right before a major demo.",
164
+ "character1_name": "Cyber Detective Chen",
165
+ "character2_name": "System Admin Rodriguez",
166
+ "steps": [
167
+ "The server logs show multiple failed login attempts",
168
+ "An employee reports receiving a suspicious email",
169
+ "The backup system was manually disabled",
170
+ "Strange network traffic appears during off-hours",
171
+ "A hidden backdoor program is discovered"
172
+ ],
173
+ "success_conditions": [
174
+ "checked email headers",
175
+ "analyzed network logs",
176
+ "investigated backup system",
177
+ "traced the backdoor"
178
+ ],
179
+ "failure_conditions": [
180
+ "restored from backup immediately",
181
+ "ignored the suspicious email",
182
+ "reset all passwords without investigation",
183
+ "blamed the system admin"
184
+ ]
185
  }
186
+ ],
187
+ "Adventure": [
188
+ {
189
+ "title": "The Lost Temple",
190
+ "initial_setup": "Deep in the Amazon rainforest, you've discovered the entrance to an ancient temple.",
191
+ "character1_name": "Dr. Rivera",
192
+ "character2_name": "Guide Santos",
193
+ "steps": [
194
+ "Ancient markings warn of a curse",
195
+ "You find a mechanism with multiple levers",
196
+ "A strange humming sound emanates from deeper within",
197
+ "The floor tiles show a peculiar pattern",
198
+ "A beam of light reveals a hidden chamber"
199
+ ],
200
+ "success_conditions": [
201
+ "documented the markings",
202
+ "observed the pattern",
203
+ "tested the mechanism carefully",
204
+ "followed the light beam"
205
+ ],
206
+ "failure_conditions": [
207
+ "ignored the warnings",
208
+ "pulled levers randomly",
209
+ "split up the group",
210
+ "took artifacts without examination"
211
+ ]
212
+ }
213
+ ],
214
+ "Sci-Fi": [
215
+ {
216
+ "title": "The Quantum Anomaly",
217
+ "initial_setup": "At a cutting-edge research facility, a quantum experiment has created an unexplained phenomenon.",
218
+ "character1_name": "Dr. Zhang",
219
+ "character2_name": "Engineer Parker",
220
+ "steps": [
221
+ "Quantum readings are off the charts",
222
+ "Equipment starts behaving erratically",
223
+ "A shimmer appears in the air",
224
+ "Time seems to flow differently near the anomaly",
225
+ "Multiple reality signatures detected"
226
+ ],
227
+ "success_conditions": [
228
+ "monitored quantum fluctuations",
229
+ "calibrated equipment",
230
+ "documented time discrepancies",
231
+ "maintained safe distance"
232
+ ],
233
+ "failure_conditions": [
234
+ "shut down power immediately",
235
+ "entered the anomaly",
236
+ "ignored safety protocols",
237
+ "attempted to contain without data"
238
+ ]
239
+ }
240
+ ]
241
+ }
242
+
243
+ # Flatten categories for easy access by title
244
+ STORY_LOOKUP = {
245
+ story["title"]: story
246
+ for category in STORY_CATEGORIES.values()
247
+ for story in category
248
+ }
249
+
250
+ # Character Response Templates
251
+ CHARACTER1_TEMPLATE = """
252
+ Context: You are {character1_name} in this story.
253
+ Story Progress: {story_log}
254
+ User's Latest Action: {user_input}
255
+
256
+ Respond to the user's action in character, considering:
257
+ 1. Your role and personality
258
+ 2. The current story situation
259
+ 3. The potential consequences of their action
260
+
261
+ Response:
262
+ """
263
+
264
+ CHARACTER2_TEMPLATE = """
265
+ Context: You are {character2_name} in this story.
266
+ Story Progress: {story_log}
267
+ User's Latest Action: {user_input}
268
+ Other Character's Response: {character1_response}
269
+
270
+ Respond to both the user and {character1_name}, considering:
271
+ 1. Your role and personality
272
+ 2. The current story developments
273
+ 3. Your relationship with {character1_name}
274
+ 4. The potential impact on the story's outcome
275
+
276
+ Response:
277
+ """
278
 
279
+ def initialize_session_state():
280
+ if 'story_state' not in st.session_state:
281
+ st.session_state.story_state = StoryState()
282
+ if 'selected_category' not in st.session_state:
283
+ st.session_state.selected_category = list(STORY_CATEGORIES.keys())[0]
284
+ if 'current_story' not in st.session_state:
285
+ # Select random story from current category
286
+ st.session_state.current_story = random.choice(STORY_CATEGORIES[st.session_state.selected_category])
287
+
288
+ def evaluate_outcome(state: StoryState, story: MicroStory) -> str:
289
+ user_actions = " ".join(state.user_inputs).lower()
290
 
291
+ # Count matches for success and failure conditions
292
+ success_matches = sum(1 for cond in story["success_conditions"] if cond.lower() in user_actions)
293
+ failure_matches = sum(1 for cond in story["failure_conditions"] if cond.lower() in user_actions)
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
+ # Calculate success ratio
296
+ success_ratio = success_matches / len(story["success_conditions"])
 
 
 
 
 
297
 
298
+ if failure_matches >= 2:
299
+ return "The story ends in failure. Critical mistakes were made."
300
+ elif success_ratio >= 0.7:
301
+ return "The story concludes successfully! Well done!"
 
 
 
 
 
 
 
302
  else:
303
+ return "The story ends with mixed results. Some opportunities were missed."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
+ def update_story_state(state: StoryState, user_input: str, char1_response: str, char2_response: str):
306
+ state.current_step += 1
307
+ state.user_inputs.append(user_input)
308
+ state.character1_responses.append(char1_response)
309
+ state.character2_responses.append(char2_response)
310
+
311
+ # Check if story should end
312
+ if state.current_step >= state.max_steps:
313
+ state.story_outcome = evaluate_outcome(state, st.session_state.current_story)
314
+
315
+ def generate_character_response(
316
+ character_name: str,
317
+ story_log: List[str],
318
+ user_input: str,
319
+ other_response: Optional[str] = None,
320
+ is_character1: bool = True
321
+ ) -> str:
322
+ # This would normally use watsonx.ai or another LLM
323
+ # For now, return placeholder responses
324
+ if is_character1:
325
+ return f"{character_name}: That's an interesting approach. Let's see where this leads..."
326
  else:
327
+ return f"{character_name}: I have my doubts about this, but we'll see..."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
 
329
  def main():
330
+ st.set_page_config(page_title="Interactive Story", layout="wide")
331
  initialize_session_state()
 
 
 
332
 
333
+ # Sidebar Configuration
334
+ st.sidebar.header('Story Selection')
335
  st.sidebar.divider()
336
 
337
+ # Category Selection
338
+ selected_category = st.sidebar.selectbox(
339
+ "Select Story Category",
340
+ list(STORY_CATEGORIES.keys()),
341
+ index=list(STORY_CATEGORIES.keys()).index(st.session_state.selected_category)
342
  )
343
 
344
+ # Update category and story if changed
345
+ if selected_category != st.session_state.selected_category:
346
+ st.session_state.selected_category = selected_category
347
+ st.session_state.current_story = random.choice(STORY_CATEGORIES[selected_category])
348
+ st.session_state.story_state = StoryState() # Reset state for new story
349
+ st.rerun()
350
 
351
+ # Display available stories in category
352
+ with st.sidebar.expander(f"Available {selected_category} Stories"):
353
+ for story in STORY_CATEGORIES[selected_category]:
354
+ st.write(f"📖 {story['title']}")
355
 
356
+ # Optional: Select specific story
357
+ specific_story = st.sidebar.selectbox(
358
+ "Select Specific Story",
359
+ [story["title"] for story in STORY_CATEGORIES[selected_category]],
360
+ index=[story["title"] for story in STORY_CATEGORIES[selected_category]].index(st.session_state.current_story["title"])
361
+ )
362
 
363
+ # Update if specific story changed
364
+ if specific_story != st.session_state.current_story["title"]:
365
+ st.session_state.current_story = STORY_LOOKUP[specific_story]
366
+ st.session_state.story_state = StoryState() # Reset state for new story
367
+ st.rerun()
368
+
369
+ # Display story stats
370
  st.sidebar.divider()
371
+ st.sidebar.subheader("Story Progress")
372
+ progress = (st.session_state.story_state.current_step / 5) * 100
373
+ st.sidebar.progress(progress)
374
+ st.sidebar.write(f"Step {st.session_state.story_state.current_step + 1}/5")
375
 
376
+ # Create three columns
377
+ col1, col2, col3 = st.columns(3)
378
 
379
+ # Story Progress Column
380
+ with col1:
381
+ st.header("Story Progress")
382
+ st.write(f"**{st.session_state.current_story['title']}**")
383
+ st.write(st.session_state.current_story['initial_setup'])
384
 
385
+ # Display story log
386
+ for step_num, (step, user_input) in enumerate(zip(
387
+ st.session_state.current_story['steps'][:st.session_state.story_state.current_step],
388
+ st.session_state.story_state.user_inputs
389
+ )):
390
+ st.write(f"Step {step_num + 1}: {step}")
391
+ st.write(f"Your action: {user_input}")
392
+ st.write("---")
393
 
394
+ # Display current step if story isn't finished
395
+ if st.session_state.story_state.story_outcome is None:
396
+ current_step = st.session_state.current_story['steps'][st.session_state.story_state.current_step]
397
+ st.write(f"Current Situation: {current_step}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
 
399
+ # Character 1 Column
400
+ with col2:
401
+ st.header(st.session_state.current_story['character1_name'])
402
+ for response in st.session_state.story_state.character1_responses:
403
+ st.write(response)
404
+
405
+ # Character 2 Column
406
+ with col3:
407
+ st.header(st.session_state.current_story['character2_name'])
408
+ for response in st.session_state.story_state.character2_responses:
409
+ st.write(response)
410
 
411
+ # User Input Section
412
+ if st.session_state.story_state.story_outcome is None:
413
+ user_input = st.text_input(
414
+ "What do you do?",
415
+ key=f"user_input_{st.session_state.story_state.current_step}"
416
+ )
417
 
418
+ if user_input:
419
+ # Generate character responses
420
+ char1_response = generate_character_response(
421
+ st.session_state.current_story['character1_name'],
422
+ st.session_state.story_state.story_log,
423
+ user_input
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
  )
 
425
 
426
+ char2_response = generate_character_response(
427
+ st.session_state.current_story['character2_name'],
428
+ st.session_state.story_state.story_log,
429
+ user_input,
430
+ char1_response,
431
+ False
 
 
432
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433
 
434
+ # Update state
435
+ update_story_state(st.session_state.story_state, user_input, char1_response, char2_response)
436
+ st.rerun()
437
+ else:
438
+ # Display story outcome
439
+ st.write(st.session_state.story_state.story_outcome)
440
+ if st.button("Start New Story"):
441
+ st.session_state.story_state = StoryState()
442
+ st.session_state.current_story = random.choice(MICRO_STORIES)
443
+ st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
 
445
  if __name__ == "__main__":
446
  main()