######################## WRITE YOUR CODE HERE ######################### # Import necessary libraries import os # Interacting with the operating system (reading/writing files) import chromadb # High-performance vector database for storing/querying dense vectors from dotenv import load_dotenv # Loading environment variables from a .env file import json # Parsing and handling JSON data # LangChain imports from langchain_core.documents import Document # Document data structures from langchain_core.runnables import RunnablePassthrough # LangChain core library for running pipelines from langchain_core.output_parsers import StrOutputParser # String output parser from langchain.prompts import ChatPromptTemplate # Template for chat prompts from langchain.chains.query_constructor.base import AttributeInfo # Base classes for query construction from langchain.retrievers.self_query.base import SelfQueryRetriever # Base classes for self-querying retrievers from langchain.retrievers.document_compressors import LLMChainExtractor, CrossEncoderReranker # Document compressors from langchain.retrievers import ContextualCompressionRetriever # Contextual compression retrievers # LangChain community & experimental imports from langchain_community.vectorstores import Chroma # Implementations of vector stores like Chroma from langchain_community.document_loaders import PyPDFDirectoryLoader, PyPDFLoader # Document loaders for PDFs from langchain_community.cross_encoders import HuggingFaceCrossEncoder # Cross-encoders from HuggingFace from langchain_experimental.text_splitter import SemanticChunker # Experimental text splitting methods from langchain.text_splitter import ( CharacterTextSplitter, # Splitting text by characters RecursiveCharacterTextSplitter # Recursive splitting of text by characters ) from langchain_core.tools import tool from langchain.agents import create_tool_calling_agent, AgentExecutor from langchain_core.prompts import ChatPromptTemplate # LangChain OpenAI imports from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI # OpenAI embeddings and models from langchain.embeddings.openai import OpenAIEmbeddings # OpenAI embeddings for text vectors # LlamaParse & LlamaIndex imports from llama_parse import LlamaParse # Document parsing library from llama_index.core import Settings, SimpleDirectoryReader # Core functionalities of the LlamaIndex # LangGraph import from langgraph.graph import StateGraph, END, START # State graph for managing states in LangChain # Pydantic import from pydantic import BaseModel # Pydantic for data validation # Typing imports from typing import Dict, List, Tuple, Any, TypedDict # Python typing for function annotations # Other utilities import numpy as np # Numpy for numerical operations from groq import Groq from mem0 import MemoryClient import streamlit as st from datetime import datetime #====================================SETUP=====================================# # Fetch secrets from Hugging Face Spaces api_key = os.environ['AZURE_OPENAI_API_KEY'] endpoint = os.environ['AZURE_OPENAI_ENDPOINT'] api_version = os.environ['AZURE_OPENAI_APIVERSION'] model_name = os.environ['CHATGPT_MODEL'] emb_key = os.environ['EMB_MODEL_KEY'] emb_endpoint = os.environ['EMB_DEPLOYMENT'] #llama_api_key = os.environ['GROQ_API_KEY'] llama_api_key = os.environ['LLAMA_API_KEY'] # Initialize the OpenAI embedding function for Chroma embedding_function = chromadb.utils.embedding_functions.OpenAIEmbeddingFunction( # api_base=_____, # Complete the code to define the API base endpoint # api_key=_____, # Complete the code to define the API key api_base= emb_endpoint, # Complete the code to define the API base endpoint api_key= emb_key, # Complete the code to define the API key api_type='azure', # This is a fixed value and does not need modification api_version='2023-05-15', # This is a fixed value and does not need modification model_name='text-embedding-ada-002' # This is a fixed value and does not need modification ) # This initializes the OpenAI embedding function for the Chroma vectorstore, using the provided Azure endpoint and API key. # Initialize the Azure OpenAI Embeddings embedding_model = AzureOpenAIEmbeddings( # azure_endpoint=_____, # Complete the code to define the Azure endpoint # api_key=_____, # Complete the code to define the API key azure_endpoint= emb_endpoint, # Complete the code to define the Azure endpoint api_key= emb_key, # Complete the code to define the API key api_version='2023-05-15', # This is a fixed value and does not need modification model='text-embedding-ada-002' # This is a fixed value and does not need modification ) # This initializes the Azure OpenAI embeddings model using the specified endpoint, API key, and model name. # Initialize the Azure Chat OpenAI model llm = AzureChatOpenAI( azure_endpoint=endpoint, api_key=api_key, api_version='2024-05-01-preview', azure_deployment='gpt-4o', temperature=0 ) # This initializes the Azure Chat OpenAI model with the provided endpoint, API key, deployment name, and a temperature setting of 0 (to control response variability). # set the LLM and embedding model in the LlamaIndex settings. # Settings.llm = _____ # Complete the code to define the LLM model # Settings.embedding = _____ # Complete the code to define the embedding model Settings.llm = llm # Complete the code to define the LLM model Settings.embedding = embedding_model # Complete the code to define the embedding model #================================Creating Langgraph agent======================# class AgentState(TypedDict): query: str # The current user query expanded_query: str # The expanded version of the user query context: List[Dict[str, Any]] # Retrieved documents (content and metadata) response: str # The generated response to the user query precision_score: float # The precision score of the response groundedness_score: float # The groundedness score of the response groundedness_loop_count: int # Counter for groundedness refinement loops precision_loop_count: int # Counter for precision refinement loops feedback: str query_feedback: str groundedness_check: bool loop_max_iter: int def expand_query(state): """ Expands the user query to improve retrieval of nutrition disorder-related information. Args: state (Dict): The current state of the workflow, containing the user query. Returns: Dict: The updated state with the expanded query. """ print("---------Expanding Query---------") #system_message = '''________________________''' system_message = """ You are a domain expert assisting in answering questions related to nutrition disorder-related information. Convert the user query into something that a nutritionist would understand. Use domain related words. Perform query expansion on the question received. If there are multiple common ways of phrasing a user question \ or common synonyms for key words in the question, make sure to return multiple versions \ of the query with the different phrasings. If the query has multiple parts, split them into separate simpler queries. This is the only case where you can generate more than 3 queries. If there are acronyms or words you are not familiar with, do not try to rephrase them. Return only 3 versions of the question as a list. Generate only a list of questions. Do not mention anything before or after the list. """ expand_prompt = ChatPromptTemplate.from_messages([ ("system", system_message), ("user", "Expand this query: {query} using the feedback: {query_feedback}") ]) chain = expand_prompt | llm | StrOutputParser() expanded_query = chain.invoke({"query": state['query'], "query_feedback":state["query_feedback"]}) print("expanded_query", expanded_query) state["expanded_query"] = expanded_query return state print("Current Working Directory:", os.getcwd()) # Initialize the Chroma vector store for retrieving documents vector_store = Chroma( collection_name="nutritional-medical-reference", persist_directory="./research_db", embedding_function=embedding_model ) # Create a retriever from the vector store retriever = vector_store.as_retriever( search_type='similarity', search_kwargs={'k': 3} ) def retrieve_context(state): """ Retrieves context from the vector store using the expanded or original query. Args: state (Dict): The current state of the workflow, containing the query and expanded query. Returns: Dict: The updated state with the retrieved context. """ print("---------retrieve_context---------") #query = state['_____'] # Complete the code to define the key for the expanded query query = state['expanded_query'] # Complete the code to define the key for the expanded query #print("Query used for retrieval:", query) # Debugging: Print the query # Retrieve documents from the vector store docs = retriever.invoke(query) print("Retrieved documents:", docs) # Debugging: Print the raw docs object # Extract both page_content and metadata from each document context= [ { "content": doc.page_content, # The actual content of the document "metadata": doc.metadata # The metadata (e.g., source, page number, etc.) } for doc in docs ] #state['_____'] = context # Complete the code to define the key for storing the context state['context'] = context # Complete the code to define the key for storing the context print("Extracted context with metadata:", context) # Debugging: Print the extracted context #print(f"Groundedness loop count: {state['groundedness_loop_count']}") return state def craft_response(state: Dict) -> Dict: """ Generates a response using the retrieved context, focusing on nutrition disorders. Args: state (Dict): The current state of the workflow, containing the query and retrieved context. Returns: Dict: The updated state with the generated response. """ print("---------craft_response---------") #system_message = '''________________________''' system_message = """ You are a knowledgeable nutritionist specialized in nutrition and health. Use the provided context to generate a helpful, accurate, and empathetic response to the user's query. Focus on identifying, explaining, or addressing nutrition disorders where relevant. Be clear and concise. """ response_prompt = ChatPromptTemplate.from_messages([ ("system", system_message), ("user", "Query: {query}\nContext: {context}\n\nfeedback: {feedback}") ]) chain = response_prompt | llm response = chain.invoke({ "query": state['query'], "context": "\n".join([doc["content"] for doc in state['context']]), #"feedback": ________________ # add feedback to the prompt "feedback": state['feedback'] # add feedback to the prompt }) state['response'] = response print("intermediate response: ", response) return state def score_groundedness(state: Dict) -> Dict: """ Checks whether the response is grounded in the retrieved context. Args: state (Dict): The current state of the workflow, containing the response and context. Returns: Dict: The updated state with the groundedness score. """ print("---------check_groundedness---------") #system_message = '''________________________''' system_message = '''You are an objective evaluator tasked with scoring the groundedness of a response based on the retrieved context provided. Definition of "groundedness": - A response is considered grounded if it strictly uses information present in the provided context. - It should avoid hallucinating, fabricating, or introducing any claims that are not explicitly supported by the context. Scoring Guidelines: - Return a numeric score between 0 and 1. - 1.0: The response is entirely grounded in the context. - 0.5: The response is partially grounded (some parts supported, others not). - 0.0: The response is not grounded at all (hallucinated or irrelevant). Important: - Do NOT explain your score. - Do NOT provide justification. - ONLY return the score as a number (e.g., 1.0, 0.5, or 0.0). ''' groundedness_prompt = ChatPromptTemplate.from_messages([ ("system", system_message), ("user", "Context: {context}\nResponse: {response}\n\nGroundedness score:") ]) chain = groundedness_prompt | llm | StrOutputParser() groundedness_score = float(chain.invoke({ "context": "\n".join([doc["content"] for doc in state['context']]), #"response": __________ # Complete the code to define the response "response": state['response'] # Complete the code to define the response })) print("groundedness_score: ", groundedness_score) state['groundedness_loop_count'] += 1 print("#########Groundedness Incremented###########") state['groundedness_score'] = groundedness_score return state def check_precision(state: Dict) -> Dict: """ Checks whether the response precisely addresses the user’s query. Args: state (Dict): The current state of the workflow, containing the query and response. Returns: Dict: The updated state with the precision score. """ print("---------check_precision---------") system_message = '''________________________''' system_message = '''Given question, answer and context verify if the context was useful in arriving at the given answer. Give verdict as "1" if useful and "0" if not ''' precision_prompt = ChatPromptTemplate.from_messages([ ("system", system_message), ("user", "Query: {query}\nResponse: {response}\n\nPrecision score:") ]) #chain = _____________ | llm | StrOutputParser() # Complete the code to define the chain of processing chain = precision_prompt | llm | StrOutputParser() # Complete the code to define the chain of processing precision_score = float(chain.invoke({ "query": state['query'], #"response":______________ # Complete the code to access the response from the state "response":state['response'] # Complete the code to access the response from the state })) state['precision_score'] = precision_score print("precision_score:", precision_score) state['precision_loop_count'] +=1 print("#########Precision Incremented###########") return state def refine_response(state: Dict) -> Dict: """ Suggests improvements for the generated response. Args: state (Dict): The current state of the workflow, containing the query and response. Returns: Dict: The updated state with response refinement suggestions. """ print("---------refine_response---------") #system_message = '''________________________''' system_message = '''You are a response refinement expert tasked with reviewing and improving AI-generated answers. Your role is to: - Carefully analyze the given response in light of the original user query. - Identify any factual inaccuracies, gaps, or lack of clarity. - Suggest improvements that make the response more complete, precise, and aligned with the query intent. Guidelines: - Be constructive and focused. - Suggest rewordings, additions, or clarifications where needed. - Highlight if any information is missing or should be cited. - Avoid introducing new facts unless they are universally accepted and directly relevant. Output Format: - ONLY return specific suggestions for improving the response. - Do NOT rewrite the full response. - Do NOT return general praise. Focus on actionable refinements.''' refine_response_prompt = ChatPromptTemplate.from_messages([ ("system", system_message), ("user", "Query: {query}\nResponse: {response}\n\n" "What improvements can be made to enhance accuracy and completeness?") ]) chain = refine_response_prompt | llm| StrOutputParser() # Store response suggestions in a structured format feedback = f"Previous Response: {state['response']}\nSuggestions: {chain.invoke({'query': state['query'], 'response': state['response']})}" print("feedback: ", feedback) print(f"State: {state}") state['feedback'] = feedback return state def refine_query(state: Dict) -> Dict: """ Suggests improvements for the expanded query. Args: state (Dict): The current state of the workflow, containing the query and expanded query. Returns: Dict: The updated state with query refinement suggestions. """ print("---------refine_query---------") #system_message = '''________________________''' system_message = ''' You are an expert in information retrieval and query optimization. Your job is to analyze an expanded search query that was generated from a user's original question, and suggest specific improvements that will help a search or retrieval system return more relevant, high-quality results. Guidelines: - Ensure the expanded query is clear, concise, and aligned with the user's original intent. - Eliminate any ambiguity or redundancy. - Suggest adding important synonyms, rephrasings, or domain-specific terminology if helpful. - Avoid suggesting overly broad or overly narrow queries. - Do NOT rewrite the query. Just offer targeted suggestions for improvement. Output Format: - Provide bullet-point suggestions for improving the expanded query. - Focus on changes that will improve retrieval quality without losing the user's intent. ''' refine_query_prompt = ChatPromptTemplate.from_messages([ ("system", system_message), ("user", "Original Query: {query}\nExpanded Query: {expanded_query}\n\n" "What improvements can be made for a better search?") ]) chain = refine_query_prompt | llm | StrOutputParser() # Store refinement suggestions without modifying the original expanded query query_feedback = f"Previous Expanded Query: {state['expanded_query']}\nSuggestions: {chain.invoke({'query': state['query'], 'expanded_query': state['expanded_query']})}" print("query_feedback: ", query_feedback) print(f"Groundedness loop count: {state['groundedness_loop_count']}") state['query_feedback'] = query_feedback return state def should_continue_groundedness(state): """Decides if groundedness is sufficient or needs improvement.""" print("---------should_continue_groundedness---------") print("groundedness loop count: ", state['groundedness_loop_count']) #if state['groundedness_score'] >= _____: # Complete the code to define the threshold for groundedness if state['groundedness_score'] >= 0.5: # Complete the code to define the threshold for groundedness print("Moving to precision") return "check_precision" else: if state["groundedness_loop_count"] > state['loop_max_iter']: return "max_iterations_reached" else: print(f"---------Groundedness Score Threshold Not met. Refining Response-----------") return "refine_response" def should_continue_precision(state: Dict) -> str: """Decides if precision is sufficient or needs improvement.""" print("---------should_continue_precision---------") #print("precision loop count: ", ___________) print("precision loop count: ",state['precision_loop_count']) #if ___________: # Threshold for precision if state['precision_score']==1.0: # Threshold for precision return "pass" # Complete the workflow else: #if ___________: # Maximum allowed loops if state['precision_loop_count'] >= 3: # Maximum allowed loops return "max_iterations_reached" else: print(f"---------Precision Score Threshold Not met. Refining Query-----------") # Debugging #return ____________ # Refine the query return "refine_query" def max_iterations_reached(state: Dict) -> Dict: """Handles the case when the maximum number of iterations is reached.""" print("---------max_iterations_reached---------") """Handles the case when the maximum number of iterations is reached.""" response = "I'm unable to refine the response further. Please provide more context or clarify your question." state['response'] = response return state from langgraph.graph import END, StateGraph, START def create_workflow() -> StateGraph: """Creates the updated workflow for the AI nutrition agent.""" #workflow = StateGraph(__________) workflow = StateGraph(AgentState) # Add processing nodes #workflow.add_node("expand_query", ___________) # Step 1: Expand user query. workflow.add_node("expand_query", expand_query) # Step 1: Expand user query. #workflow.add_node("retrieve_context", ___________) # Step 2: Retrieve relevant documents. workflow.add_node("retrieve_context", retrieve_context) # Step 2: Retrieve relevant documents. #workflow.add_node("craft_response", ___________) # Step 3: Generate a response based on retrieved data. workflow.add_node("craft_response", craft_response) # Step 3: Generate a response based on retrieved data. #workflow.add_node("score_groundedness", ___________) # Step 4: Evaluate response grounding. workflow.add_node("score_groundedness", score_groundedness) # Step 4: Evaluate response grounding. #workflow.add_node("refine_response", ___________) # Step 5: Improve response if it's weakly grounded. workflow.add_node("refine_response", refine_response) # Step 5: Improve response if it's weakly grounded. #workflow.add_node("check_precision", ___________) # Step 6: Evaluate response precision. workflow.add_node("check_precision", check_precision) # Step 6: Evaluate response precision. #workflow.add_node("refine_query", ___________) # Step 7: Improve query if response lacks precision. workflow.add_node("refine_query", refine_query) # Step 7: Improve query if response lacks precision. #workflow.add_node("max_iterations_reached", ___________) # Step 8: Handle max iterations. workflow.add_node("max_iterations_reached", max_iterations_reached) # Step 8: Handle max iterations. # Main flow edges #workflow.add_edge(__________, ___________) # workflow.add_edge(__________, ___________) # workflow.add_edge(__________, ___________) # workflow.add_edge(__________, ___________) workflow.add_edge(START, "expand_query") workflow.add_edge("expand_query", "retrieve_context") workflow.add_edge("retrieve_context", "craft_response") workflow.add_edge("craft_response", "score_groundedness") # Conditional edges based on groundedness check workflow.add_conditional_edges( "score_groundedness", should_continue_groundedness, # Use the conditional function { "check_precision": "check_precision", # If well-grounded, proceed to precision check. "refine_response": "refine_response", # If not, refine the response. "max_iterations_reached": "max_iterations_reached" # If max loops reached, exit. } ) #workflow.add_edge(__________, ___________) # Refined responses are reprocessed. workflow.add_edge("refine_response", "craft_response") # Refined responses are reprocessed. # Conditional edges based on precision check workflow.add_conditional_edges( "check_precision", should_continue_precision, # Use the conditional function { "pass": END, # If precise, complete the workflow. "refine_query": "refine_query", # If imprecise, refine the query. "max_iterations_reached": "max_iterations_reached" # If max loops reached, exit. } ) # workflow.add_edge(__________, ___________) # Refined queries go through expansion again. # workflow.add_edge(__________, ___________) workflow.add_edge("refine_query", "expand_query") # Refined queries go through expansion again. workflow.add_edge("max_iterations_reached", END) return workflow #=========================== Defining the agentic rag tool ====================# WORKFLOW_APP = create_workflow().compile() @tool def agentic_rag(query: str): """ Runs the RAG-based agent with conversation history for context-aware responses. Args: query (str): The current user query. Returns: Dict[str, Any]: The updated state with the generated response and conversation history. """ # Initialize state with necessary parameters inputs = { "query": query, # Current user query "expanded_query": "", # Complete the code to define the expanded version of the query "context": [], # Retrieved documents (initially empty) "response": "", # Complete the code to define the AI-generated response "precision_score": 0.0, # Complete the code to define the precision score of the response "groundedness_score": 0.0, # Complete the code to define the groundedness score of the response "groundedness_loop_count": 0, # Complete the code to define the counter for groundedness loops "precision_loop_count": 0, # Complete the code to define the counter for precision loops "feedback": "", # Complete the code to define the feedback "query_feedback": "", # Complete the code to define the query feedback "loop_max_iter": 3 # Complete the code to define the maximum number of iterations for loops } output = WORKFLOW_APP.invoke(inputs) return output #================================ Guardrails ===========================# llama_guard_client = Groq(api_key=llama_api_key) # Function to filter user input with Llama Guard #def filter_input_with_llama_guard(user_input, model="meta-llama/llama-guard-4-12b"): def filter_input_with_llama_guard(user_input, model="meta-llama/llama-guard-4-12b"): """ Filters user input using Llama Guard to ensure it is safe. Parameters: - user_input: The input provided by the user. - model: The Llama Guard model to be used for filtering (default is "llama-guard-3-8b"). Returns: - The filtered and safe input. """ try: # Create a request to Llama Guard to filter the user input response = llama_guard_client.chat.completions.create( messages=[{"role": "user", "content": user_input}], model=model, ) # Return the filtered input return response.choices[0].message.content.strip() except Exception as e: print(f"Error with Llama Guard: {e}") return None #============================= Adding Memory to the agent using mem0 ===============================# class NutritionBot: def __init__(self): # Initialize a memory client to store and retrieve customer interactions #self.memory = MemoryClient(os.environ["mem0"]) # Complete the code to define the memory client API key try: self.memory = MemoryClient(os.environ["mem0"]) except Exception as e: st.error(f"Failed to initialize MemoryClient: {e}") #self.memory = MemoryClient(api_key=userdata.get("mem0")) # Complete the code to define the memory client API key # Initialize the Azure OpenAI client using the provided credentials self.client = AzureChatOpenAI( # model_name="_____", # Specify the model to use (e.g., GPT-4 optimized version) # api_key=config['_____'], # API key for authentication # azure_endpoint=config['_____'], # Endpoint URL for Azure OpenAI # api_version=config['_____'], # API version being used # temperature=_____ # Controls randomness in responses; 0 ensures deterministic results model_name= model_name, # Specify the model to use (e.g., GPT-4 optimized version) api_key= api_key, # API key for authentication azure_endpoint= endpoint, # Endpoint URL for Azure OpenAI api_version= api_version, # API version being used temperature=0 # Controls randomness in responses; 0 ensures deterministic results ) """ Initialize the NutritionBot class, setting up memory, the LLM client, tools, and the agent executor. """ # Define tools available to the chatbot, such as web search tools = [agentic_rag] # Define the system prompt to set the behavior of the chatbot system_prompt = """You are a helpful nutrition assistant. Answer user questions about nutrition disorders accurately, clearly, and respectfully using available information.""" # Build the prompt template for the agent prompt = ChatPromptTemplate.from_messages([ ("system", system_prompt), # System instructions ("human", "{input}"), # Placeholder for human input ("placeholder", "{agent_scratchpad}") # Placeholder for intermediate reasoning steps ]) # Create an agent capable of interacting with tools and executing tasks agent = create_tool_calling_agent(self.client, tools, prompt) # Wrap the agent in an executor to manage tool interactions and execution flow self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) def store_customer_interaction(self, user_id: str, message: str, response: str, metadata: Dict = None): """ Store customer interaction in memory for future reference. Args: user_id (str): Unique identifier for the customer. message (str): Customer's query or message. response (str): Chatbot's response. metadata (Dict, optional): Additional metadata for the interaction. """ if metadata is None: metadata = {} # Add a timestamp to the metadata for tracking purposes metadata["timestamp"] = datetime.now().isoformat() # Format the conversation for storage conversation = [ {"role": "user", "content": message}, {"role": "assistant", "content": response} ] # Store the interaction in the memory client self.memory.add( conversation, user_id=user_id, output_format="v1.1", metadata=metadata ) def get_relevant_history(self, user_id: str, query: str) -> List[Dict]: """ Retrieve past interactions relevant to the current query. Args: user_id (str): Unique identifier for the customer. query (str): The customer's current query. Returns: List[Dict]: A list of relevant past interactions. """ return self.memory.search( query=query, # Search for interactions related to the query user_id=user_id, # Restrict search to the specific user limit= 3 # Complete the code to define the limit for retrieved interactions ) def handle_customer_query(self, user_id: str, query: str) -> str: """ Process a customer's query and provide a response, taking into account past interactions. Args: user_id (str): Unique identifier for the customer. query (str): Customer's query. Returns: str: Chatbot's response. """ # Retrieve relevant past interactions for context relevant_history = self.get_relevant_history(user_id, query) # Build a context string from the relevant history context = "Previous relevant interactions:\n" for memory in relevant_history: context += f"Customer: {memory['memory']}\n" # Customer's past messages context += f"Support: {memory['memory']}\n" # Chatbot's past responses context += "---\n" # Print context for debugging purposes print("Context: ", context) # Prepare a prompt combining past context and the current query # prompt = f""" # Context: # {context} # Current customer query: {query} # Provide a helpful response that takes into account any relevant past interactions. # """ prompt = f"{context}\n\nUser: {query}" # Generate a response using the agent response = self.agent_executor.invoke({"input": prompt}) # Store the current interaction for future reference self.store_customer_interaction( user_id=user_id, message=query, response=response["output"], metadata={"type": "support_query"} ) # Return the chatbot's response return response['output'] #=====================User Interface using streamlit ===========================# def nutrition_disorder_streamlit(): """ A Streamlit-based UI for the Nutrition Disorder Specialist Agent. """ st.title("Nutrition Disorder Specialist") st.write("Ask me anything about nutrition disorders, symptoms, causes, treatments, and more.") st.write("Type 'exit' to end the conversation.") # Initialize session state for chat history and user_id if they don't exist if 'chat_history' not in st.session_state: st.session_state.chat_history = [] if 'user_id' not in st.session_state: st.session_state.user_id = None # Login form: Only if user is not logged in if st.session_state.user_id is None: with st.form("login_form", clear_on_submit=True): user_id = st.text_input("Please enter your name to begin:") submit_button = st.form_submit_button("Login") if submit_button and user_id: st.session_state.user_id = user_id st.session_state.chat_history.append({ "role": "assistant", "content": f"Welcome, {user_id}! How can I help you with nutrition disorders today?" }) st.session_state.login_submitted = True # Set flag to trigger rerun if st.session_state.get("login_submitted", False): st.session_state.pop("login_submitted") st.rerun() else: # Display chat history for message in st.session_state.chat_history: with st.chat_message(message["role"]): st.write(message["content"]) # Chat input with custom placeholder text #user_query = st.chat_input(__________) # Blank #1: Fill in the chat input prompt (e.g., "Type your question here (or 'exit' to end)...") user_query = st.chat_input("Type your question here (or 'exit' to end)...") if user_query: if user_query.lower() == "exit": st.session_state.chat_history.append({"role": "user", "content": "exit"}) with st.chat_message("user"): st.write("exit") goodbye_msg = "Goodbye! Feel free to return if you have more questions about nutrition disorders." st.session_state.chat_history.append({"role": "assistant", "content": goodbye_msg}) with st.chat_message("assistant"): st.write(goodbye_msg) st.session_state.user_id = None st.rerun() return st.session_state.chat_history.append({"role": "user", "content": user_query}) with st.chat_message("user"): st.write(user_query) # Filter input using Llama Guard #filtered_result = __________(user_query) # Blank #2: Fill in with the function name for filtering input (e.g., filter_input_with_llama_guard) filtered_result = filter_input_with_llama_guard(user_query) filtered_result = filtered_result.replace("\n", " ") # Normalize the result # Check if input is safe based on allowed statuses #if filtered_result in [__________, __________, __________]: # Blanks #3, #4, #5: Fill in with allowed safe statuses (e.g., "safe", "unsafe S7", "unsafe S6") if filtered_result in ["safe", "unsafe S6", "unsafe S7"]: # Blanks #3, #4, #5: Fill in with allowed safe statuses (e.g., "safe", "unsafe S7", "unsafe S6") try: if 'chatbot' not in st.session_state: #st.session_state.chatbot = __________() # Blank #6: Fill in with the chatbot class initialization (e.g., NutritionBot) st.session_state.chatbot = NutritionBot() #response = st.session_state.chatbot.__________(st.session_state.user_id, user_query) response = st.session_state.chatbot.handle_customer_query(st.session_state.user_id, user_query) # Blank #7: Fill in with the method to handle queries (e.g., handle_customer_query) st.write(response) st.session_state.chat_history.append({"role": "assistant", "content": response}) except Exception as e: error_msg = f"Sorry, I encountered an error while processing your query. Please try again. Error: {str(e)}" st.write(error_msg) st.session_state.chat_history.append({"role": "assistant", "content": error_msg}) else: inappropriate_msg = "I apologize, but I cannot process that input as it may be inappropriate. Please try again." st.write(inappropriate_msg) st.session_state.chat_history.append({"role": "assistant", "content": inappropriate_msg}) if __name__ == "__main__": nutrition_disorder_streamlit()