import os import operator import functools from typing import Annotated, Sequence, TypedDict, Union, Optional from dotenv import load_dotenv from langchain_openai import ChatOpenAI from langchain_core.messages import BaseMessage, HumanMessage, AIMessage from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables import Runnable from langchain.output_parsers.openai_tools import JsonOutputKeyToolsParser from langgraph.graph import StateGraph, END from application.agents.scraper_agent import scraper_agent from application.agents.extractor_agent import extractor_agent from application.utils.logger import get_logger load_dotenv() logger = get_logger() OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") if not OPENAI_API_KEY: logger.error("OPENAI_API_KEY is missing. Please set it in your environment variables.") raise EnvironmentError("OPENAI_API_KEY not found in environment variables.") MEMBERS = ["Scraper", "Extractor"] OPTIONS = ["FINISH"] + MEMBERS SUPERVISOR_SYSTEM_PROMPT = ( "You are a supervisor tasked with managing a conversation between the following workers: {members}. " "Given the user's request and the previous messages, determine what to do next:\n" "- If the user asks to search, find, or scrape data from the web, choose 'Scraper'.\n" "- If the user asks to extract ESG emissions data from a file or PDF, choose 'Extractor'.\n" "- If the task is complete, choose 'FINISH'.\n" "- If the message is general conversation (like greetings, questions, thanks, chatting), directly respond with a message.\n" "Each worker will perform its task and report back.\n" "When you respond directly, make sure your message is friendly and helpful." ) FUNCTION_DEF = { "name": "route_or_respond", "description": "Select the next role OR respond directly.", "parameters": { "title": "RouteOrRespondSchema", "type": "object", "properties": { "next": { "title": "Next Worker", "anyOf": [{"enum": OPTIONS}], "description": "Choose next worker if needed." }, "response": { "title": "Supervisor Response", "type": "string", "description": "Respond directly if no worker action is needed." } }, "required": [], }, } class AgentState(TypedDict): messages: Annotated[Sequence[BaseMessage], operator.add] next: Optional[str] response: Optional[str] llm = ChatOpenAI(model="gpt-4o-mini", temperature=0) def agent_node(state: AgentState, agent: Runnable, name: str) -> dict: logger.info(f"Agent {name} invoked.") try: result = agent.invoke(state) logger.info(f"Agent {name} completed successfully.") return {"messages": [HumanMessage(content=result["output"], name=name)]} except Exception as e: logger.exception(f"Agent {name} failed with error: {str(e)}") raise prompt = ChatPromptTemplate.from_messages( [ ("system", SUPERVISOR_SYSTEM_PROMPT), MessagesPlaceholder(variable_name="messages"), ( "system", "Based on the conversation, either select next worker (one of: {options}) or respond directly with a message.", ), ] ).partial(options=str(OPTIONS), members=", ".join(MEMBERS)) # supervisor_chain = ( # prompt # | llm.bind_functions(functions=[FUNCTION_DEF], function_call="route_or_respond") # | JsonOutputFunctionsParser() # ) supervisor_chain = ( prompt | llm.bind_tools(tools=[FUNCTION_DEF], tool_choice="route_or_respond") | JsonOutputKeyToolsParser(key_name="route_or_respond") ) def supervisor_node(state: AgentState) -> AgentState: logger.info("Supervisor invoked.") output = supervisor_chain.invoke(state) logger.info(f"Supervisor output: {output}") if isinstance(output, list) and len(output) > 0: output = output[0] next_step = output.get("next") response = output.get("response") if not next_step and not response: raise ValueError(f"Supervisor produced invalid output: {output}") return { "messages": state["messages"], "next": next_step, "response": response, } workflow = StateGraph(AgentState) workflow.add_node("Scraper", functools.partial(agent_node, agent=scraper_agent, name="Scraper")) workflow.add_node("Extractor", functools.partial(agent_node, agent=extractor_agent, name="Extractor")) workflow.add_node("supervisor", supervisor_node) # workflow.add_node("supervisor", supervisor_chain) workflow.add_node("supervisor_response", lambda state: {"messages": [AIMessage(content=state["response"], name="Supervisor")]}) for member in MEMBERS: workflow.add_edge(member, "supervisor") def router(state: AgentState): if state.get("response"): return "supervisor_response" return state.get("next") conditional_map = {member: member for member in MEMBERS} conditional_map["FINISH"] = END conditional_map["supervisor_response"] = "supervisor_response" workflow.add_conditional_edges("supervisor", router, conditional_map) workflow.set_entry_point("supervisor") graph = workflow.compile() # # === Example Run === # if __name__ == "__main__": # logger.info("Starting the graph execution...") # initial_message = HumanMessage(content="Can you get zalando pdf link") # input_state = {"messages": [initial_message]} # for step in graph.stream(input_state): # if "__end__" not in step: # logger.info(f"Graph Step Output: {step}") # print(step) # print("----") # logger.info("Graph execution completed.")