Fading_Moments / functions.py
MilanM's picture
Create functions.py
8276485 verified
import streamlit as st
from ibm_watsonx_ai.foundation_models import ModelInference
from ibm_watsonx_ai import Credentials, APIClient
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
from knowledge_bases import KNOWLEDGE_BASE_OPTIONS, SYSTEM_PROMPTS
import genparam
import time
def check_password():
"""Password protection check for the app."""
def password_entered():
if st.session_state["password"] == st.secrets["app_password"]:
st.session_state["password_correct"] = True
del st.session_state["password"]
else:
st.session_state["password_correct"] = False
if "password_correct" not in st.session_state:
st.markdown("\n\n")
st.text_input("Enter the password", type="password", on_change=password_entered, key="password")
st.divider()
st.info("Designed and developed by Milan Mrdenovic © IBM Norway 2024")
return False
elif not st.session_state["password_correct"]:
st.markdown("\n\n")
st.text_input("Enter the password", type="password", on_change=password_entered, key="password")
st.divider()
st.error("😕 Incorrect password")
st.info("Designed and developed by Milan Mrdenovic © IBM Norway 2024")
return False
else:
return True
def initialize_session_state():
"""Initialize all session state variables."""
if 'chat_history_1' not in st.session_state:
st.session_state.chat_history_1 = []
if 'chat_history_2' not in st.session_state:
st.session_state.chat_history_2 = []
if 'chat_history_3' not in st.session_state:
st.session_state.chat_history_3 = []
if 'first_question' not in st.session_state:
st.session_state.first_question = False
if "counter" not in st.session_state:
st.session_state["counter"] = 0
if 'token_statistics' not in st.session_state:
st.session_state.token_statistics = []
if 'selected_kb' not in st.session_state:
st.session_state.selected_kb = KNOWLEDGE_BASE_OPTIONS[0]
if 'current_system_prompts' not in st.session_state:
st.session_state.current_system_prompts = SYSTEM_PROMPTS[st.session_state.selected_kb]
def setup_client(project_id=None):
"""Setup WatsonX client with credentials."""
credentials = Credentials(
url=st.secrets["url"],
api_key=st.secrets["api_key"]
)
project_id = project_id or st.secrets["project_id"]
client = APIClient(credentials, project_id=project_id)
return credentials, client
def get_active_model():
"""Get the currently active model based on configuration."""
return genparam.SELECTED_MODEL_1 if genparam.ACTIVE_MODEL == 0 else genparam.SELECTED_MODEL_2
def get_active_prompt_template():
"""Get the currently active prompt template."""
return genparam.PROMPT_TEMPLATE_1 if genparam.ACTIVE_MODEL == 0 else genparam.PROMPT_TEMPLATE_2
def prepare_prompt(prompt, chat_history):
"""Prepare the prompt with chat history if available."""
if genparam.TYPE == "chat" and chat_history:
chats = "\n".join([f"{message['role']}: \"{message['content']}\"" for message in chat_history])
return f"Conversation History:\n{chats}\n\nNew User Input: {prompt}"
return f"User Input: {prompt}"
def apply_prompt_syntax(prompt, system_prompt, prompt_template, bake_in_prompt_syntax):
"""Apply appropriate syntax to the prompt based on model requirements."""
model_family_syntax = {
"llama3-instruct (llama-3, 3.1 & 3.2) - system": """<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""",
"llama3-instruct (llama-3, 3.1 & 3.2) - user": """<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""",
"granite-13b-chat & instruct - system": """<|system|>\n{system_prompt}\n<|user|>\n{prompt}\n<|assistant|>\n\n""",
"granite-13b-chat & instruct - user": """<|user|>\n{prompt}\n<|assistant|>\n\n""",
"mistral & mixtral v2 tokenizer - system": """<s>[INST] System Prompt: {system_prompt} [/INST][INST] {prompt} [/INST]\n\n""",
"mistral & mixtral v2 tokenizer - user": """<s>[INST] {prompt} [/INST]\n\n""",
"no syntax - system": """{system_prompt}\n\n{prompt}""",
"no syntax - user": """{prompt}"""
}
if bake_in_prompt_syntax:
template = model_family_syntax[prompt_template]
if system_prompt:
return template.format(system_prompt=system_prompt, prompt=prompt)
return prompt
def generate_response(watsonx_llm, prompt_data, params):
"""Generate streaming response from the model."""
generated_response = watsonx_llm.generate_text_stream(prompt=prompt_data, params=params)
for chunk in generated_response:
yield chunk
def capture_tokens(prompt_data, response, client, bot_name):
"""Capture token usage statistics."""
if not genparam.TOKEN_CAPTURE_ENABLED:
return
watsonx_llm = ModelInference(
api_client=client,
model_id=genparam.SELECTED_MODEL,
verify=genparam.VERIFY
)
input_tokens = watsonx_llm.tokenize(prompt=prompt_data)["result"]["token_count"]
output_tokens = watsonx_llm.tokenize(prompt=response)["result"]["token_count"]
total_tokens = input_tokens + output_tokens
return {
"bot_name": bot_name,
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": total_tokens,
"timestamp": time.strftime("%H:%M:%S")
}
def fetch_response(user_input, client, system_prompt, chat_history):
"""Fetch response from the model for the given input."""
prompt = prepare_prompt(user_input, chat_history)
prompt_data = apply_prompt_syntax(
prompt,
system_prompt,
get_active_prompt_template(),
genparam.BAKE_IN_PROMPT_SYNTAX
)
watsonx_llm = ModelInference(
api_client=client,
model_id=get_active_model(),
verify=genparam.VERIFY
)
params = {
GenParams.DECODING_METHOD: genparam.DECODING_METHOD,
GenParams.MAX_NEW_TOKENS: genparam.MAX_NEW_TOKENS,
GenParams.MIN_NEW_TOKENS: genparam.MIN_NEW_TOKENS,
GenParams.REPETITION_PENALTY: genparam.REPETITION_PENALTY,
GenParams.STOP_SEQUENCES: genparam.STOP_SEQUENCES
}
stream = generate_response(watsonx_llm, prompt_data, params)
return stream, prompt_data