MilanM commited on
Commit
8276485
·
verified ·
1 Parent(s): 00d7298

Create functions.py

Browse files
Files changed (1) hide show
  1. functions.py +151 -0
functions.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from ibm_watsonx_ai.foundation_models import ModelInference
3
+ from ibm_watsonx_ai import Credentials, APIClient
4
+ from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
5
+ from knowledge_bases import KNOWLEDGE_BASE_OPTIONS, SYSTEM_PROMPTS
6
+ import genparam
7
+ import time
8
+
9
+ def check_password():
10
+ """Password protection check for the app."""
11
+ def password_entered():
12
+ if st.session_state["password"] == st.secrets["app_password"]:
13
+ st.session_state["password_correct"] = True
14
+ del st.session_state["password"]
15
+ else:
16
+ st.session_state["password_correct"] = False
17
+
18
+ if "password_correct" not in st.session_state:
19
+ st.markdown("\n\n")
20
+ st.text_input("Enter the password", type="password", on_change=password_entered, key="password")
21
+ st.divider()
22
+ st.info("Designed and developed by Milan Mrdenovic © IBM Norway 2024")
23
+ return False
24
+ elif not st.session_state["password_correct"]:
25
+ st.markdown("\n\n")
26
+ st.text_input("Enter the password", type="password", on_change=password_entered, key="password")
27
+ st.divider()
28
+ st.error("😕 Incorrect password")
29
+ st.info("Designed and developed by Milan Mrdenovic © IBM Norway 2024")
30
+ return False
31
+ else:
32
+ return True
33
+
34
+ def initialize_session_state():
35
+ """Initialize all session state variables."""
36
+ if 'chat_history_1' not in st.session_state:
37
+ st.session_state.chat_history_1 = []
38
+ if 'chat_history_2' not in st.session_state:
39
+ st.session_state.chat_history_2 = []
40
+ if 'chat_history_3' not in st.session_state:
41
+ st.session_state.chat_history_3 = []
42
+ if 'first_question' not in st.session_state:
43
+ st.session_state.first_question = False
44
+ if "counter" not in st.session_state:
45
+ st.session_state["counter"] = 0
46
+ if 'token_statistics' not in st.session_state:
47
+ st.session_state.token_statistics = []
48
+ if 'selected_kb' not in st.session_state:
49
+ st.session_state.selected_kb = KNOWLEDGE_BASE_OPTIONS[0]
50
+ if 'current_system_prompts' not in st.session_state:
51
+ st.session_state.current_system_prompts = SYSTEM_PROMPTS[st.session_state.selected_kb]
52
+
53
+ def setup_client(project_id=None):
54
+ """Setup WatsonX client with credentials."""
55
+ credentials = Credentials(
56
+ url=st.secrets["url"],
57
+ api_key=st.secrets["api_key"]
58
+ )
59
+ project_id = project_id or st.secrets["project_id"]
60
+ client = APIClient(credentials, project_id=project_id)
61
+ return credentials, client
62
+
63
+ def get_active_model():
64
+ """Get the currently active model based on configuration."""
65
+ return genparam.SELECTED_MODEL_1 if genparam.ACTIVE_MODEL == 0 else genparam.SELECTED_MODEL_2
66
+
67
+ def get_active_prompt_template():
68
+ """Get the currently active prompt template."""
69
+ return genparam.PROMPT_TEMPLATE_1 if genparam.ACTIVE_MODEL == 0 else genparam.PROMPT_TEMPLATE_2
70
+
71
+ def prepare_prompt(prompt, chat_history):
72
+ """Prepare the prompt with chat history if available."""
73
+ if genparam.TYPE == "chat" and chat_history:
74
+ chats = "\n".join([f"{message['role']}: \"{message['content']}\"" for message in chat_history])
75
+ return f"Conversation History:\n{chats}\n\nNew User Input: {prompt}"
76
+ return f"User Input: {prompt}"
77
+
78
+ def apply_prompt_syntax(prompt, system_prompt, prompt_template, bake_in_prompt_syntax):
79
+ """Apply appropriate syntax to the prompt based on model requirements."""
80
+ model_family_syntax = {
81
+ "llama3-instruct (llama-3, 3.1 & 3.2) - system": """<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""",
82
+ "llama3-instruct (llama-3, 3.1 & 3.2) - user": """<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""",
83
+ "granite-13b-chat & instruct - system": """<|system|>\n{system_prompt}\n<|user|>\n{prompt}\n<|assistant|>\n\n""",
84
+ "granite-13b-chat & instruct - user": """<|user|>\n{prompt}\n<|assistant|>\n\n""",
85
+ "mistral & mixtral v2 tokenizer - system": """<s>[INST] System Prompt: {system_prompt} [/INST][INST] {prompt} [/INST]\n\n""",
86
+ "mistral & mixtral v2 tokenizer - user": """<s>[INST] {prompt} [/INST]\n\n""",
87
+ "no syntax - system": """{system_prompt}\n\n{prompt}""",
88
+ "no syntax - user": """{prompt}"""
89
+ }
90
+
91
+ if bake_in_prompt_syntax:
92
+ template = model_family_syntax[prompt_template]
93
+ if system_prompt:
94
+ return template.format(system_prompt=system_prompt, prompt=prompt)
95
+ return prompt
96
+
97
+ def generate_response(watsonx_llm, prompt_data, params):
98
+ """Generate streaming response from the model."""
99
+ generated_response = watsonx_llm.generate_text_stream(prompt=prompt_data, params=params)
100
+ for chunk in generated_response:
101
+ yield chunk
102
+
103
+ def capture_tokens(prompt_data, response, client, bot_name):
104
+ """Capture token usage statistics."""
105
+ if not genparam.TOKEN_CAPTURE_ENABLED:
106
+ return
107
+
108
+ watsonx_llm = ModelInference(
109
+ api_client=client,
110
+ model_id=genparam.SELECTED_MODEL,
111
+ verify=genparam.VERIFY
112
+ )
113
+
114
+ input_tokens = watsonx_llm.tokenize(prompt=prompt_data)["result"]["token_count"]
115
+ output_tokens = watsonx_llm.tokenize(prompt=response)["result"]["token_count"]
116
+ total_tokens = input_tokens + output_tokens
117
+
118
+ return {
119
+ "bot_name": bot_name,
120
+ "input_tokens": input_tokens,
121
+ "output_tokens": output_tokens,
122
+ "total_tokens": total_tokens,
123
+ "timestamp": time.strftime("%H:%M:%S")
124
+ }
125
+
126
+ def fetch_response(user_input, client, system_prompt, chat_history):
127
+ """Fetch response from the model for the given input."""
128
+ prompt = prepare_prompt(user_input, chat_history)
129
+ prompt_data = apply_prompt_syntax(
130
+ prompt,
131
+ system_prompt,
132
+ get_active_prompt_template(),
133
+ genparam.BAKE_IN_PROMPT_SYNTAX
134
+ )
135
+
136
+ watsonx_llm = ModelInference(
137
+ api_client=client,
138
+ model_id=get_active_model(),
139
+ verify=genparam.VERIFY
140
+ )
141
+
142
+ params = {
143
+ GenParams.DECODING_METHOD: genparam.DECODING_METHOD,
144
+ GenParams.MAX_NEW_TOKENS: genparam.MAX_NEW_TOKENS,
145
+ GenParams.MIN_NEW_TOKENS: genparam.MIN_NEW_TOKENS,
146
+ GenParams.REPETITION_PENALTY: genparam.REPETITION_PENALTY,
147
+ GenParams.STOP_SEQUENCES: genparam.STOP_SEQUENCES
148
+ }
149
+
150
+ stream = generate_response(watsonx_llm, prompt_data, params)
151
+ return stream, prompt_data