stckwok commited on
Commit
588bdab
·
verified ·
1 Parent(s): de3b4f2

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +130 -232
app.py CHANGED
@@ -29,7 +29,7 @@ from langchain.agents import create_tool_calling_agent, AgentExecutor
29
  from langchain_core.prompts import ChatPromptTemplate
30
 
31
  # LangChain OpenAI imports
32
- from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI, ChatOpenAI # OpenAI embeddings and models
33
  from langchain.embeddings.openai import OpenAIEmbeddings # OpenAI embeddings for text vectors
34
 
35
  # LlamaParse & LlamaIndex imports
@@ -54,7 +54,7 @@ from datetime import datetime
54
 
55
  #====================================SETUP=====================================#
56
  # Fetch secrets from Hugging Face Spaces
57
- api_key = os.getenv("API_KEY") #config.get("API_KEY")
58
  endpoint = os.getenv("OPENAI_API_BASE")
59
  llama_api_key = os.environ['GROQ_API_KEY']
60
  MEM0_api_key = os.environ['mem0']
@@ -66,6 +66,7 @@ embedding_function = chromadb.utils.embedding_functions.OpenAIEmbeddingFunction(
66
  model_name='text-embedding-ada-002' # This is a fixed value and does not need modification
67
  )
68
 
 
69
 
70
  # Initialize the OpenAI Embeddings
71
  embedding_model = OpenAIEmbeddings(
@@ -82,11 +83,14 @@ llm = ChatOpenAI(
82
  model="gpt-4o-mini",
83
  streaming=False
84
  )
85
-
86
 
87
  # set the LLM and embedding model in the LlamaIndex settings.
88
- Settings.llm = llm
89
- Settings.embedding = embedding_model
 
 
 
90
 
91
  #================================Creating Langgraph agent======================#
92
 
@@ -115,29 +119,10 @@ def expand_query(state):
115
  Dict: The updated state with the expanded query.
116
  """
117
  print("---------Expanding Query---------")
118
- system_message = '''You are an AI specializing in improving search queries to retrieve the most relevant nutrition disorder-related information.
119
- Your task is to **refine** and **expand** the given query so that better search results are obtained, while **keeping the original intent** unchanged.
120
-
121
- Guidelines:
122
- - Add **specific details** where needed. Example: If a user asks about "anorexia," specify aspects like symptoms, causes, or treatment options.
123
- - Include **related terms** to improve retrieval (e.g., “bulimia” → “bulimia nervosa vs binge eating disorder”).
124
- - If the user provides an unclear query, suggest necessary clarifications.
125
- - **DO NOT** answer the question. Your job is only to enhance the query.
126
-
127
- Examples:
128
- 1. User Query: "Tell me about eating disorders."
129
- Expanded Query: "Provide details on eating disorders, including types (e.g., anorexia nervosa, bulimia nervosa), symptoms, causes, and treatment options."
130
 
131
- 2. User Query: "What is anorexia?"
132
- Expanded Query: "Explain anorexia nervosa, including its symptoms, causes, risk factors, and treatment options."
133
 
134
- 3. User Query: "How to treat bulimia?"
135
- Expanded Query: "Describe treatment options for bulimia nervosa, including psychotherapy, medications, and lifestyle changes."
136
-
137
- 4. User Query: "What are the effects of malnutrition?"
138
- Expanded Query: "Explain the effects of malnutrition on physical and mental health, including specific nutrient deficiencies and their consequences."
139
-
140
- Now, expand the following query:'''
141
 
142
  expand_prompt = ChatPromptTemplate.from_messages([
143
  ("system", system_message),
@@ -177,7 +162,8 @@ def retrieve_context(state):
177
  Dict: The updated state with the retrieved context.
178
  """
179
  print("---------retrieve_context---------")
180
- query = state['expanded_query']
 
181
  #print("Query used for retrieval:", query) # Debugging: Print the query
182
 
183
  # Retrieve documents from the vector store
@@ -192,7 +178,9 @@ def retrieve_context(state):
192
  }
193
  for doc in docs
194
  ]
195
- state['context'] = context
 
 
196
  print("Extracted context with metadata:", context) # Debugging: Print the extracted context
197
  #print(f"Groundedness loop count: {state['groundedness_loop_count']}")
198
  return state
@@ -209,38 +197,27 @@ def craft_response(state: Dict) -> Dict:
209
  Returns:
210
  Dict: The updated state with the generated response.
211
  """
212
- system_message = '''You are a professional AI nutrition disorder specialist generating responses based on retrieved documents.
213
- Your task is to use the given **context** to generate a highly accurate, informative, and user-friendly response.
214
-
215
- Guidelines:
216
- - **Be direct and concise** while ensuring completeness.
217
- - **DO NOT include information that is not present in the context.**
218
- - If multiple sources exist, synthesize them into a coherent response.
219
- - If the context does not fully answer the query, state what additional information is needed.
220
- - Use bullet points when explaining complex concepts.
221
-
222
- Example:
223
- User Query: "What are the symptoms of anorexia nervosa?"
224
- Context:
225
- 1. Anorexia nervosa is characterized by extreme weight loss and fear of gaining weight.
226
- 2. Common symptoms include restricted eating, distorted body image, and excessive exercise.
227
- Response:
228
- "Anorexia nervosa is an eating disorder characterized by extreme weight loss and an intense fear of gaining weight. Common symptoms include:
229
- - Restricted eating
230
- - Distorted body image
231
- - Excessive exercise
232
- If you or someone you know is experiencing these symptoms, it is important to seek professional help."'''
233
 
234
  response_prompt = ChatPromptTemplate.from_messages([
235
  ("system", system_message),
236
- ("user", "Query: {query}\nContext: {context}\n\nResponse:")
237
  ])
238
 
239
- chain = response_prompt | llm | StrOutputParser()
240
- state['response'] = chain.invoke({
241
  "query": state['query'],
242
- "context": "\n".join([doc["content"] for doc in state['context']]) # Extract content from each document
 
 
 
243
  })
 
 
 
244
  return state
245
 
246
 
@@ -256,37 +233,9 @@ def score_groundedness(state: Dict) -> Dict:
256
  Dict: The updated state with the groundedness score.
257
  """
258
  print("---------check_groundedness---------")
259
- system_message = '''You are an AI tasked with evaluating whether a response is grounded in the provided context and includes proper citations.
260
-
261
- Guidelines:
262
- 1. **Groundedness Check**:
263
- - Verify that the response accurately reflects the information in the context.
264
- - Flag any unsupported claims or deviations from the context.
265
-
266
- 2. **Citation Check**:
267
- - Ensure that the response includes citations to the source material (e.g., "According to [Source], ...").
268
- - If citations are missing, suggest adding them.
269
 
270
- 3. **Scoring**:
271
- - Assign a groundedness score between 0 and 1, where 1 means fully grounded and properly cited.
272
-
273
- Examples:
274
- 1. Response: "Anorexia nervosa is caused by genetic factors (Source 1)."
275
- Context: "Anorexia nervosa is influenced by genetic, environmental, and psychological factors (Source 1)."
276
- Evaluation: "The response is grounded and properly cited. Groundedness score: 1.0."
277
-
278
- 2. Response: "Bulimia nervosa can be cured with diet alone."
279
- Context: "Treatment for bulimia nervosa involves psychotherapy and medications (Source 2)."
280
- Evaluation: "The response is ungrounded and lacks citations. Groundedness score: 0.2."
281
-
282
- 3. Response: "Anorexia nervosa has a high mortality rate."
283
- Context: "Anorexia nervosa has one of the highest mortality rates among psychiatric disorders (Source 3)."
284
- Evaluation: "The response is grounded but lacks a citation. Groundedness score: 0.7. ."
285
-
286
- ****Return only a float score (e.g., 0.9). Do not provide explanations.****
287
-
288
- Now, evaluate the following response:
289
- '''
290
 
291
  groundedness_prompt = ChatPromptTemplate.from_messages([
292
  ("system", system_message),
@@ -296,12 +245,14 @@ Now, evaluate the following response:
296
  chain = groundedness_prompt | llm | StrOutputParser()
297
  groundedness_score = float(chain.invoke({
298
  "context": "\n".join([doc["content"] for doc in state['context']]),
299
- "response": state['response']
 
300
  }))
301
- print("groundedness_score: ",groundedness_score)
302
- state['groundedness_loop_count'] +=1
303
  print("#########Groundedness Incremented###########")
304
  state['groundedness_score'] = groundedness_score
 
305
  return state
306
 
307
 
@@ -317,49 +268,23 @@ def check_precision(state: Dict) -> Dict:
317
  Dict: The updated state with the precision score.
318
  """
319
  print("---------check_precision---------")
320
- system_message = '''You are an AI evaluator assessing the **precision** of the response.
321
- Your task is to **score** how well the response addresses the user’s original nutrition disorder-related query.
322
-
323
- Scoring Criteria:
324
- - 1.0 → The response is fully precise, directly answering the question.
325
- - 0.7 → The response is mostly correct but contains some generalization.
326
- - 0.5 → The response is somewhat relevant but lacks key details.
327
- - 0.3 → The response is vague or only partially correct.
328
- - 0.0 → The response is incorrect or misleading.
329
-
330
- Examples:
331
- 1. Query: "What are the symptoms of anorexia nervosa?"
332
- Response: "The symptoms of anorexia nervosa include extreme weight loss, fear of gaining weight, and a distorted body image."
333
- Precision Score: 1.0
334
-
335
- 2. Query: "How is bulimia nervosa treated?"
336
- Response: "Bulimia nervosa is treated with therapy and medications."
337
- Precision Score: 0.7
338
-
339
- 3. Query: "What causes binge eating disorder?"
340
- Response: "Binge eating disorder is caused by a combination of genetic, psychological, and environmental factors."
341
- Precision Score: 0.5
342
-
343
- 4. Query: "What are the effects of malnutrition?"
344
- Response: "Malnutrition can lead to health problems."
345
- Precision Score: 0.3
346
-
347
- 5. Query: "What is the mortality rate of anorexia nervosa?"
348
- Response: "Anorexia nervosa is a type of eating disorder."
349
- Precision Score: 0.0
350
-
351
- *****Return only a float score (e.g., 0.9). Do not provide explanations.*****
352
- Now, evaluate the following query and response:
353
- '''
354
  precision_prompt = ChatPromptTemplate.from_messages([
355
  ("system", system_message),
356
  ("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
357
  ])
358
 
359
- chain = precision_prompt | llm | StrOutputParser()
 
 
360
  precision_score = float(chain.invoke({
361
  "query": state['query'],
362
- "response": state['response']
 
 
363
  }))
364
  state['precision_score'] = precision_score
365
  print("precision_score:", precision_score)
@@ -381,29 +306,9 @@ def refine_response(state: Dict) -> Dict:
381
  """
382
  print("---------refine_response---------")
383
 
384
- system_message = '''You are an AI response refinement assistant. Your task is to suggest **improvements** for the given response.
385
-
386
- ### Guidelines:
387
- - Identify **gaps in the explanation** (missing key details).
388
- - Highlight **unclear or vague parts** that need elaboration.
389
- - Suggest **additional details** that should be included for better accuracy.
390
- - Ensure the refined response is **precise** and **grounded** in the retrieved context.
391
-
392
- ### Examples:
393
- 1. Query: "What are the symptoms of anorexia nervosa?"
394
- Response: "The symptoms include weight loss and fear of gaining weight."
395
- Suggestions: "The response is missing key details about behavioral and emotional symptoms. Add details like 'distorted body image' and 'restrictive eating patterns.'"
396
-
397
- 2. Query: "How is bulimia nervosa treated?"
398
- Response: "Bulimia nervosa is treated with therapy."
399
- Suggestions: "The response is too vague. Specify the types of therapy (e.g., cognitive-behavioral therapy) and mention other treatments like nutritional counseling and medications."
400
-
401
- 3. Query: "What causes binge eating disorder?"
402
- Response: "Binge eating disorder is caused by psychological factors."
403
- Suggestions: "The response is incomplete. Add details about genetic and environmental factors, and explain how they contribute to the disorder."
404
 
405
- Now, suggest improvements for the following response:
406
- '''
407
 
408
  refine_response_prompt = ChatPromptTemplate.from_messages([
409
  ("system", system_message),
@@ -433,28 +338,9 @@ def refine_query(state: Dict) -> Dict:
433
  Dict: The updated state with query refinement suggestions.
434
  """
435
  print("---------refine_query---------")
436
- system_message = '''You are an AI query refinement assistant. Your task is to suggest **improvements** for the expanded query.
 
437
 
438
- ### Guidelines:
439
- - Add **specific keywords** to improve document retrieval.
440
- - Identify **missing details** that should be included.
441
- - Suggest **ways to narrow the scope** for better precision.
442
-
443
- ### Examples:
444
- 1. Original Query: "Tell me about eating disorders."
445
- Expanded Query: "Provide details on eating disorders, including types, symptoms, causes, and treatment options."
446
- Suggestions: "Add specific types of eating disorders like 'anorexia nervosa' and 'bulimia nervosa' to improve retrieval."
447
-
448
- 2. Original Query: "What is anorexia?"
449
- Expanded Query: "Explain anorexia nervosa, including its symptoms and causes."
450
- Suggestions: "Include details about treatment options and risk factors to make the query more comprehensive."
451
-
452
- 3. Original Query: "How to treat bulimia?"
453
- Expanded Query: "Describe treatment options for bulimia nervosa."
454
- Suggestions: "Specify types of treatments like 'cognitive-behavioral therapy' and 'medications' for better precision."
455
-
456
- Now, suggest improvements for the following expanded query:
457
- '''
458
 
459
  refine_query_prompt = ChatPromptTemplate.from_messages([
460
  ("system", system_message),
@@ -477,7 +363,8 @@ def should_continue_groundedness(state):
477
  """Decides if groundedness is sufficient or needs improvement."""
478
  print("---------should_continue_groundedness---------")
479
  print("groundedness loop count: ", state['groundedness_loop_count'])
480
- if state['groundedness_score'] >= 0.4: # Threshold for groundedness
 
481
  print("Moving to precision")
482
  return "check_precision"
483
  else:
@@ -491,19 +378,24 @@ def should_continue_groundedness(state):
491
  def should_continue_precision(state: Dict) -> str:
492
  """Decides if precision is sufficient or needs improvement."""
493
  print("---------should_continue_precision---------")
494
- print("precision loop count: ",state['precision_loop_count'])
495
- if state['precision_score'] >= 0.7: # Threshold for precision
 
 
 
496
  return "pass" # Complete the workflow
497
  else:
498
- if state['precision_loop_count'] > state['loop_max_iter']: # Maximum allowed loops
 
499
  return "max_iterations_reached"
500
  else:
501
  print(f"---------Precision Score Threshold Not met. Refining Query-----------") # Debugging
502
- # Exit the loop
503
  return "refine_query" # Refine the query
504
 
505
 
506
 
 
507
  def max_iterations_reached(state: Dict) -> Dict:
508
  """Handles the case when the maximum number of iterations is reached."""
509
  print("---------max_iterations_reached---------")
@@ -514,27 +406,29 @@ def max_iterations_reached(state: Dict) -> Dict:
514
 
515
 
516
 
 
 
517
  def create_workflow() -> StateGraph:
518
  """Creates the updated workflow for the AI nutrition agent."""
519
- workflow = StateGraph(AgentState)
 
520
 
521
  # Add processing nodes
522
- workflow.add_node("expand_query", expand_query) # Step 1: Expand user query.
523
- workflow.add_node("retrieve_context", retrieve_context) # Step 2: Retrieve relevant documents.
524
- workflow.add_node("craft_response", craft_response) # Step 3: Generate a response based on retrieved data.
525
- workflow.add_node("score_groundedness", score_groundedness) # Step 4: Evaluate response grounding.
526
- workflow.add_node("refine_response", refine_response) # Step 5: Improve response if it's weakly grounded.
527
- workflow.add_node("check_precision", check_precision) # Step 6: Evaluate response precision.
528
- workflow.add_node("refine_query", refine_query) # Step 7: Improve query if response lacks precision.
529
- workflow.add_node("max_iterations_reached", max_iterations_reached) # Step 8: Handle max iterations.
530
- # workflow.add_node("groundedness_decider",groundedness_decider)
 
531
  # Main flow edges
532
  workflow.add_edge(START, "expand_query")
533
  workflow.add_edge("expand_query", "retrieve_context")
534
  workflow.add_edge("retrieve_context", "craft_response")
535
  workflow.add_edge("craft_response", "score_groundedness")
536
- # workflow.add_edge("score_groundedness","groundedness_decider")
537
-
538
 
539
  # Conditional edges based on groundedness check
540
  workflow.add_conditional_edges(
@@ -546,6 +440,7 @@ def create_workflow() -> StateGraph:
546
  "max_iterations_reached": "max_iterations_reached" # If max loops reached, exit.
547
  }
548
  )
 
549
  workflow.add_edge("refine_response", "craft_response") # Refined responses are reprocessed.
550
 
551
  # Conditional edges based on precision check
@@ -555,19 +450,17 @@ def create_workflow() -> StateGraph:
555
  {
556
  "pass": END, # If precise, complete the workflow.
557
  "refine_query": "refine_query", # If imprecise, refine the query.
558
- "max_iterations_reached": "max_iterations_reached" # If max loops reached, exit.
559
  }
560
  )
 
561
  workflow.add_edge("refine_query", "expand_query") # Refined queries go through expansion again.
562
 
563
  workflow.add_edge("max_iterations_reached", END)
564
- # Set entry point
565
- # workflow.set_entry_point("expand_query")
566
 
567
  return workflow
568
 
569
 
570
-
571
  #=========================== Defining the agentic rag tool ====================#
572
  WORKFLOW_APP = create_workflow().compile()
573
  @tool
@@ -584,17 +477,16 @@ def agentic_rag(query: str):
584
  # Initialize state with necessary parameters
585
  inputs = {
586
  "query": query, # Current user query
587
- "expanded_query": "", # Expanded version of the query
588
  "context": [], # Retrieved documents (initially empty)
589
- "response": "", # AI-generated response
590
- "precision_score": 0.0, # Precision score of the response
591
- "groundedness_score": 0.0, # Groundedness score of the response
592
- "groundedness_loop_count": 0, # Counter for groundedness loops
593
- "precision_loop_count": 0, # Counter for precision loops
594
- "feedback": "",
595
- "query_feedback":"",
596
- "loop_max_iter":2
597
-
598
  }
599
 
600
  output = WORKFLOW_APP.invoke(inputs)
@@ -638,8 +530,9 @@ class NutritionBot:
638
  """
639
 
640
  # Initialize a memory client to store and retrieve customer interactions
641
- self.memory = MemoryClient(api_key=MEM0_api_key)
642
 
 
643
  self.client = ChatOpenAI(
644
  model_name="gpt-4o-mini", # Specify the model to use (e.g., GPT-4 optimized version)
645
  api_key=config.get("API_KEY"), # API key for authentication
@@ -647,7 +540,6 @@ class NutritionBot:
647
  temperature=0 # Controls randomness in responses; 0 ensures deterministic results
648
  )
649
 
650
-
651
  # Define tools available to the chatbot, such as web search
652
  tools = [agentic_rag]
653
 
@@ -679,6 +571,7 @@ class NutritionBot:
679
  # Wrap the agent in an executor to manage tool interactions and execution flow
680
  self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
681
 
 
682
  def store_customer_interaction(self, user_id: str, message: str, response: str, metadata: Dict = None):
683
  """
684
  Store customer interaction in memory for future reference.
@@ -709,6 +602,7 @@ class NutritionBot:
709
  metadata=metadata
710
  )
711
 
 
712
  def get_relevant_history(self, user_id: str, query: str) -> List[Dict]:
713
  """
714
  Retrieve past interactions relevant to the current query.
@@ -723,9 +617,12 @@ class NutritionBot:
723
  return self.memory.search(
724
  query=query, # Search for interactions related to the query
725
  user_id=user_id, # Restrict search to the specific user
726
- limit=5 # Retrieve up to 5 relevant interactions
 
 
727
  )
728
 
 
729
  def handle_customer_query(self, user_id: str, query: str) -> str:
730
  """
731
  Process a customer's query and provide a response, taking into account past interactions.
@@ -803,8 +700,6 @@ def nutrition_disorder_streamlit():
803
  "content": f"Welcome, {user_id}! How can I help you with nutrition disorders today?"
804
  })
805
  st.session_state.login_submitted = True # Set flag to trigger rerun
806
-
807
- # Trigger rerun outside the form if login was successful
808
  if st.session_state.get("login_submitted", False):
809
  st.session_state.pop("login_submitted")
810
  st.rerun()
@@ -814,11 +709,11 @@ def nutrition_disorder_streamlit():
814
  with st.chat_message(message["role"]):
815
  st.write(message["content"])
816
 
817
- # Chat input
818
- user_query = st.chat_input("Type your question here (or 'exit' to end)...")
 
819
 
820
  if user_query:
821
- # Check if user wants to exit
822
  if user_query.lower() == "exit":
823
  st.session_state.chat_history.append({"role": "user", "content": "exit"})
824
  with st.chat_message("user"):
@@ -831,38 +726,41 @@ def nutrition_disorder_streamlit():
831
  st.rerun()
832
  return
833
 
834
- # Add user message to chat history
835
  st.session_state.chat_history.append({"role": "user", "content": user_query})
836
  with st.chat_message("user"):
837
  st.write(user_query)
838
 
839
- # Filter input
840
- filtered_result = filter_input_with_llama_guard(user_query)
841
-
842
- # Process through the agent
843
- with st.chat_message("assistant"):
844
- if filtered_result in ["safe", "unsafe S7", "unsafe S6"]:
845
- try:
846
- # Initialize chatbot if not already done
847
- if 'chatbot' not in st.session_state:
848
- st.session_state.chatbot = NutritionBot()
849
-
850
- # Get response from the chatbot
851
- response = st.session_state.chatbot.handle_customer_query(
852
- st.session_state.user_id,
853
- user_query
854
- )
855
-
856
- st.write(response)
857
- st.session_state.chat_history.append({"role": "assistant", "content": response})
858
- except Exception as e:
859
- error_msg = f"Sorry, I encountered an error while processing your query. Please try again. Error: {str(e)}"
860
- st.write(error_msg)
861
- st.session_state.chat_history.append({"role": "assistant", "content": error_msg})
862
- else:
863
- inappropriate_msg = "I apologize, but I cannot process that input as it may be inappropriate. Please try again."
864
- st.write(inappropriate_msg)
865
- st.session_state.chat_history.append({"role": "assistant", "content": inappropriate_msg})
 
 
 
 
866
 
867
  if __name__ == "__main__":
868
  nutrition_disorder_streamlit()
 
29
  from langchain_core.prompts import ChatPromptTemplate
30
 
31
  # LangChain OpenAI imports
32
+ from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI, ChatOpenAI # OpenAI embeddings and models
33
  from langchain.embeddings.openai import OpenAIEmbeddings # OpenAI embeddings for text vectors
34
 
35
  # LlamaParse & LlamaIndex imports
 
54
 
55
  #====================================SETUP=====================================#
56
  # Fetch secrets from Hugging Face Spaces
57
+ api_key = os.getenv("API_KEY")
58
  endpoint = os.getenv("OPENAI_API_BASE")
59
  llama_api_key = os.environ['GROQ_API_KEY']
60
  MEM0_api_key = os.environ['mem0']
 
66
  model_name='text-embedding-ada-002' # This is a fixed value and does not need modification
67
  )
68
 
69
+ # This initializes the OpenAI embedding function for the Chroma vectorstore, using the provided endpoint and API key.
70
 
71
  # Initialize the OpenAI Embeddings
72
  embedding_model = OpenAIEmbeddings(
 
83
  model="gpt-4o-mini",
84
  streaming=False
85
  )
86
+ # This initializes the Chat OpenAI model with the provided endpoint, API key, deployment name, and a temperature setting of 0 (to control response variability).
87
 
88
  # set the LLM and embedding model in the LlamaIndex settings.
89
+ # Settings.llm = _____ # Complete the code to define the LLM model
90
+ # Settings.embedding = _____ # Complete the code to define the embedding model
91
+ Settings.llm = llm # Complete the code to define the LLM model
92
+ Settings.embedding = embedding_model # Complete the code to define the embedding model
93
+
94
 
95
  #================================Creating Langgraph agent======================#
96
 
 
119
  Dict: The updated state with the expanded query.
120
  """
121
  print("---------Expanding Query---------")
122
+ # system_message = '''________________________'''
123
+ system_message = '''You are a nutrition expert and language model specialized in nutritional disorders. Your task is to expand the provided query by incorporating related keywords, synonyms, and additional context that can improve the retrieval of detailed nutrition disorder-related information.'''
 
 
 
 
 
 
 
 
 
 
124
 
 
 
125
 
 
 
 
 
 
 
 
126
 
127
  expand_prompt = ChatPromptTemplate.from_messages([
128
  ("system", system_message),
 
162
  Dict: The updated state with the retrieved context.
163
  """
164
  print("---------retrieve_context---------")
165
+ # query = state['_____'] # Complete the code to define the key for the expanded query
166
+ query = state['expanded_query'] # Complete the code to define the key for the expanded query
167
  #print("Query used for retrieval:", query) # Debugging: Print the query
168
 
169
  # Retrieve documents from the vector store
 
178
  }
179
  for doc in docs
180
  ]
181
+ # state['_____'] = context # Complete the code to define the key for storing the context
182
+ state['context'] = context # Complete the code to define the key for storing the context
183
+
184
  print("Extracted context with metadata:", context) # Debugging: Print the extracted context
185
  #print(f"Groundedness loop count: {state['groundedness_loop_count']}")
186
  return state
 
197
  Returns:
198
  Dict: The updated state with the generated response.
199
  """
200
+ print("---------craft_response---------")
201
+ # system_message = '''________________________'''
202
+ system_message = '''You are a nutrition expert and your responses should be clear, concise, and evidence-based. Use the provided context to accurately address the user's query regarding nutritional disorders.'''
203
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
  response_prompt = ChatPromptTemplate.from_messages([
206
  ("system", system_message),
207
+ ("user", "Query: {query}\nContext: {context}\n\nfeedback: {feedback}")
208
  ])
209
 
210
+ chain = response_prompt | llm
211
+ response = chain.invoke({
212
  "query": state['query'],
213
+ "context": "\n".join([doc["content"] for doc in state['context']]),
214
+ # "feedback": ________________ # add feedback to the prompt
215
+ "feedback": state.get("query_feedback", "No additional feedback provided") # add feedback to the prompt
216
+
217
  })
218
+ state['response'] = response
219
+ print("intermediate response: ", response)
220
+
221
  return state
222
 
223
 
 
233
  Dict: The updated state with the groundedness score.
234
  """
235
  print("---------check_groundedness---------")
236
+ # system_message = '''________________________'''
237
+ system_message = '''You are an evaluator for response groundedness. Given the context and the response related to nutritional disorders, provide a numerical score between 0 and 1 where 0 means the response is not grounded at all, and 1 means it is completely grounded in the context.'''
 
 
 
 
 
 
 
 
238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
  groundedness_prompt = ChatPromptTemplate.from_messages([
241
  ("system", system_message),
 
245
  chain = groundedness_prompt | llm | StrOutputParser()
246
  groundedness_score = float(chain.invoke({
247
  "context": "\n".join([doc["content"] for doc in state['context']]),
248
+ # "response": __________ # Complete the code to define the response
249
+ "response": state['response'] # Complete the code to define the response
250
  }))
251
+ print("groundedness_score: ", groundedness_score)
252
+ state['groundedness_loop_count'] += 1
253
  print("#########Groundedness Incremented###########")
254
  state['groundedness_score'] = groundedness_score
255
+
256
  return state
257
 
258
 
 
268
  Dict: The updated state with the precision score.
269
  """
270
  print("---------check_precision---------")
271
+ # system_message = '''________________________'''
272
+ system_message = '''You are an evaluator for response precision. Given the query and the response, provide a numerical score between 0 and 1 where 0 indicates that the response does not address the query at all, and 1 indicates that the response precisely addresses the query.'''
273
+
274
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  precision_prompt = ChatPromptTemplate.from_messages([
276
  ("system", system_message),
277
  ("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
278
  ])
279
 
280
+ # chain = _____________ | llm | StrOutputParser() # Complete the code to define the chain of processing
281
+ chain = precision_prompt | llm | StrOutputParser() # Complete the code to define the chain of processing
282
+
283
  precision_score = float(chain.invoke({
284
  "query": state['query'],
285
+ # "response":______________ # Complete the code to access the response from the state
286
+ "response":state['response'] # Complete the code to access the response from the state
287
+
288
  }))
289
  state['precision_score'] = precision_score
290
  print("precision_score:", precision_score)
 
306
  """
307
  print("---------refine_response---------")
308
 
309
+ # system_message = '''________________________'''
310
+ system_message = '''You are an expert editor in nutritional science communications. Your role is to review the response given to a nutritional query and provide clear suggestions to improve its accuracy, clarity, and completeness. Focus on making sure that the response fully addresses the query and is supported by evidence-based nutritional guidelines.'''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
 
 
 
312
 
313
  refine_response_prompt = ChatPromptTemplate.from_messages([
314
  ("system", system_message),
 
338
  Dict: The updated state with query refinement suggestions.
339
  """
340
  print("---------refine_query---------")
341
+ # system_message = '''________________________'''
342
+ system_message = '''You are a search query refinement expert. Given the original and expanded queries related to nutritional disorders, provide suggestions to refine the query further for improved search results.'''
343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
 
345
  refine_query_prompt = ChatPromptTemplate.from_messages([
346
  ("system", system_message),
 
363
  """Decides if groundedness is sufficient or needs improvement."""
364
  print("---------should_continue_groundedness---------")
365
  print("groundedness loop count: ", state['groundedness_loop_count'])
366
+ # if state['groundedness_score'] >= _____: # Complete the code to define the threshold for groundedness
367
+ if state['groundedness_score'] >= 0.7: # Complete the code to define the threshold for groundedness
368
  print("Moving to precision")
369
  return "check_precision"
370
  else:
 
378
  def should_continue_precision(state: Dict) -> str:
379
  """Decides if precision is sufficient or needs improvement."""
380
  print("---------should_continue_precision---------")
381
+ # print("precision loop count: ", ___________)
382
+ # if ___________: # Threshold for precision
383
+
384
+ print("precision loop count: ", state['precision_loop_count'])
385
+ if state['precision_score'] >= 0.8: # Threshold for precision
386
  return "pass" # Complete the workflow
387
  else:
388
+ # if ___________: # Maximum allowed loops
389
+ if state["precision_loop_count"] > state['loop_max_iter']: # Maximum allowed loops
390
  return "max_iterations_reached"
391
  else:
392
  print(f"---------Precision Score Threshold Not met. Refining Query-----------") # Debugging
393
+ # return ____________ # Refine the query
394
  return "refine_query" # Refine the query
395
 
396
 
397
 
398
+
399
  def max_iterations_reached(state: Dict) -> Dict:
400
  """Handles the case when the maximum number of iterations is reached."""
401
  print("---------max_iterations_reached---------")
 
406
 
407
 
408
 
409
+ from langgraph.graph import END, StateGraph, START
410
+
411
  def create_workflow() -> StateGraph:
412
  """Creates the updated workflow for the AI nutrition agent."""
413
+ # workflow = StateGraph(_____ ) # Complete the code to define the initial state of the agent
414
+ workflow = StateGraph(dict) # Complete the code to define the initial state of the agent
415
 
416
  # Add processing nodes
417
+ workflow.add_node("expand_query", expand_query ) # Step 1: Expand user query. Complete with the function to expand the query
418
+ workflow.add_node("retrieve_context", retrieve_context ) # Step 2: Retrieve relevant documents. Complete with the function to retrieve context
419
+ workflow.add_node("craft_response", craft_response ) # Step 3: Generate a response based on retrieved data. Complete with the function to craft a response
420
+ workflow.add_node("score_groundedness", score_groundedness ) # Step 4: Evaluate response grounding. Complete with the function to score groundedness
421
+ workflow.add_node("refine_response", refine_response ) # Step 5: Improve response if it's weakly grounded. Complete with the function to refine the response
422
+ workflow.add_node("check_precision", check_precision ) # Step 6: Evaluate response precision. Complete with the function to check precision
423
+ workflow.add_node("refine_query", refine_query ) # Step 7: Improve query if response lacks precision. Complete with the function to refine the query
424
+ workflow.add_node("max_iterations_reached", max_iterations_reached ) # Step 8: Handle max iterations. Complete with the function to handle max iterations
425
+
426
+
427
  # Main flow edges
428
  workflow.add_edge(START, "expand_query")
429
  workflow.add_edge("expand_query", "retrieve_context")
430
  workflow.add_edge("retrieve_context", "craft_response")
431
  workflow.add_edge("craft_response", "score_groundedness")
 
 
432
 
433
  # Conditional edges based on groundedness check
434
  workflow.add_conditional_edges(
 
440
  "max_iterations_reached": "max_iterations_reached" # If max loops reached, exit.
441
  }
442
  )
443
+
444
  workflow.add_edge("refine_response", "craft_response") # Refined responses are reprocessed.
445
 
446
  # Conditional edges based on precision check
 
450
  {
451
  "pass": END, # If precise, complete the workflow.
452
  "refine_query": "refine_query", # If imprecise, refine the query.
453
+ "max_iterations_reached": "max_iterations_reached" # If max loops reached, exit.
454
  }
455
  )
456
+
457
  workflow.add_edge("refine_query", "expand_query") # Refined queries go through expansion again.
458
 
459
  workflow.add_edge("max_iterations_reached", END)
 
 
460
 
461
  return workflow
462
 
463
 
 
464
  #=========================== Defining the agentic rag tool ====================#
465
  WORKFLOW_APP = create_workflow().compile()
466
  @tool
 
477
  # Initialize state with necessary parameters
478
  inputs = {
479
  "query": query, # Current user query
480
+ "expanded_query": "", # Complete the code to define the expanded version of the query
481
  "context": [], # Retrieved documents (initially empty)
482
+ "response": "", # Complete the code to define the AI-generated response
483
+ "precision_score": 0.0, # Complete the code to define the precision score of the response
484
+ "groundedness_score": 0.0, # Complete the code to define the groundedness score of the response
485
+ "groundedness_loop_count": 0, # Complete the code to define the counter for groundedness loops
486
+ "precision_loop_count": 0, # Complete the code to define the counter for precision loops
487
+ "feedback": "", # Complete the code to define the feedback
488
+ "query_feedback": "", # Complete the code to define the query feedback
489
+ "loop_max_iter": 5 # Complete the code to define the maximum number of iterations for loops
 
490
  }
491
 
492
  output = WORKFLOW_APP.invoke(inputs)
 
530
  """
531
 
532
  # Initialize a memory client to store and retrieve customer interactions
533
+ self.memory = MemoryClient(api_key=userdata.get("mem0")) # Complete the code to define the memory client API key
534
 
535
+ # Initialize the OpenAI client using the provided credentials
536
  self.client = ChatOpenAI(
537
  model_name="gpt-4o-mini", # Specify the model to use (e.g., GPT-4 optimized version)
538
  api_key=config.get("API_KEY"), # API key for authentication
 
540
  temperature=0 # Controls randomness in responses; 0 ensures deterministic results
541
  )
542
 
 
543
  # Define tools available to the chatbot, such as web search
544
  tools = [agentic_rag]
545
 
 
571
  # Wrap the agent in an executor to manage tool interactions and execution flow
572
  self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
573
 
574
+
575
  def store_customer_interaction(self, user_id: str, message: str, response: str, metadata: Dict = None):
576
  """
577
  Store customer interaction in memory for future reference.
 
602
  metadata=metadata
603
  )
604
 
605
+
606
  def get_relevant_history(self, user_id: str, query: str) -> List[Dict]:
607
  """
608
  Retrieve past interactions relevant to the current query.
 
617
  return self.memory.search(
618
  query=query, # Search for interactions related to the query
619
  user_id=user_id, # Restrict search to the specific user
620
+ # limit=_____ # Complete the code to define the limit for retrieved interactions
621
+ limit=5 # Complete the code to define the limit for retrieved interactions
622
+
623
  )
624
 
625
+
626
  def handle_customer_query(self, user_id: str, query: str) -> str:
627
  """
628
  Process a customer's query and provide a response, taking into account past interactions.
 
700
  "content": f"Welcome, {user_id}! How can I help you with nutrition disorders today?"
701
  })
702
  st.session_state.login_submitted = True # Set flag to trigger rerun
 
 
703
  if st.session_state.get("login_submitted", False):
704
  st.session_state.pop("login_submitted")
705
  st.rerun()
 
709
  with st.chat_message(message["role"]):
710
  st.write(message["content"])
711
 
712
+ # Chat input with custom placeholder text
713
+ # user_query = st.chat_input(__________) # Blank #1: Fill in the chat input prompt (e.g., "Type your question here (or 'exit' to end)...")
714
+ user_query = st.chat_input("Type your question here (or 'exit' to end)...") # Blank #1
715
 
716
  if user_query:
 
717
  if user_query.lower() == "exit":
718
  st.session_state.chat_history.append({"role": "user", "content": "exit"})
719
  with st.chat_message("user"):
 
726
  st.rerun()
727
  return
728
 
 
729
  st.session_state.chat_history.append({"role": "user", "content": user_query})
730
  with st.chat_message("user"):
731
  st.write(user_query)
732
 
733
+
734
+
735
+ # Filter input using Llama Guard
736
+ # filtered_result = __________(user_query) # Blank #2: Fill in with the function name for filtering input (e.g., filter_input_with_llama_guard)
737
+
738
+ # filtered_result = filter_input_with_llama_guard(user_query) # Blank #2: Fill in with the function name for filtering input (e.g., filter_input_with_llama_guard)
739
+ filtered_result = check_input_safety(user_query, llama_guard_client)
740
+ filtered_result = filtered_result.replace("\n", " ") # Normalize the result
741
+
742
+ # Check if input is safe based on allowed statuses
743
+ # if filtered_result in [__________, __________, __________]: # Blanks #3, #4, #5: Fill in with allowed safe statuses (e.g., "safe", "unsafe S7", "unsafe S6")
744
+ if filtered_result in ["SAFE", "S6", "S7"]: # Blanks #3, #4, #5: Fill in with allowed safe statuses (e.g., "safe", "unsafe S7", "unsafe S6")
745
+
746
+ try:
747
+ if 'chatbot' not in st.session_state:
748
+ st.session_state.chatbot = NutritionBot() # Blank #6: Fill in with the chatbot class initialization (e.g., NutritionBot)
749
+
750
+ # response = st.session_state.chatbot.__________(st.session_state.user_id, user_query)
751
+ response = st.session_state.chatbot.handle_customer_query(st.session_state.user_id, user_query)
752
+
753
+ # Blank #7: Fill in with the method to handle queries (e.g., handle_customer_query)
754
+ st.write(response)
755
+ st.session_state.chat_history.append({"role": "assistant", "content": response})
756
+ except Exception as e:
757
+ error_msg = f"Sorry, I encountered an error while processing your query. Please try again. Error: {str(e)}"
758
+ st.write(error_msg)
759
+ st.session_state.chat_history.append({"role": "assistant", "content": error_msg})
760
+ else:
761
+ inappropriate_msg = "I apologize, but I cannot process that input as it may be inappropriate. Please try again."
762
+ st.write(inappropriate_msg)
763
+ st.session_state.chat_history.append({"role": "assistant", "content": inappropriate_msg})
764
 
765
  if __name__ == "__main__":
766
  nutrition_disorder_streamlit()