stckwok commited on
Commit
57d0f31
·
verified ·
1 Parent(s): 2c719c5

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +868 -0
app.py ADDED
@@ -0,0 +1,868 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Import necessary libraries
3
+ import os # Interacting with the operating system (reading/writing files)
4
+ import chromadb # High-performance vector database for storing/querying dense vectors
5
+ from dotenv import load_dotenv # Loading environment variables from a .env file
6
+ import json # Parsing and handling JSON data
7
+
8
+ # LangChain imports
9
+ from langchain_core.documents import Document # Document data structures
10
+ from langchain_core.runnables import RunnablePassthrough # LangChain core library for running pipelines
11
+ from langchain_core.output_parsers import StrOutputParser # String output parser
12
+ from langchain.prompts import ChatPromptTemplate # Template for chat prompts
13
+ from langchain.chains.query_constructor.base import AttributeInfo # Base classes for query construction
14
+ from langchain.retrievers.self_query.base import SelfQueryRetriever # Base classes for self-querying retrievers
15
+ from langchain.retrievers.document_compressors import LLMChainExtractor, CrossEncoderReranker # Document compressors
16
+ from langchain.retrievers import ContextualCompressionRetriever # Contextual compression retrievers
17
+
18
+ # LangChain community & experimental imports
19
+ from langchain_community.vectorstores import Chroma # Implementations of vector stores like Chroma
20
+ from langchain_community.document_loaders import PyPDFDirectoryLoader, PyPDFLoader # Document loaders for PDFs
21
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder # Cross-encoders from HuggingFace
22
+ from langchain_experimental.text_splitter import SemanticChunker # Experimental text splitting methods
23
+ from langchain.text_splitter import (
24
+ CharacterTextSplitter, # Splitting text by characters
25
+ RecursiveCharacterTextSplitter # Recursive splitting of text by characters
26
+ )
27
+ from langchain_core.tools import tool
28
+ from langchain.agents import create_tool_calling_agent, AgentExecutor
29
+ from langchain_core.prompts import ChatPromptTemplate
30
+
31
+ # LangChain OpenAI imports
32
+ from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI # OpenAI embeddings and models
33
+ from langchain.embeddings.openai import OpenAIEmbeddings # OpenAI embeddings for text vectors
34
+
35
+ # LlamaParse & LlamaIndex imports
36
+ from llama_parse import LlamaParse # Document parsing library
37
+ from llama_index.core import Settings, SimpleDirectoryReader # Core functionalities of the LlamaIndex
38
+
39
+ # LangGraph import
40
+ from langgraph.graph import StateGraph, END, START # State graph for managing states in LangChain
41
+
42
+ # Pydantic import
43
+ from pydantic import BaseModel # Pydantic for data validation
44
+
45
+ # Typing imports
46
+ from typing import Dict, List, Tuple, Any, TypedDict # Python typing for function annotations
47
+
48
+ # Other utilities
49
+ import numpy as np # Numpy for numerical operations
50
+ from groq import Groq
51
+ from mem0 import MemoryClient
52
+ import streamlit as st
53
+ from datetime import datetime
54
+
55
+ #====================================SETUP=====================================#
56
+ # Fetch secrets from Hugging Face Spaces
57
+ api_key = config.get("API_KEY")
58
+ endpoint = config.get("OPENAI_API_BASE")
59
+ llama_api_key = os.environ['GROQ_API_KEY']
60
+ MEM0_api_key = os.environ['mem0']
61
+
62
+ # Initialize the OpenAI embedding function for Chroma
63
+ embedding_function = chromadb.utils.embedding_functions.OpenAIEmbeddingFunction(
64
+ api_base=endpoint, # Complete the code to define the API base endpoint
65
+ api_key=api_key, # Complete the code to define the API key
66
+ model_name='text-embedding-ada-002' # This is a fixed value and does not need modification
67
+ )
68
+
69
+
70
+ # Initialize the OpenAI Embeddings
71
+ embedding_model = OpenAIEmbeddings(
72
+ openai_api_base=endpoint,
73
+ openai_api_key=api_key,
74
+ model='text-embedding-ada-002'
75
+ )
76
+
77
+
78
+ # Initialize the Chat OpenAI model
79
+ llm = ChatOpenAI(
80
+ openai_api_base=endpoint,
81
+ openai_api_key=api_key,
82
+ model="gpt-4o-mini",
83
+ streaming=False
84
+ )
85
+
86
+
87
+ # set the LLM and embedding model in the LlamaIndex settings.
88
+ Settings.llm = llm
89
+ Settings.embedding = embedding_model
90
+
91
+ #================================Creating Langgraph agent======================#
92
+
93
+ class AgentState(TypedDict):
94
+ query: str # The current user query
95
+ expanded_query: str # The expanded version of the user query
96
+ context: List[Dict[str, Any]] # Retrieved documents (content and metadata)
97
+ response: str # The generated response to the user query
98
+ precision_score: float # The precision score of the response
99
+ groundedness_score: float # The groundedness score of the response
100
+ groundedness_loop_count: int # Counter for groundedness refinement loops
101
+ precision_loop_count: int # Counter for precision refinement loops
102
+ feedback: str
103
+ query_feedback: str
104
+ groundedness_check: bool
105
+ loop_max_iter: int
106
+
107
+ def expand_query(state):
108
+ """
109
+ Expands the user query to improve retrieval of nutrition disorder-related information.
110
+
111
+ Args:
112
+ state (Dict): The current state of the workflow, containing the user query.
113
+
114
+ Returns:
115
+ Dict: The updated state with the expanded query.
116
+ """
117
+ print("---------Expanding Query---------")
118
+ system_message = '''You are an AI specializing in improving search queries to retrieve the most relevant nutrition disorder-related information.
119
+ Your task is to **refine** and **expand** the given query so that better search results are obtained, while **keeping the original intent** unchanged.
120
+
121
+ Guidelines:
122
+ - Add **specific details** where needed. Example: If a user asks about "anorexia," specify aspects like symptoms, causes, or treatment options.
123
+ - Include **related terms** to improve retrieval (e.g., “bulimia” → “bulimia nervosa vs binge eating disorder”).
124
+ - If the user provides an unclear query, suggest necessary clarifications.
125
+ - **DO NOT** answer the question. Your job is only to enhance the query.
126
+
127
+ Examples:
128
+ 1. User Query: "Tell me about eating disorders."
129
+ Expanded Query: "Provide details on eating disorders, including types (e.g., anorexia nervosa, bulimia nervosa), symptoms, causes, and treatment options."
130
+
131
+ 2. User Query: "What is anorexia?"
132
+ Expanded Query: "Explain anorexia nervosa, including its symptoms, causes, risk factors, and treatment options."
133
+
134
+ 3. User Query: "How to treat bulimia?"
135
+ Expanded Query: "Describe treatment options for bulimia nervosa, including psychotherapy, medications, and lifestyle changes."
136
+
137
+ 4. User Query: "What are the effects of malnutrition?"
138
+ Expanded Query: "Explain the effects of malnutrition on physical and mental health, including specific nutrient deficiencies and their consequences."
139
+
140
+ Now, expand the following query:'''
141
+
142
+ expand_prompt = ChatPromptTemplate.from_messages([
143
+ ("system", system_message),
144
+ ("user", "Expand this query: {query} using the feedback: {query_feedback}")
145
+
146
+ ])
147
+
148
+ chain = expand_prompt | llm | StrOutputParser()
149
+ expanded_query = chain.invoke({"query": state['query'], "query_feedback":state["query_feedback"]})
150
+ print("expanded_query", expanded_query)
151
+ state["expanded_query"] = expanded_query
152
+ return state
153
+
154
+
155
+ # Initialize the Chroma vector store for retrieving documents
156
+ vector_store = Chroma(
157
+ collection_name="nutritional_hypotheticals",
158
+ persist_directory="./nutritional_db",
159
+ embedding_function=embedding_model
160
+
161
+ )
162
+
163
+ # Create a retriever from the vector store
164
+ retriever = vector_store.as_retriever(
165
+ search_type='similarity',
166
+ search_kwargs={'k': 3}
167
+ )
168
+
169
+ def retrieve_context(state):
170
+ """
171
+ Retrieves context from the vector store using the expanded or original query.
172
+
173
+ Args:
174
+ state (Dict): The current state of the workflow, containing the query and expanded query.
175
+
176
+ Returns:
177
+ Dict: The updated state with the retrieved context.
178
+ """
179
+ print("---------retrieve_context---------")
180
+ query = state['expanded_query']
181
+ #print("Query used for retrieval:", query) # Debugging: Print the query
182
+
183
+ # Retrieve documents from the vector store
184
+ docs = retriever.invoke(query)
185
+ print("Retrieved documents:", docs) # Debugging: Print the raw docs object
186
+
187
+ # Extract both page_content and metadata from each document
188
+ context= [
189
+ {
190
+ "content": doc.page_content, # The actual content of the document
191
+ "metadata": doc.metadata # The metadata (e.g., source, page number, etc.)
192
+ }
193
+ for doc in docs
194
+ ]
195
+ state['context'] = context
196
+ print("Extracted context with metadata:", context) # Debugging: Print the extracted context
197
+ #print(f"Groundedness loop count: {state['groundedness_loop_count']}")
198
+ return state
199
+
200
+
201
+
202
+ def craft_response(state: Dict) -> Dict:
203
+ """
204
+ Generates a response using the retrieved context, focusing on nutrition disorders.
205
+
206
+ Args:
207
+ state (Dict): The current state of the workflow, containing the query and retrieved context.
208
+
209
+ Returns:
210
+ Dict: The updated state with the generated response.
211
+ """
212
+ system_message = '''You are a professional AI nutrition disorder specialist generating responses based on retrieved documents.
213
+ Your task is to use the given **context** to generate a highly accurate, informative, and user-friendly response.
214
+
215
+ Guidelines:
216
+ - **Be direct and concise** while ensuring completeness.
217
+ - **DO NOT include information that is not present in the context.**
218
+ - If multiple sources exist, synthesize them into a coherent response.
219
+ - If the context does not fully answer the query, state what additional information is needed.
220
+ - Use bullet points when explaining complex concepts.
221
+
222
+ Example:
223
+ User Query: "What are the symptoms of anorexia nervosa?"
224
+ Context:
225
+ 1. Anorexia nervosa is characterized by extreme weight loss and fear of gaining weight.
226
+ 2. Common symptoms include restricted eating, distorted body image, and excessive exercise.
227
+ Response:
228
+ "Anorexia nervosa is an eating disorder characterized by extreme weight loss and an intense fear of gaining weight. Common symptoms include:
229
+ - Restricted eating
230
+ - Distorted body image
231
+ - Excessive exercise
232
+ If you or someone you know is experiencing these symptoms, it is important to seek professional help."'''
233
+
234
+ response_prompt = ChatPromptTemplate.from_messages([
235
+ ("system", system_message),
236
+ ("user", "Query: {query}\nContext: {context}\n\nResponse:")
237
+ ])
238
+
239
+ chain = response_prompt | llm | StrOutputParser()
240
+ state['response'] = chain.invoke({
241
+ "query": state['query'],
242
+ "context": "\n".join([doc["content"] for doc in state['context']]) # Extract content from each document
243
+ })
244
+ return state
245
+
246
+
247
+
248
+ def score_groundedness(state: Dict) -> Dict:
249
+ """
250
+ Checks whether the response is grounded in the retrieved context.
251
+
252
+ Args:
253
+ state (Dict): The current state of the workflow, containing the response and context.
254
+
255
+ Returns:
256
+ Dict: The updated state with the groundedness score.
257
+ """
258
+ print("---------check_groundedness---------")
259
+ system_message = '''You are an AI tasked with evaluating whether a response is grounded in the provided context and includes proper citations.
260
+
261
+ Guidelines:
262
+ 1. **Groundedness Check**:
263
+ - Verify that the response accurately reflects the information in the context.
264
+ - Flag any unsupported claims or deviations from the context.
265
+
266
+ 2. **Citation Check**:
267
+ - Ensure that the response includes citations to the source material (e.g., "According to [Source], ...").
268
+ - If citations are missing, suggest adding them.
269
+
270
+ 3. **Scoring**:
271
+ - Assign a groundedness score between 0 and 1, where 1 means fully grounded and properly cited.
272
+
273
+ Examples:
274
+ 1. Response: "Anorexia nervosa is caused by genetic factors (Source 1)."
275
+ Context: "Anorexia nervosa is influenced by genetic, environmental, and psychological factors (Source 1)."
276
+ Evaluation: "The response is grounded and properly cited. Groundedness score: 1.0."
277
+
278
+ 2. Response: "Bulimia nervosa can be cured with diet alone."
279
+ Context: "Treatment for bulimia nervosa involves psychotherapy and medications (Source 2)."
280
+ Evaluation: "The response is ungrounded and lacks citations. Groundedness score: 0.2."
281
+
282
+ 3. Response: "Anorexia nervosa has a high mortality rate."
283
+ Context: "Anorexia nervosa has one of the highest mortality rates among psychiatric disorders (Source 3)."
284
+ Evaluation: "The response is grounded but lacks a citation. Groundedness score: 0.7. ."
285
+
286
+ ****Return only a float score (e.g., 0.9). Do not provide explanations.****
287
+
288
+ Now, evaluate the following response:
289
+ '''
290
+
291
+ groundedness_prompt = ChatPromptTemplate.from_messages([
292
+ ("system", system_message),
293
+ ("user", "Context: {context}\nResponse: {response}\n\nGroundedness score:")
294
+ ])
295
+
296
+ chain = groundedness_prompt | llm | StrOutputParser()
297
+ groundedness_score = float(chain.invoke({
298
+ "context": "\n".join([doc["content"] for doc in state['context']]),
299
+ "response": state['response']
300
+ }))
301
+ print("groundedness_score: ",groundedness_score)
302
+ state['groundedness_loop_count'] +=1
303
+ print("#########Groundedness Incremented###########")
304
+ state['groundedness_score'] = groundedness_score
305
+ return state
306
+
307
+
308
+
309
+ def check_precision(state: Dict) -> Dict:
310
+ """
311
+ Checks whether the response precisely addresses the user’s query.
312
+
313
+ Args:
314
+ state (Dict): The current state of the workflow, containing the query and response.
315
+
316
+ Returns:
317
+ Dict: The updated state with the precision score.
318
+ """
319
+ print("---------check_precision---------")
320
+ system_message = '''You are an AI evaluator assessing the **precision** of the response.
321
+ Your task is to **score** how well the response addresses the user’s original nutrition disorder-related query.
322
+
323
+ Scoring Criteria:
324
+ - 1.0 → The response is fully precise, directly answering the question.
325
+ - 0.7 → The response is mostly correct but contains some generalization.
326
+ - 0.5 → The response is somewhat relevant but lacks key details.
327
+ - 0.3 → The response is vague or only partially correct.
328
+ - 0.0 → The response is incorrect or misleading.
329
+
330
+ Examples:
331
+ 1. Query: "What are the symptoms of anorexia nervosa?"
332
+ Response: "The symptoms of anorexia nervosa include extreme weight loss, fear of gaining weight, and a distorted body image."
333
+ Precision Score: 1.0
334
+
335
+ 2. Query: "How is bulimia nervosa treated?"
336
+ Response: "Bulimia nervosa is treated with therapy and medications."
337
+ Precision Score: 0.7
338
+
339
+ 3. Query: "What causes binge eating disorder?"
340
+ Response: "Binge eating disorder is caused by a combination of genetic, psychological, and environmental factors."
341
+ Precision Score: 0.5
342
+
343
+ 4. Query: "What are the effects of malnutrition?"
344
+ Response: "Malnutrition can lead to health problems."
345
+ Precision Score: 0.3
346
+
347
+ 5. Query: "What is the mortality rate of anorexia nervosa?"
348
+ Response: "Anorexia nervosa is a type of eating disorder."
349
+ Precision Score: 0.0
350
+
351
+ *****Return only a float score (e.g., 0.9). Do not provide explanations.*****
352
+ Now, evaluate the following query and response:
353
+ '''
354
+ precision_prompt = ChatPromptTemplate.from_messages([
355
+ ("system", system_message),
356
+ ("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
357
+ ])
358
+
359
+ chain = precision_prompt | llm | StrOutputParser()
360
+ precision_score = float(chain.invoke({
361
+ "query": state['query'],
362
+ "response": state['response']
363
+ }))
364
+ state['precision_score'] = precision_score
365
+ print("precision_score:", precision_score)
366
+ state['precision_loop_count'] +=1
367
+ print("#########Precision Incremented###########")
368
+ return state
369
+
370
+
371
+
372
+ def refine_response(state: Dict) -> Dict:
373
+ """
374
+ Suggests improvements for the generated response.
375
+
376
+ Args:
377
+ state (Dict): The current state of the workflow, containing the query and response.
378
+
379
+ Returns:
380
+ Dict: The updated state with response refinement suggestions.
381
+ """
382
+ print("---------refine_response---------")
383
+
384
+ system_message = '''You are an AI response refinement assistant. Your task is to suggest **improvements** for the given response.
385
+
386
+ ### Guidelines:
387
+ - Identify **gaps in the explanation** (missing key details).
388
+ - Highlight **unclear or vague parts** that need elaboration.
389
+ - Suggest **additional details** that should be included for better accuracy.
390
+ - Ensure the refined response is **precise** and **grounded** in the retrieved context.
391
+
392
+ ### Examples:
393
+ 1. Query: "What are the symptoms of anorexia nervosa?"
394
+ Response: "The symptoms include weight loss and fear of gaining weight."
395
+ Suggestions: "The response is missing key details about behavioral and emotional symptoms. Add details like 'distorted body image' and 'restrictive eating patterns.'"
396
+
397
+ 2. Query: "How is bulimia nervosa treated?"
398
+ Response: "Bulimia nervosa is treated with therapy."
399
+ Suggestions: "The response is too vague. Specify the types of therapy (e.g., cognitive-behavioral therapy) and mention other treatments like nutritional counseling and medications."
400
+
401
+ 3. Query: "What causes binge eating disorder?"
402
+ Response: "Binge eating disorder is caused by psychological factors."
403
+ Suggestions: "The response is incomplete. Add details about genetic and environmental factors, and explain how they contribute to the disorder."
404
+
405
+ Now, suggest improvements for the following response:
406
+ '''
407
+
408
+ refine_response_prompt = ChatPromptTemplate.from_messages([
409
+ ("system", system_message),
410
+ ("user", "Query: {query}\nResponse: {response}\n\n"
411
+ "What improvements can be made to enhance accuracy and completeness?")
412
+ ])
413
+
414
+ chain = refine_response_prompt | llm| StrOutputParser()
415
+
416
+ # Store response suggestions in a structured format
417
+ feedback = f"Previous Response: {state['response']}\nSuggestions: {chain.invoke({'query': state['query'], 'response': state['response']})}"
418
+ print("feedback: ", feedback)
419
+ print(f"State: {state}")
420
+ state['feedback'] = feedback
421
+ return state
422
+
423
+
424
+
425
+ def refine_query(state: Dict) -> Dict:
426
+ """
427
+ Suggests improvements for the expanded query.
428
+
429
+ Args:
430
+ state (Dict): The current state of the workflow, containing the query and expanded query.
431
+
432
+ Returns:
433
+ Dict: The updated state with query refinement suggestions.
434
+ """
435
+ print("---------refine_query---------")
436
+ system_message = '''You are an AI query refinement assistant. Your task is to suggest **improvements** for the expanded query.
437
+
438
+ ### Guidelines:
439
+ - Add **specific keywords** to improve document retrieval.
440
+ - Identify **missing details** that should be included.
441
+ - Suggest **ways to narrow the scope** for better precision.
442
+
443
+ ### Examples:
444
+ 1. Original Query: "Tell me about eating disorders."
445
+ Expanded Query: "Provide details on eating disorders, including types, symptoms, causes, and treatment options."
446
+ Suggestions: "Add specific types of eating disorders like 'anorexia nervosa' and 'bulimia nervosa' to improve retrieval."
447
+
448
+ 2. Original Query: "What is anorexia?"
449
+ Expanded Query: "Explain anorexia nervosa, including its symptoms and causes."
450
+ Suggestions: "Include details about treatment options and risk factors to make the query more comprehensive."
451
+
452
+ 3. Original Query: "How to treat bulimia?"
453
+ Expanded Query: "Describe treatment options for bulimia nervosa."
454
+ Suggestions: "Specify types of treatments like 'cognitive-behavioral therapy' and 'medications' for better precision."
455
+
456
+ Now, suggest improvements for the following expanded query:
457
+ '''
458
+
459
+ refine_query_prompt = ChatPromptTemplate.from_messages([
460
+ ("system", system_message),
461
+ ("user", "Original Query: {query}\nExpanded Query: {expanded_query}\n\n"
462
+ "What improvements can be made for a better search?")
463
+ ])
464
+
465
+ chain = refine_query_prompt | llm | StrOutputParser()
466
+
467
+ # Store refinement suggestions without modifying the original expanded query
468
+ query_feedback = f"Previous Expanded Query: {state['expanded_query']}\nSuggestions: {chain.invoke({'query': state['query'], 'expanded_query': state['expanded_query']})}"
469
+ print("query_feedback: ", query_feedback)
470
+ print(f"Groundedness loop count: {state['groundedness_loop_count']}")
471
+ state['query_feedback'] = query_feedback
472
+ return state
473
+
474
+
475
+
476
+ def should_continue_groundedness(state):
477
+ """Decides if groundedness is sufficient or needs improvement."""
478
+ print("---------should_continue_groundedness---------")
479
+ print("groundedness loop count: ", state['groundedness_loop_count'])
480
+ if state['groundedness_score'] >= 0.4: # Threshold for groundedness
481
+ print("Moving to precision")
482
+ return "check_precision"
483
+ else:
484
+ if state["groundedness_loop_count"] > state['loop_max_iter']:
485
+ return "max_iterations_reached"
486
+ else:
487
+ print(f"---------Groundedness Score Threshold Not met. Refining Response-----------")
488
+ return "refine_response"
489
+
490
+
491
+ def should_continue_precision(state: Dict) -> str:
492
+ """Decides if precision is sufficient or needs improvement."""
493
+ print("---------should_continue_precision---------")
494
+ print("precision loop count: ",state['precision_loop_count'])
495
+ if state['precision_score'] >= 0.7: # Threshold for precision
496
+ return "pass" # Complete the workflow
497
+ else:
498
+ if state['precision_loop_count'] > state['loop_max_iter']: # Maximum allowed loops
499
+ return "max_iterations_reached"
500
+ else:
501
+ print(f"---------Precision Score Threshold Not met. Refining Query-----------") # Debugging
502
+ # Exit the loop
503
+ return "refine_query" # Refine the query
504
+
505
+
506
+
507
+ def max_iterations_reached(state: Dict) -> Dict:
508
+ """Handles the case when the maximum number of iterations is reached."""
509
+ print("---------max_iterations_reached---------")
510
+ """Handles the case when the maximum number of iterations is reached."""
511
+ response = "I'm unable to refine the response further. Please provide more context or clarify your question."
512
+ state['response'] = response
513
+ return state
514
+
515
+
516
+
517
+ def create_workflow() -> StateGraph:
518
+ """Creates the updated workflow for the AI nutrition agent."""
519
+ workflow = StateGraph(AgentState)
520
+
521
+ # Add processing nodes
522
+ workflow.add_node("expand_query", expand_query) # Step 1: Expand user query.
523
+ workflow.add_node("retrieve_context", retrieve_context) # Step 2: Retrieve relevant documents.
524
+ workflow.add_node("craft_response", craft_response) # Step 3: Generate a response based on retrieved data.
525
+ workflow.add_node("score_groundedness", score_groundedness) # Step 4: Evaluate response grounding.
526
+ workflow.add_node("refine_response", refine_response) # Step 5: Improve response if it's weakly grounded.
527
+ workflow.add_node("check_precision", check_precision) # Step 6: Evaluate response precision.
528
+ workflow.add_node("refine_query", refine_query) # Step 7: Improve query if response lacks precision.
529
+ workflow.add_node("max_iterations_reached", max_iterations_reached) # Step 8: Handle max iterations.
530
+ # workflow.add_node("groundedness_decider",groundedness_decider)
531
+ # Main flow edges
532
+ workflow.add_edge(START, "expand_query")
533
+ workflow.add_edge("expand_query", "retrieve_context")
534
+ workflow.add_edge("retrieve_context", "craft_response")
535
+ workflow.add_edge("craft_response", "score_groundedness")
536
+ # workflow.add_edge("score_groundedness","groundedness_decider")
537
+
538
+
539
+ # Conditional edges based on groundedness check
540
+ workflow.add_conditional_edges(
541
+ "score_groundedness",
542
+ should_continue_groundedness, # Use the conditional function
543
+ {
544
+ "check_precision": "check_precision", # If well-grounded, proceed to precision check.
545
+ "refine_response": "refine_response", # If not, refine the response.
546
+ "max_iterations_reached": "max_iterations_reached" # If max loops reached, exit.
547
+ }
548
+ )
549
+ workflow.add_edge("refine_response", "craft_response") # Refined responses are reprocessed.
550
+
551
+ # Conditional edges based on precision check
552
+ workflow.add_conditional_edges(
553
+ "check_precision",
554
+ should_continue_precision, # Use the conditional function
555
+ {
556
+ "pass": END, # If precise, complete the workflow.
557
+ "refine_query": "refine_query", # If imprecise, refine the query.
558
+ "max_iterations_reached": "max_iterations_reached" # If max loops reached, exit.
559
+ }
560
+ )
561
+ workflow.add_edge("refine_query", "expand_query") # Refined queries go through expansion again.
562
+
563
+ workflow.add_edge("max_iterations_reached", END)
564
+ # Set entry point
565
+ # workflow.set_entry_point("expand_query")
566
+
567
+ return workflow
568
+
569
+
570
+
571
+ #=========================== Defining the agentic rag tool ====================#
572
+ WORKFLOW_APP = create_workflow().compile()
573
+ @tool
574
+ def agentic_rag(query: str):
575
+ """
576
+ Runs the RAG-based agent with conversation history for context-aware responses.
577
+
578
+ Args:
579
+ query (str): The current user query.
580
+
581
+ Returns:
582
+ Dict[str, Any]: The updated state with the generated response and conversation history.
583
+ """
584
+ # Initialize state with necessary parameters
585
+ inputs = {
586
+ "query": query, # Current user query
587
+ "expanded_query": "", # Expanded version of the query
588
+ "context": [], # Retrieved documents (initially empty)
589
+ "response": "", # AI-generated response
590
+ "precision_score": 0.0, # Precision score of the response
591
+ "groundedness_score": 0.0, # Groundedness score of the response
592
+ "groundedness_loop_count": 0, # Counter for groundedness loops
593
+ "precision_loop_count": 0, # Counter for precision loops
594
+ "feedback": "",
595
+ "query_feedback":"",
596
+ "loop_max_iter":2
597
+
598
+ }
599
+
600
+ output = WORKFLOW_APP.invoke(inputs)
601
+
602
+ return output
603
+
604
+
605
+ #================================ Guardrails ===========================#
606
+ llama_guard_client = Groq(api_key=llama_api_key)
607
+ # Function to filter user input with Llama Guard
608
+ def filter_input_with_llama_guard(user_input, model="llama-guard-3-8b"):
609
+ """
610
+ Filters user input using Llama Guard to ensure it is safe.
611
+
612
+ Parameters:
613
+ - user_input: The input provided by the user.
614
+ - model: The Llama Guard model to be used for filtering (default is "llama-guard-3-8b").
615
+
616
+ Returns:
617
+ - The filtered and safe input.
618
+ """
619
+ try:
620
+ # Create a request to Llama Guard to filter the user input
621
+ response = llama_guard_client.chat.completions.create(
622
+ messages=[{"role": "user", "content": user_input}],
623
+ model=model,
624
+ )
625
+ # Return the filtered input
626
+ return response.choices[0].message.content.strip()
627
+ except Exception as e:
628
+ print(f"Error with Llama Guard: {e}")
629
+ return None
630
+
631
+
632
+ #============================= Adding Memory to the agent using mem0 ===============================#
633
+
634
+ class NutritionBot:
635
+ def __init__(self):
636
+ """
637
+ Initialize the NutritionBot class, setting up memory, the LLM client, tools, and the agent executor.
638
+ """
639
+
640
+ # Initialize a memory client to store and retrieve customer interactions
641
+ self.memory = MemoryClient(api_key=MEM0_api_key)
642
+
643
+ self.client = ChatOpenAI(
644
+ model_name="gpt-4o-mini", # Specify the model to use (e.g., GPT-4 optimized version)
645
+ api_key=config.get("API_KEY"), # API key for authentication
646
+ endpoint = config.get("OPENAI_API_BASE"),
647
+ temperature=0 # Controls randomness in responses; 0 ensures deterministic results
648
+ )
649
+
650
+
651
+ # Define tools available to the chatbot, such as web search
652
+ tools = [agentic_rag]
653
+
654
+ # Define the system prompt to set the behavior of the chatbot
655
+ system_prompt = """You are a caring and knowledgeable Medical Support Agent, specializing in nutrition disorder-related guidance. Your goal is to provide accurate, empathetic, and tailored nutritional recommendations while ensuring a seamless customer experience.
656
+ Guidelines for Interaction:
657
+ Maintain a polite, professional, and reassuring tone.
658
+ Show genuine empathy for customer concerns and health challenges.
659
+ Reference past interactions to provide personalized and consistent advice.
660
+ Engage with the customer by asking about their food preferences, dietary restrictions, and lifestyle before offering recommendations.
661
+ Ensure consistent and accurate information across conversations.
662
+ If any detail is unclear or missing, proactively ask for clarification.
663
+ Always use the agentic_rag tool to retrieve up-to-date and evidence-based nutrition insights.
664
+ Keep track of ongoing issues and follow-ups to ensure continuity in support.
665
+ Your primary goal is to help customers make informed nutrition decisions that align with their health conditions and personal preferences.
666
+
667
+ """
668
+
669
+ # Build the prompt template for the agent
670
+ prompt = ChatPromptTemplate.from_messages([
671
+ ("system", system_prompt), # System instructions
672
+ ("human", "{input}"), # Placeholder for human input
673
+ ("placeholder", "{agent_scratchpad}") # Placeholder for intermediate reasoning steps
674
+ ])
675
+
676
+ # Create an agent capable of interacting with tools and executing tasks
677
+ agent = create_tool_calling_agent(self.client, tools, prompt)
678
+
679
+ # Wrap the agent in an executor to manage tool interactions and execution flow
680
+ self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
681
+
682
+ def store_customer_interaction(self, user_id: str, message: str, response: str, metadata: Dict = None):
683
+ """
684
+ Store customer interaction in memory for future reference.
685
+
686
+ Args:
687
+ user_id (str): Unique identifier for the customer.
688
+ message (str): Customer's query or message.
689
+ response (str): Chatbot's response.
690
+ metadata (Dict, optional): Additional metadata for the interaction.
691
+ """
692
+ if metadata is None:
693
+ metadata = {}
694
+
695
+ # Add a timestamp to the metadata for tracking purposes
696
+ metadata["timestamp"] = datetime.now().isoformat()
697
+
698
+ # Format the conversation for storage
699
+ conversation = [
700
+ {"role": "user", "content": message},
701
+ {"role": "assistant", "content": response}
702
+ ]
703
+
704
+ # Store the interaction in the memory client
705
+ self.memory.add(
706
+ conversation,
707
+ user_id=user_id,
708
+ output_format="v1.1",
709
+ metadata=metadata
710
+ )
711
+
712
+ def get_relevant_history(self, user_id: str, query: str) -> List[Dict]:
713
+ """
714
+ Retrieve past interactions relevant to the current query.
715
+
716
+ Args:
717
+ user_id (str): Unique identifier for the customer.
718
+ query (str): The customer's current query.
719
+
720
+ Returns:
721
+ List[Dict]: A list of relevant past interactions.
722
+ """
723
+ return self.memory.search(
724
+ query=query, # Search for interactions related to the query
725
+ user_id=user_id, # Restrict search to the specific user
726
+ limit=5 # Retrieve up to 5 relevant interactions
727
+ )
728
+
729
+ def handle_customer_query(self, user_id: str, query: str) -> str:
730
+ """
731
+ Process a customer's query and provide a response, taking into account past interactions.
732
+
733
+ Args:
734
+ user_id (str): Unique identifier for the customer.
735
+ query (str): Customer's query.
736
+
737
+ Returns:
738
+ str: Chatbot's response.
739
+ """
740
+
741
+ # Retrieve relevant past interactions for context
742
+ relevant_history = self.get_relevant_history(user_id, query)
743
+
744
+ # Build a context string from the relevant history
745
+ context = "Previous relevant interactions:\n"
746
+ for memory in relevant_history:
747
+ context += f"Customer: {memory['memory']}\n" # Customer's past messages
748
+ context += f"Support: {memory['memory']}\n" # Chatbot's past responses
749
+ context += "---\n"
750
+
751
+ # Print context for debugging purposes
752
+ print("Context: ", context)
753
+
754
+ # Prepare a prompt combining past context and the current query
755
+ prompt = f"""
756
+ Context:
757
+ {context}
758
+
759
+ Current customer query: {query}
760
+
761
+ Provide a helpful response that takes into account any relevant past interactions.
762
+ """
763
+
764
+ # Generate a response using the agent
765
+ response = self.agent_executor.invoke({"input": prompt})
766
+
767
+ # Store the current interaction for future reference
768
+ self.store_customer_interaction(
769
+ user_id=user_id,
770
+ message=query,
771
+ response=response["output"],
772
+ metadata={"type": "support_query"}
773
+ )
774
+
775
+ # Return the chatbot's response
776
+ return response['output']
777
+
778
+
779
+ #=====================User Interface using streamlit ===========================#
780
+ def nutrition_disorder_streamlit():
781
+ """
782
+ A Streamlit-based UI for the Nutrition Disorder Specialist Agent.
783
+ """
784
+ st.title("Nutrition Disorder Specialist")
785
+ st.write("Ask me anything about nutrition disorders, symptoms, causes, treatments, and more.")
786
+ st.write("Type 'exit' to end the conversation.")
787
+
788
+ # Initialize session state for chat history and user_id if they don't exist
789
+ if 'chat_history' not in st.session_state:
790
+ st.session_state.chat_history = []
791
+ if 'user_id' not in st.session_state:
792
+ st.session_state.user_id = None
793
+
794
+ # Login form: Only if user is not logged in
795
+ if st.session_state.user_id is None:
796
+ with st.form("login_form", clear_on_submit=True):
797
+ user_id = st.text_input("Please enter your name to begin:")
798
+ submit_button = st.form_submit_button("Login")
799
+ if submit_button and user_id:
800
+ st.session_state.user_id = user_id
801
+ st.session_state.chat_history.append({
802
+ "role": "assistant",
803
+ "content": f"Welcome, {user_id}! How can I help you with nutrition disorders today?"
804
+ })
805
+ st.session_state.login_submitted = True # Set flag to trigger rerun
806
+
807
+ # Trigger rerun outside the form if login was successful
808
+ if st.session_state.get("login_submitted", False):
809
+ st.session_state.pop("login_submitted")
810
+ st.rerun()
811
+ else:
812
+ # Display chat history
813
+ for message in st.session_state.chat_history:
814
+ with st.chat_message(message["role"]):
815
+ st.write(message["content"])
816
+
817
+ # Chat input
818
+ user_query = st.chat_input("Type your question here (or 'exit' to end)...")
819
+
820
+ if user_query:
821
+ # Check if user wants to exit
822
+ if user_query.lower() == "exit":
823
+ st.session_state.chat_history.append({"role": "user", "content": "exit"})
824
+ with st.chat_message("user"):
825
+ st.write("exit")
826
+ goodbye_msg = "Goodbye! Feel free to return if you have more questions about nutrition disorders."
827
+ st.session_state.chat_history.append({"role": "assistant", "content": goodbye_msg})
828
+ with st.chat_message("assistant"):
829
+ st.write(goodbye_msg)
830
+ st.session_state.user_id = None
831
+ st.rerun()
832
+ return
833
+
834
+ # Add user message to chat history
835
+ st.session_state.chat_history.append({"role": "user", "content": user_query})
836
+ with st.chat_message("user"):
837
+ st.write(user_query)
838
+
839
+ # Filter input
840
+ filtered_result = filter_input_with_llama_guard(user_query)
841
+
842
+ # Process through the agent
843
+ with st.chat_message("assistant"):
844
+ if filtered_result in ["safe", "unsafe S7", "unsafe S6"]:
845
+ try:
846
+ # Initialize chatbot if not already done
847
+ if 'chatbot' not in st.session_state:
848
+ st.session_state.chatbot = NutritionBot()
849
+
850
+ # Get response from the chatbot
851
+ response = st.session_state.chatbot.handle_customer_query(
852
+ st.session_state.user_id,
853
+ user_query
854
+ )
855
+
856
+ st.write(response)
857
+ st.session_state.chat_history.append({"role": "assistant", "content": response})
858
+ except Exception as e:
859
+ error_msg = f"Sorry, I encountered an error while processing your query. Please try again. Error: {str(e)}"
860
+ st.write(error_msg)
861
+ st.session_state.chat_history.append({"role": "assistant", "content": error_msg})
862
+ else:
863
+ inappropriate_msg = "I apologize, but I cannot process that input as it may be inappropriate. Please try again."
864
+ st.write(inappropriate_msg)
865
+ st.session_state.chat_history.append({"role": "assistant", "content": inappropriate_msg})
866
+
867
+ if __name__ == "__main__":
868
+ nutrition_disorder_streamlit()