Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
@@ -29,7 +29,7 @@ from langchain.agents import create_tool_calling_agent, AgentExecutor
|
|
29 |
from langchain_core.prompts import ChatPromptTemplate
|
30 |
|
31 |
# LangChain OpenAI imports
|
32 |
-
from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI, ChatOpenAI
|
33 |
from langchain.embeddings.openai import OpenAIEmbeddings # OpenAI embeddings for text vectors
|
34 |
|
35 |
# LlamaParse & LlamaIndex imports
|
@@ -54,7 +54,7 @@ from datetime import datetime
|
|
54 |
|
55 |
#====================================SETUP=====================================#
|
56 |
# Fetch secrets from Hugging Face Spaces
|
57 |
-
api_key = os.getenv("API_KEY")
|
58 |
endpoint = os.getenv("OPENAI_API_BASE")
|
59 |
llama_api_key = os.environ['GROQ_API_KEY']
|
60 |
MEM0_api_key = os.environ['mem0']
|
@@ -66,6 +66,7 @@ embedding_function = chromadb.utils.embedding_functions.OpenAIEmbeddingFunction(
|
|
66 |
model_name='text-embedding-ada-002' # This is a fixed value and does not need modification
|
67 |
)
|
68 |
|
|
|
69 |
|
70 |
# Initialize the OpenAI Embeddings
|
71 |
embedding_model = OpenAIEmbeddings(
|
@@ -82,11 +83,14 @@ llm = ChatOpenAI(
|
|
82 |
model="gpt-4o-mini",
|
83 |
streaming=False
|
84 |
)
|
85 |
-
|
86 |
|
87 |
# set the LLM and embedding model in the LlamaIndex settings.
|
88 |
-
Settings.llm =
|
89 |
-
Settings.embedding =
|
|
|
|
|
|
|
90 |
|
91 |
#================================Creating Langgraph agent======================#
|
92 |
|
@@ -115,29 +119,10 @@ def expand_query(state):
|
|
115 |
Dict: The updated state with the expanded query.
|
116 |
"""
|
117 |
print("---------Expanding Query---------")
|
118 |
-
system_message = '''
|
119 |
-
Your task is to
|
120 |
-
|
121 |
-
Guidelines:
|
122 |
-
- Add **specific details** where needed. Example: If a user asks about "anorexia," specify aspects like symptoms, causes, or treatment options.
|
123 |
-
- Include **related terms** to improve retrieval (e.g., “bulimia” → “bulimia nervosa vs binge eating disorder”).
|
124 |
-
- If the user provides an unclear query, suggest necessary clarifications.
|
125 |
-
- **DO NOT** answer the question. Your job is only to enhance the query.
|
126 |
-
|
127 |
-
Examples:
|
128 |
-
1. User Query: "Tell me about eating disorders."
|
129 |
-
Expanded Query: "Provide details on eating disorders, including types (e.g., anorexia nervosa, bulimia nervosa), symptoms, causes, and treatment options."
|
130 |
|
131 |
-
2. User Query: "What is anorexia?"
|
132 |
-
Expanded Query: "Explain anorexia nervosa, including its symptoms, causes, risk factors, and treatment options."
|
133 |
|
134 |
-
3. User Query: "How to treat bulimia?"
|
135 |
-
Expanded Query: "Describe treatment options for bulimia nervosa, including psychotherapy, medications, and lifestyle changes."
|
136 |
-
|
137 |
-
4. User Query: "What are the effects of malnutrition?"
|
138 |
-
Expanded Query: "Explain the effects of malnutrition on physical and mental health, including specific nutrient deficiencies and their consequences."
|
139 |
-
|
140 |
-
Now, expand the following query:'''
|
141 |
|
142 |
expand_prompt = ChatPromptTemplate.from_messages([
|
143 |
("system", system_message),
|
@@ -177,7 +162,8 @@ def retrieve_context(state):
|
|
177 |
Dict: The updated state with the retrieved context.
|
178 |
"""
|
179 |
print("---------retrieve_context---------")
|
180 |
-
query = state['
|
|
|
181 |
#print("Query used for retrieval:", query) # Debugging: Print the query
|
182 |
|
183 |
# Retrieve documents from the vector store
|
@@ -192,7 +178,9 @@ def retrieve_context(state):
|
|
192 |
}
|
193 |
for doc in docs
|
194 |
]
|
195 |
-
state['
|
|
|
|
|
196 |
print("Extracted context with metadata:", context) # Debugging: Print the extracted context
|
197 |
#print(f"Groundedness loop count: {state['groundedness_loop_count']}")
|
198 |
return state
|
@@ -209,38 +197,27 @@ def craft_response(state: Dict) -> Dict:
|
|
209 |
Returns:
|
210 |
Dict: The updated state with the generated response.
|
211 |
"""
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
- **Be direct and concise** while ensuring completeness.
|
217 |
-
- **DO NOT include information that is not present in the context.**
|
218 |
-
- If multiple sources exist, synthesize them into a coherent response.
|
219 |
-
- If the context does not fully answer the query, state what additional information is needed.
|
220 |
-
- Use bullet points when explaining complex concepts.
|
221 |
-
|
222 |
-
Example:
|
223 |
-
User Query: "What are the symptoms of anorexia nervosa?"
|
224 |
-
Context:
|
225 |
-
1. Anorexia nervosa is characterized by extreme weight loss and fear of gaining weight.
|
226 |
-
2. Common symptoms include restricted eating, distorted body image, and excessive exercise.
|
227 |
-
Response:
|
228 |
-
"Anorexia nervosa is an eating disorder characterized by extreme weight loss and an intense fear of gaining weight. Common symptoms include:
|
229 |
-
- Restricted eating
|
230 |
-
- Distorted body image
|
231 |
-
- Excessive exercise
|
232 |
-
If you or someone you know is experiencing these symptoms, it is important to seek professional help."'''
|
233 |
|
234 |
response_prompt = ChatPromptTemplate.from_messages([
|
235 |
("system", system_message),
|
236 |
-
("user", "Query: {query}\nContext: {context}\n\
|
237 |
])
|
238 |
|
239 |
-
chain = response_prompt | llm
|
240 |
-
|
241 |
"query": state['query'],
|
242 |
-
"context": "\n".join([doc["content"] for doc in state['context']])
|
|
|
|
|
|
|
243 |
})
|
|
|
|
|
|
|
244 |
return state
|
245 |
|
246 |
|
@@ -256,37 +233,9 @@ def score_groundedness(state: Dict) -> Dict:
|
|
256 |
Dict: The updated state with the groundedness score.
|
257 |
"""
|
258 |
print("---------check_groundedness---------")
|
259 |
-
system_message = '''
|
260 |
-
|
261 |
-
Guidelines:
|
262 |
-
1. **Groundedness Check**:
|
263 |
-
- Verify that the response accurately reflects the information in the context.
|
264 |
-
- Flag any unsupported claims or deviations from the context.
|
265 |
-
|
266 |
-
2. **Citation Check**:
|
267 |
-
- Ensure that the response includes citations to the source material (e.g., "According to [Source], ...").
|
268 |
-
- If citations are missing, suggest adding them.
|
269 |
|
270 |
-
3. **Scoring**:
|
271 |
-
- Assign a groundedness score between 0 and 1, where 1 means fully grounded and properly cited.
|
272 |
-
|
273 |
-
Examples:
|
274 |
-
1. Response: "Anorexia nervosa is caused by genetic factors (Source 1)."
|
275 |
-
Context: "Anorexia nervosa is influenced by genetic, environmental, and psychological factors (Source 1)."
|
276 |
-
Evaluation: "The response is grounded and properly cited. Groundedness score: 1.0."
|
277 |
-
|
278 |
-
2. Response: "Bulimia nervosa can be cured with diet alone."
|
279 |
-
Context: "Treatment for bulimia nervosa involves psychotherapy and medications (Source 2)."
|
280 |
-
Evaluation: "The response is ungrounded and lacks citations. Groundedness score: 0.2."
|
281 |
-
|
282 |
-
3. Response: "Anorexia nervosa has a high mortality rate."
|
283 |
-
Context: "Anorexia nervosa has one of the highest mortality rates among psychiatric disorders (Source 3)."
|
284 |
-
Evaluation: "The response is grounded but lacks a citation. Groundedness score: 0.7. ."
|
285 |
-
|
286 |
-
****Return only a float score (e.g., 0.9). Do not provide explanations.****
|
287 |
-
|
288 |
-
Now, evaluate the following response:
|
289 |
-
'''
|
290 |
|
291 |
groundedness_prompt = ChatPromptTemplate.from_messages([
|
292 |
("system", system_message),
|
@@ -296,12 +245,14 @@ Now, evaluate the following response:
|
|
296 |
chain = groundedness_prompt | llm | StrOutputParser()
|
297 |
groundedness_score = float(chain.invoke({
|
298 |
"context": "\n".join([doc["content"] for doc in state['context']]),
|
299 |
-
"response":
|
|
|
300 |
}))
|
301 |
-
print("groundedness_score: ",groundedness_score)
|
302 |
-
state['groundedness_loop_count'] +=1
|
303 |
print("#########Groundedness Incremented###########")
|
304 |
state['groundedness_score'] = groundedness_score
|
|
|
305 |
return state
|
306 |
|
307 |
|
@@ -317,49 +268,23 @@ def check_precision(state: Dict) -> Dict:
|
|
317 |
Dict: The updated state with the precision score.
|
318 |
"""
|
319 |
print("---------check_precision---------")
|
320 |
-
system_message = '''
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
- 1.0 → The response is fully precise, directly answering the question.
|
325 |
-
- 0.7 → The response is mostly correct but contains some generalization.
|
326 |
-
- 0.5 → The response is somewhat relevant but lacks key details.
|
327 |
-
- 0.3 → The response is vague or only partially correct.
|
328 |
-
- 0.0 → The response is incorrect or misleading.
|
329 |
-
|
330 |
-
Examples:
|
331 |
-
1. Query: "What are the symptoms of anorexia nervosa?"
|
332 |
-
Response: "The symptoms of anorexia nervosa include extreme weight loss, fear of gaining weight, and a distorted body image."
|
333 |
-
Precision Score: 1.0
|
334 |
-
|
335 |
-
2. Query: "How is bulimia nervosa treated?"
|
336 |
-
Response: "Bulimia nervosa is treated with therapy and medications."
|
337 |
-
Precision Score: 0.7
|
338 |
-
|
339 |
-
3. Query: "What causes binge eating disorder?"
|
340 |
-
Response: "Binge eating disorder is caused by a combination of genetic, psychological, and environmental factors."
|
341 |
-
Precision Score: 0.5
|
342 |
-
|
343 |
-
4. Query: "What are the effects of malnutrition?"
|
344 |
-
Response: "Malnutrition can lead to health problems."
|
345 |
-
Precision Score: 0.3
|
346 |
-
|
347 |
-
5. Query: "What is the mortality rate of anorexia nervosa?"
|
348 |
-
Response: "Anorexia nervosa is a type of eating disorder."
|
349 |
-
Precision Score: 0.0
|
350 |
-
|
351 |
-
*****Return only a float score (e.g., 0.9). Do not provide explanations.*****
|
352 |
-
Now, evaluate the following query and response:
|
353 |
-
'''
|
354 |
precision_prompt = ChatPromptTemplate.from_messages([
|
355 |
("system", system_message),
|
356 |
("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
|
357 |
])
|
358 |
|
359 |
-
chain =
|
|
|
|
|
360 |
precision_score = float(chain.invoke({
|
361 |
"query": state['query'],
|
362 |
-
"response": state
|
|
|
|
|
363 |
}))
|
364 |
state['precision_score'] = precision_score
|
365 |
print("precision_score:", precision_score)
|
@@ -381,29 +306,9 @@ def refine_response(state: Dict) -> Dict:
|
|
381 |
"""
|
382 |
print("---------refine_response---------")
|
383 |
|
384 |
-
system_message = '''
|
385 |
-
|
386 |
-
### Guidelines:
|
387 |
-
- Identify **gaps in the explanation** (missing key details).
|
388 |
-
- Highlight **unclear or vague parts** that need elaboration.
|
389 |
-
- Suggest **additional details** that should be included for better accuracy.
|
390 |
-
- Ensure the refined response is **precise** and **grounded** in the retrieved context.
|
391 |
-
|
392 |
-
### Examples:
|
393 |
-
1. Query: "What are the symptoms of anorexia nervosa?"
|
394 |
-
Response: "The symptoms include weight loss and fear of gaining weight."
|
395 |
-
Suggestions: "The response is missing key details about behavioral and emotional symptoms. Add details like 'distorted body image' and 'restrictive eating patterns.'"
|
396 |
-
|
397 |
-
2. Query: "How is bulimia nervosa treated?"
|
398 |
-
Response: "Bulimia nervosa is treated with therapy."
|
399 |
-
Suggestions: "The response is too vague. Specify the types of therapy (e.g., cognitive-behavioral therapy) and mention other treatments like nutritional counseling and medications."
|
400 |
-
|
401 |
-
3. Query: "What causes binge eating disorder?"
|
402 |
-
Response: "Binge eating disorder is caused by psychological factors."
|
403 |
-
Suggestions: "The response is incomplete. Add details about genetic and environmental factors, and explain how they contribute to the disorder."
|
404 |
|
405 |
-
Now, suggest improvements for the following response:
|
406 |
-
'''
|
407 |
|
408 |
refine_response_prompt = ChatPromptTemplate.from_messages([
|
409 |
("system", system_message),
|
@@ -433,28 +338,9 @@ def refine_query(state: Dict) -> Dict:
|
|
433 |
Dict: The updated state with query refinement suggestions.
|
434 |
"""
|
435 |
print("---------refine_query---------")
|
436 |
-
system_message = '''
|
|
|
437 |
|
438 |
-
### Guidelines:
|
439 |
-
- Add **specific keywords** to improve document retrieval.
|
440 |
-
- Identify **missing details** that should be included.
|
441 |
-
- Suggest **ways to narrow the scope** for better precision.
|
442 |
-
|
443 |
-
### Examples:
|
444 |
-
1. Original Query: "Tell me about eating disorders."
|
445 |
-
Expanded Query: "Provide details on eating disorders, including types, symptoms, causes, and treatment options."
|
446 |
-
Suggestions: "Add specific types of eating disorders like 'anorexia nervosa' and 'bulimia nervosa' to improve retrieval."
|
447 |
-
|
448 |
-
2. Original Query: "What is anorexia?"
|
449 |
-
Expanded Query: "Explain anorexia nervosa, including its symptoms and causes."
|
450 |
-
Suggestions: "Include details about treatment options and risk factors to make the query more comprehensive."
|
451 |
-
|
452 |
-
3. Original Query: "How to treat bulimia?"
|
453 |
-
Expanded Query: "Describe treatment options for bulimia nervosa."
|
454 |
-
Suggestions: "Specify types of treatments like 'cognitive-behavioral therapy' and 'medications' for better precision."
|
455 |
-
|
456 |
-
Now, suggest improvements for the following expanded query:
|
457 |
-
'''
|
458 |
|
459 |
refine_query_prompt = ChatPromptTemplate.from_messages([
|
460 |
("system", system_message),
|
@@ -477,7 +363,8 @@ def should_continue_groundedness(state):
|
|
477 |
"""Decides if groundedness is sufficient or needs improvement."""
|
478 |
print("---------should_continue_groundedness---------")
|
479 |
print("groundedness loop count: ", state['groundedness_loop_count'])
|
480 |
-
if state['groundedness_score'] >=
|
|
|
481 |
print("Moving to precision")
|
482 |
return "check_precision"
|
483 |
else:
|
@@ -491,19 +378,24 @@ def should_continue_groundedness(state):
|
|
491 |
def should_continue_precision(state: Dict) -> str:
|
492 |
"""Decides if precision is sufficient or needs improvement."""
|
493 |
print("---------should_continue_precision---------")
|
494 |
-
print("precision loop count: ",
|
495 |
-
if
|
|
|
|
|
|
|
496 |
return "pass" # Complete the workflow
|
497 |
else:
|
498 |
-
if
|
|
|
499 |
return "max_iterations_reached"
|
500 |
else:
|
501 |
print(f"---------Precision Score Threshold Not met. Refining Query-----------") # Debugging
|
502 |
-
#
|
503 |
return "refine_query" # Refine the query
|
504 |
|
505 |
|
506 |
|
|
|
507 |
def max_iterations_reached(state: Dict) -> Dict:
|
508 |
"""Handles the case when the maximum number of iterations is reached."""
|
509 |
print("---------max_iterations_reached---------")
|
@@ -514,27 +406,29 @@ def max_iterations_reached(state: Dict) -> Dict:
|
|
514 |
|
515 |
|
516 |
|
|
|
|
|
517 |
def create_workflow() -> StateGraph:
|
518 |
"""Creates the updated workflow for the AI nutrition agent."""
|
519 |
-
workflow = StateGraph(
|
|
|
520 |
|
521 |
# Add processing nodes
|
522 |
-
workflow.add_node("expand_query", expand_query) # Step 1: Expand user query.
|
523 |
-
workflow.add_node("retrieve_context", retrieve_context) # Step 2: Retrieve relevant documents.
|
524 |
-
workflow.add_node("craft_response", craft_response) # Step 3: Generate a response based on retrieved data.
|
525 |
-
workflow.add_node("score_groundedness", score_groundedness) # Step 4: Evaluate response grounding.
|
526 |
-
workflow.add_node("refine_response", refine_response) # Step 5: Improve response if it's weakly grounded.
|
527 |
-
workflow.add_node("check_precision", check_precision) # Step 6: Evaluate response precision.
|
528 |
-
workflow.add_node("refine_query", refine_query) # Step 7: Improve query if response lacks precision.
|
529 |
-
workflow.add_node("max_iterations_reached", max_iterations_reached) # Step 8: Handle max iterations.
|
530 |
-
|
|
|
531 |
# Main flow edges
|
532 |
workflow.add_edge(START, "expand_query")
|
533 |
workflow.add_edge("expand_query", "retrieve_context")
|
534 |
workflow.add_edge("retrieve_context", "craft_response")
|
535 |
workflow.add_edge("craft_response", "score_groundedness")
|
536 |
-
# workflow.add_edge("score_groundedness","groundedness_decider")
|
537 |
-
|
538 |
|
539 |
# Conditional edges based on groundedness check
|
540 |
workflow.add_conditional_edges(
|
@@ -546,6 +440,7 @@ def create_workflow() -> StateGraph:
|
|
546 |
"max_iterations_reached": "max_iterations_reached" # If max loops reached, exit.
|
547 |
}
|
548 |
)
|
|
|
549 |
workflow.add_edge("refine_response", "craft_response") # Refined responses are reprocessed.
|
550 |
|
551 |
# Conditional edges based on precision check
|
@@ -555,19 +450,17 @@ def create_workflow() -> StateGraph:
|
|
555 |
{
|
556 |
"pass": END, # If precise, complete the workflow.
|
557 |
"refine_query": "refine_query", # If imprecise, refine the query.
|
558 |
-
"max_iterations_reached": "max_iterations_reached"
|
559 |
}
|
560 |
)
|
|
|
561 |
workflow.add_edge("refine_query", "expand_query") # Refined queries go through expansion again.
|
562 |
|
563 |
workflow.add_edge("max_iterations_reached", END)
|
564 |
-
# Set entry point
|
565 |
-
# workflow.set_entry_point("expand_query")
|
566 |
|
567 |
return workflow
|
568 |
|
569 |
|
570 |
-
|
571 |
#=========================== Defining the agentic rag tool ====================#
|
572 |
WORKFLOW_APP = create_workflow().compile()
|
573 |
@tool
|
@@ -584,17 +477,16 @@ def agentic_rag(query: str):
|
|
584 |
# Initialize state with necessary parameters
|
585 |
inputs = {
|
586 |
"query": query, # Current user query
|
587 |
-
"expanded_query": "", #
|
588 |
"context": [], # Retrieved documents (initially empty)
|
589 |
-
"response": "", # AI-generated response
|
590 |
-
"precision_score": 0.0, #
|
591 |
-
"groundedness_score": 0.0, #
|
592 |
-
"groundedness_loop_count": 0, #
|
593 |
-
"precision_loop_count": 0, #
|
594 |
-
"feedback": "",
|
595 |
-
"query_feedback":"",
|
596 |
-
"loop_max_iter":
|
597 |
-
|
598 |
}
|
599 |
|
600 |
output = WORKFLOW_APP.invoke(inputs)
|
@@ -638,8 +530,9 @@ class NutritionBot:
|
|
638 |
"""
|
639 |
|
640 |
# Initialize a memory client to store and retrieve customer interactions
|
641 |
-
self.memory = MemoryClient(api_key=
|
642 |
|
|
|
643 |
self.client = ChatOpenAI(
|
644 |
model_name="gpt-4o-mini", # Specify the model to use (e.g., GPT-4 optimized version)
|
645 |
api_key=config.get("API_KEY"), # API key for authentication
|
@@ -647,7 +540,6 @@ class NutritionBot:
|
|
647 |
temperature=0 # Controls randomness in responses; 0 ensures deterministic results
|
648 |
)
|
649 |
|
650 |
-
|
651 |
# Define tools available to the chatbot, such as web search
|
652 |
tools = [agentic_rag]
|
653 |
|
@@ -679,6 +571,7 @@ class NutritionBot:
|
|
679 |
# Wrap the agent in an executor to manage tool interactions and execution flow
|
680 |
self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
|
681 |
|
|
|
682 |
def store_customer_interaction(self, user_id: str, message: str, response: str, metadata: Dict = None):
|
683 |
"""
|
684 |
Store customer interaction in memory for future reference.
|
@@ -709,6 +602,7 @@ class NutritionBot:
|
|
709 |
metadata=metadata
|
710 |
)
|
711 |
|
|
|
712 |
def get_relevant_history(self, user_id: str, query: str) -> List[Dict]:
|
713 |
"""
|
714 |
Retrieve past interactions relevant to the current query.
|
@@ -723,9 +617,12 @@ class NutritionBot:
|
|
723 |
return self.memory.search(
|
724 |
query=query, # Search for interactions related to the query
|
725 |
user_id=user_id, # Restrict search to the specific user
|
726 |
-
limit=
|
|
|
|
|
727 |
)
|
728 |
|
|
|
729 |
def handle_customer_query(self, user_id: str, query: str) -> str:
|
730 |
"""
|
731 |
Process a customer's query and provide a response, taking into account past interactions.
|
@@ -803,8 +700,6 @@ def nutrition_disorder_streamlit():
|
|
803 |
"content": f"Welcome, {user_id}! How can I help you with nutrition disorders today?"
|
804 |
})
|
805 |
st.session_state.login_submitted = True # Set flag to trigger rerun
|
806 |
-
|
807 |
-
# Trigger rerun outside the form if login was successful
|
808 |
if st.session_state.get("login_submitted", False):
|
809 |
st.session_state.pop("login_submitted")
|
810 |
st.rerun()
|
@@ -814,11 +709,11 @@ def nutrition_disorder_streamlit():
|
|
814 |
with st.chat_message(message["role"]):
|
815 |
st.write(message["content"])
|
816 |
|
817 |
-
# Chat input
|
818 |
-
user_query = st.chat_input("Type your question here (or 'exit' to end)...")
|
|
|
819 |
|
820 |
if user_query:
|
821 |
-
# Check if user wants to exit
|
822 |
if user_query.lower() == "exit":
|
823 |
st.session_state.chat_history.append({"role": "user", "content": "exit"})
|
824 |
with st.chat_message("user"):
|
@@ -831,38 +726,41 @@ def nutrition_disorder_streamlit():
|
|
831 |
st.rerun()
|
832 |
return
|
833 |
|
834 |
-
# Add user message to chat history
|
835 |
st.session_state.chat_history.append({"role": "user", "content": user_query})
|
836 |
with st.chat_message("user"):
|
837 |
st.write(user_query)
|
838 |
|
839 |
-
|
840 |
-
|
841 |
-
|
842 |
-
#
|
843 |
-
|
844 |
-
|
845 |
-
|
846 |
-
|
847 |
-
|
848 |
-
|
849 |
-
|
850 |
-
|
851 |
-
|
852 |
-
|
853 |
-
|
854 |
-
)
|
855 |
-
|
856 |
-
|
857 |
-
|
858 |
-
|
859 |
-
|
860 |
-
|
861 |
-
|
862 |
-
|
863 |
-
|
864 |
-
st.write(
|
865 |
-
st.session_state.chat_history.append({"role": "assistant", "content":
|
|
|
|
|
|
|
|
|
866 |
|
867 |
if __name__ == "__main__":
|
868 |
nutrition_disorder_streamlit()
|
|
|
29 |
from langchain_core.prompts import ChatPromptTemplate
|
30 |
|
31 |
# LangChain OpenAI imports
|
32 |
+
from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI, ChatOpenAI # OpenAI embeddings and models
|
33 |
from langchain.embeddings.openai import OpenAIEmbeddings # OpenAI embeddings for text vectors
|
34 |
|
35 |
# LlamaParse & LlamaIndex imports
|
|
|
54 |
|
55 |
#====================================SETUP=====================================#
|
56 |
# Fetch secrets from Hugging Face Spaces
|
57 |
+
api_key = os.getenv("API_KEY")
|
58 |
endpoint = os.getenv("OPENAI_API_BASE")
|
59 |
llama_api_key = os.environ['GROQ_API_KEY']
|
60 |
MEM0_api_key = os.environ['mem0']
|
|
|
66 |
model_name='text-embedding-ada-002' # This is a fixed value and does not need modification
|
67 |
)
|
68 |
|
69 |
+
# This initializes the OpenAI embedding function for the Chroma vectorstore, using the provided endpoint and API key.
|
70 |
|
71 |
# Initialize the OpenAI Embeddings
|
72 |
embedding_model = OpenAIEmbeddings(
|
|
|
83 |
model="gpt-4o-mini",
|
84 |
streaming=False
|
85 |
)
|
86 |
+
# This initializes the Chat OpenAI model with the provided endpoint, API key, deployment name, and a temperature setting of 0 (to control response variability).
|
87 |
|
88 |
# set the LLM and embedding model in the LlamaIndex settings.
|
89 |
+
# Settings.llm = _____ # Complete the code to define the LLM model
|
90 |
+
# Settings.embedding = _____ # Complete the code to define the embedding model
|
91 |
+
Settings.llm = llm # Complete the code to define the LLM model
|
92 |
+
Settings.embedding = embedding_model # Complete the code to define the embedding model
|
93 |
+
|
94 |
|
95 |
#================================Creating Langgraph agent======================#
|
96 |
|
|
|
119 |
Dict: The updated state with the expanded query.
|
120 |
"""
|
121 |
print("---------Expanding Query---------")
|
122 |
+
# system_message = '''________________________'''
|
123 |
+
system_message = '''You are a nutrition expert and language model specialized in nutritional disorders. Your task is to expand the provided query by incorporating related keywords, synonyms, and additional context that can improve the retrieval of detailed nutrition disorder-related information.'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
|
|
|
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
expand_prompt = ChatPromptTemplate.from_messages([
|
128 |
("system", system_message),
|
|
|
162 |
Dict: The updated state with the retrieved context.
|
163 |
"""
|
164 |
print("---------retrieve_context---------")
|
165 |
+
# query = state['_____'] # Complete the code to define the key for the expanded query
|
166 |
+
query = state['expanded_query'] # Complete the code to define the key for the expanded query
|
167 |
#print("Query used for retrieval:", query) # Debugging: Print the query
|
168 |
|
169 |
# Retrieve documents from the vector store
|
|
|
178 |
}
|
179 |
for doc in docs
|
180 |
]
|
181 |
+
# state['_____'] = context # Complete the code to define the key for storing the context
|
182 |
+
state['context'] = context # Complete the code to define the key for storing the context
|
183 |
+
|
184 |
print("Extracted context with metadata:", context) # Debugging: Print the extracted context
|
185 |
#print(f"Groundedness loop count: {state['groundedness_loop_count']}")
|
186 |
return state
|
|
|
197 |
Returns:
|
198 |
Dict: The updated state with the generated response.
|
199 |
"""
|
200 |
+
print("---------craft_response---------")
|
201 |
+
# system_message = '''________________________'''
|
202 |
+
system_message = '''You are a nutrition expert and your responses should be clear, concise, and evidence-based. Use the provided context to accurately address the user's query regarding nutritional disorders.'''
|
203 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
response_prompt = ChatPromptTemplate.from_messages([
|
206 |
("system", system_message),
|
207 |
+
("user", "Query: {query}\nContext: {context}\n\nfeedback: {feedback}")
|
208 |
])
|
209 |
|
210 |
+
chain = response_prompt | llm
|
211 |
+
response = chain.invoke({
|
212 |
"query": state['query'],
|
213 |
+
"context": "\n".join([doc["content"] for doc in state['context']]),
|
214 |
+
# "feedback": ________________ # add feedback to the prompt
|
215 |
+
"feedback": state.get("query_feedback", "No additional feedback provided") # add feedback to the prompt
|
216 |
+
|
217 |
})
|
218 |
+
state['response'] = response
|
219 |
+
print("intermediate response: ", response)
|
220 |
+
|
221 |
return state
|
222 |
|
223 |
|
|
|
233 |
Dict: The updated state with the groundedness score.
|
234 |
"""
|
235 |
print("---------check_groundedness---------")
|
236 |
+
# system_message = '''________________________'''
|
237 |
+
system_message = '''You are an evaluator for response groundedness. Given the context and the response related to nutritional disorders, provide a numerical score between 0 and 1 where 0 means the response is not grounded at all, and 1 means it is completely grounded in the context.'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
|
240 |
groundedness_prompt = ChatPromptTemplate.from_messages([
|
241 |
("system", system_message),
|
|
|
245 |
chain = groundedness_prompt | llm | StrOutputParser()
|
246 |
groundedness_score = float(chain.invoke({
|
247 |
"context": "\n".join([doc["content"] for doc in state['context']]),
|
248 |
+
# "response": __________ # Complete the code to define the response
|
249 |
+
"response": state['response'] # Complete the code to define the response
|
250 |
}))
|
251 |
+
print("groundedness_score: ", groundedness_score)
|
252 |
+
state['groundedness_loop_count'] += 1
|
253 |
print("#########Groundedness Incremented###########")
|
254 |
state['groundedness_score'] = groundedness_score
|
255 |
+
|
256 |
return state
|
257 |
|
258 |
|
|
|
268 |
Dict: The updated state with the precision score.
|
269 |
"""
|
270 |
print("---------check_precision---------")
|
271 |
+
# system_message = '''________________________'''
|
272 |
+
system_message = '''You are an evaluator for response precision. Given the query and the response, provide a numerical score between 0 and 1 where 0 indicates that the response does not address the query at all, and 1 indicates that the response precisely addresses the query.'''
|
273 |
+
|
274 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
precision_prompt = ChatPromptTemplate.from_messages([
|
276 |
("system", system_message),
|
277 |
("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
|
278 |
])
|
279 |
|
280 |
+
# chain = _____________ | llm | StrOutputParser() # Complete the code to define the chain of processing
|
281 |
+
chain = precision_prompt | llm | StrOutputParser() # Complete the code to define the chain of processing
|
282 |
+
|
283 |
precision_score = float(chain.invoke({
|
284 |
"query": state['query'],
|
285 |
+
# "response":______________ # Complete the code to access the response from the state
|
286 |
+
"response":state['response'] # Complete the code to access the response from the state
|
287 |
+
|
288 |
}))
|
289 |
state['precision_score'] = precision_score
|
290 |
print("precision_score:", precision_score)
|
|
|
306 |
"""
|
307 |
print("---------refine_response---------")
|
308 |
|
309 |
+
# system_message = '''________________________'''
|
310 |
+
system_message = '''You are an expert editor in nutritional science communications. Your role is to review the response given to a nutritional query and provide clear suggestions to improve its accuracy, clarity, and completeness. Focus on making sure that the response fully addresses the query and is supported by evidence-based nutritional guidelines.'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
|
|
|
|
|
312 |
|
313 |
refine_response_prompt = ChatPromptTemplate.from_messages([
|
314 |
("system", system_message),
|
|
|
338 |
Dict: The updated state with query refinement suggestions.
|
339 |
"""
|
340 |
print("---------refine_query---------")
|
341 |
+
# system_message = '''________________________'''
|
342 |
+
system_message = '''You are a search query refinement expert. Given the original and expanded queries related to nutritional disorders, provide suggestions to refine the query further for improved search results.'''
|
343 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
|
345 |
refine_query_prompt = ChatPromptTemplate.from_messages([
|
346 |
("system", system_message),
|
|
|
363 |
"""Decides if groundedness is sufficient or needs improvement."""
|
364 |
print("---------should_continue_groundedness---------")
|
365 |
print("groundedness loop count: ", state['groundedness_loop_count'])
|
366 |
+
# if state['groundedness_score'] >= _____: # Complete the code to define the threshold for groundedness
|
367 |
+
if state['groundedness_score'] >= 0.7: # Complete the code to define the threshold for groundedness
|
368 |
print("Moving to precision")
|
369 |
return "check_precision"
|
370 |
else:
|
|
|
378 |
def should_continue_precision(state: Dict) -> str:
|
379 |
"""Decides if precision is sufficient or needs improvement."""
|
380 |
print("---------should_continue_precision---------")
|
381 |
+
# print("precision loop count: ", ___________)
|
382 |
+
# if ___________: # Threshold for precision
|
383 |
+
|
384 |
+
print("precision loop count: ", state['precision_loop_count'])
|
385 |
+
if state['precision_score'] >= 0.8: # Threshold for precision
|
386 |
return "pass" # Complete the workflow
|
387 |
else:
|
388 |
+
# if ___________: # Maximum allowed loops
|
389 |
+
if state["precision_loop_count"] > state['loop_max_iter']: # Maximum allowed loops
|
390 |
return "max_iterations_reached"
|
391 |
else:
|
392 |
print(f"---------Precision Score Threshold Not met. Refining Query-----------") # Debugging
|
393 |
+
# return ____________ # Refine the query
|
394 |
return "refine_query" # Refine the query
|
395 |
|
396 |
|
397 |
|
398 |
+
|
399 |
def max_iterations_reached(state: Dict) -> Dict:
|
400 |
"""Handles the case when the maximum number of iterations is reached."""
|
401 |
print("---------max_iterations_reached---------")
|
|
|
406 |
|
407 |
|
408 |
|
409 |
+
from langgraph.graph import END, StateGraph, START
|
410 |
+
|
411 |
def create_workflow() -> StateGraph:
|
412 |
"""Creates the updated workflow for the AI nutrition agent."""
|
413 |
+
# workflow = StateGraph(_____ ) # Complete the code to define the initial state of the agent
|
414 |
+
workflow = StateGraph(dict) # Complete the code to define the initial state of the agent
|
415 |
|
416 |
# Add processing nodes
|
417 |
+
workflow.add_node("expand_query", expand_query ) # Step 1: Expand user query. Complete with the function to expand the query
|
418 |
+
workflow.add_node("retrieve_context", retrieve_context ) # Step 2: Retrieve relevant documents. Complete with the function to retrieve context
|
419 |
+
workflow.add_node("craft_response", craft_response ) # Step 3: Generate a response based on retrieved data. Complete with the function to craft a response
|
420 |
+
workflow.add_node("score_groundedness", score_groundedness ) # Step 4: Evaluate response grounding. Complete with the function to score groundedness
|
421 |
+
workflow.add_node("refine_response", refine_response ) # Step 5: Improve response if it's weakly grounded. Complete with the function to refine the response
|
422 |
+
workflow.add_node("check_precision", check_precision ) # Step 6: Evaluate response precision. Complete with the function to check precision
|
423 |
+
workflow.add_node("refine_query", refine_query ) # Step 7: Improve query if response lacks precision. Complete with the function to refine the query
|
424 |
+
workflow.add_node("max_iterations_reached", max_iterations_reached ) # Step 8: Handle max iterations. Complete with the function to handle max iterations
|
425 |
+
|
426 |
+
|
427 |
# Main flow edges
|
428 |
workflow.add_edge(START, "expand_query")
|
429 |
workflow.add_edge("expand_query", "retrieve_context")
|
430 |
workflow.add_edge("retrieve_context", "craft_response")
|
431 |
workflow.add_edge("craft_response", "score_groundedness")
|
|
|
|
|
432 |
|
433 |
# Conditional edges based on groundedness check
|
434 |
workflow.add_conditional_edges(
|
|
|
440 |
"max_iterations_reached": "max_iterations_reached" # If max loops reached, exit.
|
441 |
}
|
442 |
)
|
443 |
+
|
444 |
workflow.add_edge("refine_response", "craft_response") # Refined responses are reprocessed.
|
445 |
|
446 |
# Conditional edges based on precision check
|
|
|
450 |
{
|
451 |
"pass": END, # If precise, complete the workflow.
|
452 |
"refine_query": "refine_query", # If imprecise, refine the query.
|
453 |
+
"max_iterations_reached": "max_iterations_reached" # If max loops reached, exit.
|
454 |
}
|
455 |
)
|
456 |
+
|
457 |
workflow.add_edge("refine_query", "expand_query") # Refined queries go through expansion again.
|
458 |
|
459 |
workflow.add_edge("max_iterations_reached", END)
|
|
|
|
|
460 |
|
461 |
return workflow
|
462 |
|
463 |
|
|
|
464 |
#=========================== Defining the agentic rag tool ====================#
|
465 |
WORKFLOW_APP = create_workflow().compile()
|
466 |
@tool
|
|
|
477 |
# Initialize state with necessary parameters
|
478 |
inputs = {
|
479 |
"query": query, # Current user query
|
480 |
+
"expanded_query": "", # Complete the code to define the expanded version of the query
|
481 |
"context": [], # Retrieved documents (initially empty)
|
482 |
+
"response": "", # Complete the code to define the AI-generated response
|
483 |
+
"precision_score": 0.0, # Complete the code to define the precision score of the response
|
484 |
+
"groundedness_score": 0.0, # Complete the code to define the groundedness score of the response
|
485 |
+
"groundedness_loop_count": 0, # Complete the code to define the counter for groundedness loops
|
486 |
+
"precision_loop_count": 0, # Complete the code to define the counter for precision loops
|
487 |
+
"feedback": "", # Complete the code to define the feedback
|
488 |
+
"query_feedback": "", # Complete the code to define the query feedback
|
489 |
+
"loop_max_iter": 5 # Complete the code to define the maximum number of iterations for loops
|
|
|
490 |
}
|
491 |
|
492 |
output = WORKFLOW_APP.invoke(inputs)
|
|
|
530 |
"""
|
531 |
|
532 |
# Initialize a memory client to store and retrieve customer interactions
|
533 |
+
self.memory = MemoryClient(api_key=userdata.get("mem0")) # Complete the code to define the memory client API key
|
534 |
|
535 |
+
# Initialize the OpenAI client using the provided credentials
|
536 |
self.client = ChatOpenAI(
|
537 |
model_name="gpt-4o-mini", # Specify the model to use (e.g., GPT-4 optimized version)
|
538 |
api_key=config.get("API_KEY"), # API key for authentication
|
|
|
540 |
temperature=0 # Controls randomness in responses; 0 ensures deterministic results
|
541 |
)
|
542 |
|
|
|
543 |
# Define tools available to the chatbot, such as web search
|
544 |
tools = [agentic_rag]
|
545 |
|
|
|
571 |
# Wrap the agent in an executor to manage tool interactions and execution flow
|
572 |
self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
|
573 |
|
574 |
+
|
575 |
def store_customer_interaction(self, user_id: str, message: str, response: str, metadata: Dict = None):
|
576 |
"""
|
577 |
Store customer interaction in memory for future reference.
|
|
|
602 |
metadata=metadata
|
603 |
)
|
604 |
|
605 |
+
|
606 |
def get_relevant_history(self, user_id: str, query: str) -> List[Dict]:
|
607 |
"""
|
608 |
Retrieve past interactions relevant to the current query.
|
|
|
617 |
return self.memory.search(
|
618 |
query=query, # Search for interactions related to the query
|
619 |
user_id=user_id, # Restrict search to the specific user
|
620 |
+
# limit=_____ # Complete the code to define the limit for retrieved interactions
|
621 |
+
limit=5 # Complete the code to define the limit for retrieved interactions
|
622 |
+
|
623 |
)
|
624 |
|
625 |
+
|
626 |
def handle_customer_query(self, user_id: str, query: str) -> str:
|
627 |
"""
|
628 |
Process a customer's query and provide a response, taking into account past interactions.
|
|
|
700 |
"content": f"Welcome, {user_id}! How can I help you with nutrition disorders today?"
|
701 |
})
|
702 |
st.session_state.login_submitted = True # Set flag to trigger rerun
|
|
|
|
|
703 |
if st.session_state.get("login_submitted", False):
|
704 |
st.session_state.pop("login_submitted")
|
705 |
st.rerun()
|
|
|
709 |
with st.chat_message(message["role"]):
|
710 |
st.write(message["content"])
|
711 |
|
712 |
+
# Chat input with custom placeholder text
|
713 |
+
# user_query = st.chat_input(__________) # Blank #1: Fill in the chat input prompt (e.g., "Type your question here (or 'exit' to end)...")
|
714 |
+
user_query = st.chat_input("Type your question here (or 'exit' to end)...") # Blank #1
|
715 |
|
716 |
if user_query:
|
|
|
717 |
if user_query.lower() == "exit":
|
718 |
st.session_state.chat_history.append({"role": "user", "content": "exit"})
|
719 |
with st.chat_message("user"):
|
|
|
726 |
st.rerun()
|
727 |
return
|
728 |
|
|
|
729 |
st.session_state.chat_history.append({"role": "user", "content": user_query})
|
730 |
with st.chat_message("user"):
|
731 |
st.write(user_query)
|
732 |
|
733 |
+
|
734 |
+
|
735 |
+
# Filter input using Llama Guard
|
736 |
+
# filtered_result = __________(user_query) # Blank #2: Fill in with the function name for filtering input (e.g., filter_input_with_llama_guard)
|
737 |
+
|
738 |
+
# filtered_result = filter_input_with_llama_guard(user_query) # Blank #2: Fill in with the function name for filtering input (e.g., filter_input_with_llama_guard)
|
739 |
+
filtered_result = check_input_safety(user_query, llama_guard_client)
|
740 |
+
filtered_result = filtered_result.replace("\n", " ") # Normalize the result
|
741 |
+
|
742 |
+
# Check if input is safe based on allowed statuses
|
743 |
+
# if filtered_result in [__________, __________, __________]: # Blanks #3, #4, #5: Fill in with allowed safe statuses (e.g., "safe", "unsafe S7", "unsafe S6")
|
744 |
+
if filtered_result in ["SAFE", "S6", "S7"]: # Blanks #3, #4, #5: Fill in with allowed safe statuses (e.g., "safe", "unsafe S7", "unsafe S6")
|
745 |
+
|
746 |
+
try:
|
747 |
+
if 'chatbot' not in st.session_state:
|
748 |
+
st.session_state.chatbot = NutritionBot() # Blank #6: Fill in with the chatbot class initialization (e.g., NutritionBot)
|
749 |
+
|
750 |
+
# response = st.session_state.chatbot.__________(st.session_state.user_id, user_query)
|
751 |
+
response = st.session_state.chatbot.handle_customer_query(st.session_state.user_id, user_query)
|
752 |
+
|
753 |
+
# Blank #7: Fill in with the method to handle queries (e.g., handle_customer_query)
|
754 |
+
st.write(response)
|
755 |
+
st.session_state.chat_history.append({"role": "assistant", "content": response})
|
756 |
+
except Exception as e:
|
757 |
+
error_msg = f"Sorry, I encountered an error while processing your query. Please try again. Error: {str(e)}"
|
758 |
+
st.write(error_msg)
|
759 |
+
st.session_state.chat_history.append({"role": "assistant", "content": error_msg})
|
760 |
+
else:
|
761 |
+
inappropriate_msg = "I apologize, but I cannot process that input as it may be inappropriate. Please try again."
|
762 |
+
st.write(inappropriate_msg)
|
763 |
+
st.session_state.chat_history.append({"role": "assistant", "content": inappropriate_msg})
|
764 |
|
765 |
if __name__ == "__main__":
|
766 |
nutrition_disorder_streamlit()
|