import streamlit as st from ibm_watsonx_ai.foundation_models import ModelInference from ibm_watsonx_ai import Credentials, APIClient from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams from knowledge_bases import KNOWLEDGE_BASE_OPTIONS, SYSTEM_PROMPTS import genparam import time def check_password(): """Password protection check for the app.""" def password_entered(): if st.session_state["password"] == st.secrets["app_password"]: st.session_state["password_correct"] = True del st.session_state["password"] else: st.session_state["password_correct"] = False if "password_correct" not in st.session_state: st.markdown("\n\n") st.text_input("Enter the password", type="password", on_change=password_entered, key="password") st.divider() st.info("Designed and developed by Milan Mrdenovic © IBM Norway 2024") return False elif not st.session_state["password_correct"]: st.markdown("\n\n") st.text_input("Enter the password", type="password", on_change=password_entered, key="password") st.divider() st.error("😕 Incorrect password") st.info("Designed and developed by Milan Mrdenovic © IBM Norway 2024") return False else: return True def initialize_session_state(): """Initialize all session state variables.""" if 'chat_history_1' not in st.session_state: st.session_state.chat_history_1 = [] if 'chat_history_2' not in st.session_state: st.session_state.chat_history_2 = [] if 'chat_history_3' not in st.session_state: st.session_state.chat_history_3 = [] if 'first_question' not in st.session_state: st.session_state.first_question = False if "counter" not in st.session_state: st.session_state["counter"] = 0 if 'token_statistics' not in st.session_state: st.session_state.token_statistics = [] if 'selected_kb' not in st.session_state: st.session_state.selected_kb = KNOWLEDGE_BASE_OPTIONS[0] if 'current_system_prompts' not in st.session_state: st.session_state.current_system_prompts = SYSTEM_PROMPTS[st.session_state.selected_kb] def setup_client(project_id=None): """Setup WatsonX client with credentials.""" credentials = Credentials( url=st.secrets["url"], api_key=st.secrets["api_key"] ) project_id = project_id or st.secrets["project_id"] client = APIClient(credentials, project_id=project_id) return credentials, client def get_active_model(): """Get the currently active model based on configuration.""" return genparam.SELECTED_MODEL_1 if genparam.ACTIVE_MODEL == 0 else genparam.SELECTED_MODEL_2 def get_active_prompt_template(): """Get the currently active prompt template.""" return genparam.PROMPT_TEMPLATE_1 if genparam.ACTIVE_MODEL == 0 else genparam.PROMPT_TEMPLATE_2 def prepare_prompt(prompt, chat_history): """Prepare the prompt with chat history if available.""" if genparam.TYPE == "chat" and chat_history: chats = "\n".join([f"{message['role']}: \"{message['content']}\"" for message in chat_history]) return f"Conversation History:\n{chats}\n\nNew User Input: {prompt}" return f"User Input: {prompt}" def apply_prompt_syntax(prompt, system_prompt, prompt_template, bake_in_prompt_syntax): """Apply appropriate syntax to the prompt based on model requirements.""" model_family_syntax = { "llama3-instruct (llama-3, 3.1 & 3.2) - system": """<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""", "llama3-instruct (llama-3, 3.1 & 3.2) - user": """<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""", "granite-13b-chat & instruct - system": """<|system|>\n{system_prompt}\n<|user|>\n{prompt}\n<|assistant|>\n\n""", "granite-13b-chat & instruct - user": """<|user|>\n{prompt}\n<|assistant|>\n\n""", "mistral & mixtral v2 tokenizer - system": """[INST] System Prompt: {system_prompt} [/INST][INST] {prompt} [/INST]\n\n""", "mistral & mixtral v2 tokenizer - user": """[INST] {prompt} [/INST]\n\n""", "no syntax - system": """{system_prompt}\n\n{prompt}""", "no syntax - user": """{prompt}""" } if bake_in_prompt_syntax: template = model_family_syntax[prompt_template] if system_prompt: return template.format(system_prompt=system_prompt, prompt=prompt) return prompt def generate_response(watsonx_llm, prompt_data, params): """Generate streaming response from the model.""" generated_response = watsonx_llm.generate_text_stream(prompt=prompt_data, params=params) for chunk in generated_response: yield chunk def capture_tokens(prompt_data, response, client, bot_name): """Capture token usage statistics.""" if not genparam.TOKEN_CAPTURE_ENABLED: return watsonx_llm = ModelInference( api_client=client, model_id=genparam.SELECTED_MODEL, verify=genparam.VERIFY ) input_tokens = watsonx_llm.tokenize(prompt=prompt_data)["result"]["token_count"] output_tokens = watsonx_llm.tokenize(prompt=response)["result"]["token_count"] total_tokens = input_tokens + output_tokens return { "bot_name": bot_name, "input_tokens": input_tokens, "output_tokens": output_tokens, "total_tokens": total_tokens, "timestamp": time.strftime("%H:%M:%S") } def fetch_response(user_input, client, system_prompt, chat_history): """Fetch response from the model for the given input.""" prompt = prepare_prompt(user_input, chat_history) prompt_data = apply_prompt_syntax( prompt, system_prompt, get_active_prompt_template(), genparam.BAKE_IN_PROMPT_SYNTAX ) watsonx_llm = ModelInference( api_client=client, model_id=get_active_model(), verify=genparam.VERIFY ) params = { GenParams.DECODING_METHOD: genparam.DECODING_METHOD, GenParams.MAX_NEW_TOKENS: genparam.MAX_NEW_TOKENS, GenParams.MIN_NEW_TOKENS: genparam.MIN_NEW_TOKENS, GenParams.REPETITION_PENALTY: genparam.REPETITION_PENALTY, GenParams.STOP_SEQUENCES: genparam.STOP_SEQUENCES } stream = generate_response(watsonx_llm, prompt_data, params) return stream, prompt_data