Amitgm commited on
Commit
1a311b9
·
verified ·
1 Parent(s): 4c458bf

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +376 -0
main.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.document_loaders import CSVLoader
2
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
3
+ # from langchain_community.embeddings import OpenAIEmbeddings
4
+ from langchain_community.vectorstores import chroma
5
+ from langchain_community.llms import openai
6
+ from langchain.chains import LLMChain
7
+ from dotenv import load_dotenv
8
+ from langchain.chains import ConversationalRetrievalChain
9
+ from langchain_core.prompts import ChatPromptTemplate,PromptTemplate
10
+ from langchain.memory import ConversationBufferMemory
11
+ from langchain_community.chat_models import ChatOpenAI
12
+ from langchain_openai import OpenAIEmbeddings
13
+ from langchain_chroma import Chroma
14
+ import os
15
+ from dotenv import load_dotenv
16
+ import streamlit as st
17
+ import streamlit_chat
18
+ from langchain_groq import ChatGroq
19
+ global seed
20
+ from langchain.chains import LLMChain
21
+ from langchain.prompts import PromptTemplate
22
+ from langchain.memory import ConversationBufferMemory
23
+ from langchain_community.chat_models import ChatOpenAI
24
+ from langchain.docstore.document import Document
25
+ from langchain.llms import HuggingFacePipeline
26
+ from langchain.embeddings import HuggingFaceEmbeddings
27
+
28
+
29
+ import pandas as pd
30
+
31
+ load_dotenv()
32
+
33
+
34
+ # OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
35
+ # os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
36
+
37
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
38
+
39
+
40
+ class prompts:
41
+
42
+ prompt = PromptTemplate.from_template("""
43
+
44
+ You are a helpful fitness assistant. Use the following context to answer the question The Level is provided for you to get a better idea on how to answer the question
45
+ .
46
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.Also make sure to mention the level passed for the user.
47
+ Context:
48
+ {context}
49
+
50
+ Chat History:
51
+ {history}
52
+
53
+ Question:
54
+ {question}
55
+
56
+ Level:
57
+ {level}
58
+
59
+ Answer:
60
+ """)
61
+
62
+ # Data Filteration
63
+ def filter_transform_data(dataframe):
64
+
65
+ dataframe.drop("RatingDesc",axis=1,inplace=True)
66
+
67
+ dataframe.dropna(subset=["Desc","Equipment"],inplace=True)
68
+
69
+ dataframe.drop("Rating",inplace=True,axis=1)
70
+
71
+ # transform data
72
+
73
+ document_data = dataframe.to_dict(orient="records")
74
+
75
+ return document_data
76
+
77
+
78
+ def get_context(vector_store,query,level):
79
+
80
+ results = vector_store.max_marginal_relevance_search(
81
+
82
+ query=query,
83
+ k=5,
84
+ filter={"Level": level},
85
+ )
86
+
87
+ # Creating the LLM Chain
88
+
89
+ # Pass your context manually from retrieved documents
90
+ context = "\n\n".join([doc.page_content for doc in results])
91
+
92
+ return context
93
+
94
+ def generate_vector_store():
95
+
96
+ # embedding = OpenAIEmbeddings(
97
+
98
+ if "vector_store" not in st.session_state:
99
+
100
+ langchain_documents = []
101
+
102
+ dataframe = pd.read_csv("megaGymDataset.csv",index_col=0)
103
+
104
+ document_data = filter_transform_data(dataframe)
105
+
106
+ # Iterate through the sample data and create Document objects
107
+ for item in document_data:
108
+ # Formulate the page_content string
109
+ page_content = (
110
+ f"Title: {item['Title']}\n"
111
+ f"Type:{item['Type']}\n"
112
+ f"BodyPart: {item['BodyPart']}\n"
113
+ f"Desc: {item['Desc']}\n"
114
+ f"Equipment: {item['Equipment']}\n"
115
+ )
116
+
117
+ # Create the metadata dictionary
118
+ metadata = {"Level": item['Level']}
119
+
120
+ # Create the Document object
121
+ doc = Document(page_content=page_content, metadata=metadata)
122
+
123
+ # Add the Document to our list
124
+ langchain_documents.append(doc)
125
+
126
+ # creating the session_state for vector_store
127
+
128
+ # embedding = OpenAIEmbeddings(openai_api_key=os.environ["OPENAI_API_KEY"])
129
+ embedding = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-large-instruct")
130
+
131
+
132
+ # if path not exist
133
+ if not os.path.exists("db"):
134
+
135
+ st.session_state.vector_store = Chroma.from_documents(langchain_documents,embedding=embedding,collection_name="gym-queries-data",persist_directory = "db")
136
+ # st.session_state.vector_store.persist()
137
+
138
+ else:
139
+
140
+ st.session_state.vector_store = Chroma(
141
+
142
+ persist_directory="db",
143
+ embedding_function=embedding
144
+ )
145
+
146
+ return st.session_state.vector_store
147
+
148
+ def get_conversational_chain(vector_store,query,level):
149
+
150
+ # model_name = "msu-rcc-lair/RuadaptQwen2.5-32B-Instruct" # Replace with actual name
151
+
152
+ # tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
153
+ # model = AutoModelForCausalLM.from_pretrained(
154
+ # model_name,
155
+ # device_map="auto",
156
+ # torch_dtype=torch.bfloat16, # or float16
157
+ # trust_remote_code=True
158
+ # )
159
+ # generator = pipeline(
160
+ # "text-generation",
161
+ # model=model,
162
+ # tokenizer=tokenizer,
163
+ # max_new_tokens=512,
164
+ # temperature=0.7,
165
+ # return_full_text=False
166
+ # )
167
+
168
+
169
+
170
+ # llm = HuggingFacePipeline(pipeline=generator)
171
+ # llm = ChatOpenAI(temperature=0.5,model_name="gpt-4o")
172
+
173
+ # llama3-70b-8192
174
+ llm = ChatGroq(
175
+ temperature=1,
176
+ groq_api_key = GROQ_API_KEY,
177
+ model_name="llama-3.1-8b-instant",
178
+ max_tokens=560,
179
+ # top_p=0.95,
180
+ # frequency_penalty=1,
181
+ # presence_penalty=1,
182
+ )
183
+ # llm_chain = LLMChain(llm=llm, prompt=prompts.prompt)
184
+
185
+ if "memory" not in st.session_state:
186
+
187
+ st.session_state.memory = ConversationBufferMemory(memory_key="history", input_key="question", return_messages=True)
188
+
189
+ st.session_state.conversational_chain = LLMChain(
190
+ llm=llm,
191
+ # taking the prompt template
192
+ prompt=prompts.prompt,
193
+ memory=st.session_state.memory
194
+ )
195
+
196
+
197
+ return st.session_state.conversational_chain,st.session_state.memory
198
+
199
+ def stick_it_good():
200
+
201
+ # make header sticky.
202
+ st.markdown(
203
+ """
204
+ <div class='fixed-header'/>
205
+ <style>
206
+ div[data-testid="stVerticalBlock"] div:has(div.fixed-header) {
207
+ position: sticky;
208
+ top: 2.875rem;
209
+ background-color: ##393939;
210
+ z-index: 999;
211
+ }
212
+ .fixed-header {
213
+ border-bottom: 1px solid black;
214
+ }
215
+ </style>
216
+ """,
217
+ unsafe_allow_html=True
218
+ )
219
+
220
+
221
+ def show_privacy_policy():
222
+ st.title("Privacy Policy")
223
+
224
+
225
+ def show_terms_of_service():
226
+ st.title("Terms of Service")
227
+
228
+ seed = 0
229
+
230
+ def main():
231
+
232
+ global seed
233
+
234
+ page = st.sidebar.selectbox("Choose a page", ["Home", "Privacy Policy", "Terms of Service"])
235
+
236
+ if page == "Privacy Policy":
237
+
238
+ show_privacy_policy()
239
+
240
+ elif page == "Terms of Service":
241
+
242
+ show_terms_of_service()
243
+
244
+ else:
245
+
246
+ st.write("Welcome to the Home Page")
247
+
248
+ with st.container():
249
+
250
+ st.title("Workout Wizard")
251
+ stick_it_good()
252
+
253
+
254
+ with st.sidebar:
255
+
256
+ if "seed" not in st.session_state:
257
+
258
+ st.session_state.seed = 0
259
+
260
+ # Display the image using the URL
261
+
262
+ choose_mode = st.selectbox('Choose Workout Level',["Beginner","Intermediate","Expert"])
263
+
264
+
265
+ st.markdown("<h2 style='text-align: center;'>Choose Your Avatar</h2>", unsafe_allow_html=True)
266
+
267
+ # st.markdown(f"<h2 style='text-align: center;'>{st.button("Back")}</h2>", unsafe_allow_html=True)
268
+
269
+ # Center the buttons using HTML and CSS
270
+ col1, col2, col3 = st.columns([1, 1, 1])
271
+
272
+
273
+ with col1:
274
+
275
+ st.write("") # Empty column for spacing
276
+
277
+ with col2:
278
+
279
+ print(st.session_state.seed)
280
+
281
+ choose_Avatar = st.button("Next")
282
+
283
+ choose_Avatar_second = st.button("Back")
284
+
285
+
286
+ if choose_Avatar:
287
+
288
+ st.session_state.seed += 1
289
+
290
+ if choose_Avatar_second:
291
+
292
+ st.session_state.seed -= 1
293
+
294
+ avatar_url = f"https://api.dicebear.com/9.x/adventurer/svg?seed={st.session_state.seed}"
295
+
296
+ st.image(avatar_url, caption=f"Avatar {st.session_state.seed }")
297
+
298
+ with col3:
299
+
300
+ st.write("") # Empty column for spacing
301
+
302
+
303
+ streamlit_chat.message("Hi. I'm your friendly Gym Assistant Bot.")
304
+ streamlit_chat.message("Ask me anything about the gym! Just don’t ask me to do any push-ups... I'm already *up* and running!")
305
+ streamlit_chat.message("If you want to change your workout level and avatar, press the top left arrow and you will have options to make changes")
306
+
307
+
308
+ question = st.chat_input("Ask a question related to your GYM queries")
309
+
310
+
311
+ if "conversation_chain" not in st.session_state:
312
+
313
+ st.session_state.conversation_chain = None
314
+
315
+
316
+ # if question:
317
+
318
+ # Converstion chain
319
+ if st.session_state.conversation_chain == None:
320
+ # st.session_state.vectors
321
+
322
+ print("the vector store generated")
323
+
324
+ st.session_state.vector_store = generate_vector_store()
325
+
326
+ st.session_state.conversation_chain, st.session_state.memory = get_conversational_chain(st.session_state.vector_store,question,choose_mode)
327
+
328
+ # the session state memory
329
+ if st.session_state.memory != None:
330
+
331
+ for i,message in enumerate(st.session_state.memory.chat_memory.messages):
332
+
333
+ if i%2 == 0:
334
+
335
+ suffix = f" for {choose_mode} level"
336
+
337
+ # Check if the message ends with the suffix and strip it
338
+ if message.content.endswith(suffix):
339
+
340
+ message.content = message.content[:-len(suffix)]
341
+
342
+ # message.content = message.content.strip(f" for {choose_mode} level")
343
+
344
+ print("this is the message content",message.content)
345
+
346
+ streamlit_chat.message(message.content,is_user=True, avatar_style="adventurer",seed=st.session_state.seed, key=f"user_msg_{i}")
347
+
348
+ else:
349
+
350
+ streamlit_chat.message(message.content,key=f"bot_msg_{i}")
351
+
352
+ st.write("--------------------------------------------------")
353
+
354
+ if question:
355
+
356
+ streamlit_chat.message(question,is_user=True, avatar_style="adventurer",seed=st.session_state.seed)
357
+
358
+ print(question)
359
+
360
+ print("------------------------")
361
+
362
+ # GETTING THE CONTEXT AND ANSWER FROM THE MODEL
363
+
364
+ context = get_context(st.session_state.vector_store,question,choose_mode)
365
+
366
+ print("context::",context)
367
+ print("the choose mode:",choose_mode)
368
+
369
+ response = st.session_state.conversational_chain.run({"context": context, "question": question,"level":choose_mode})
370
+
371
+ streamlit_chat.message(response)
372
+
373
+
374
+ if __name__ == "__main__":
375
+
376
+ main()