Spaces:
Sleeping
Sleeping
import os | |
import operator | |
import functools | |
from typing import Annotated, Sequence, TypedDict, Union, Optional | |
from dotenv import load_dotenv | |
from langchain_openai import ChatOpenAI | |
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain_core.runnables import Runnable | |
from langchain.output_parsers.openai_tools import JsonOutputKeyToolsParser | |
from langgraph.graph import StateGraph, END | |
from application.agents.scraper_agent import scraper_agent | |
from application.agents.extractor_agent import extractor_agent | |
from application.utils.logger import get_logger | |
load_dotenv() | |
logger = get_logger() | |
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") | |
if not OPENAI_API_KEY: | |
logger.error("OPENAI_API_KEY is missing. Please set it in your environment variables.") | |
raise EnvironmentError("OPENAI_API_KEY not found in environment variables.") | |
MEMBERS = ["Scraper", "Extractor"] | |
OPTIONS = ["FINISH"] + MEMBERS | |
SUPERVISOR_SYSTEM_PROMPT = ( | |
"You are a supervisor tasked with managing a conversation between the following workers: {members}. " | |
"Given the user's request and the previous messages, determine what to do next:\n" | |
"- If the user asks to search, find, or scrape data from the web, choose 'Scraper'.\n" | |
"- If the user asks to extract ESG emissions data from a file or PDF, choose 'Extractor'.\n" | |
"- If the task is complete, choose 'FINISH'.\n" | |
"- If the message is general conversation (like greetings, questions, thanks, chatting), directly respond with a message.\n" | |
"Each worker will perform its task and report back.\n" | |
"When you respond directly, make sure your message is friendly and helpful." | |
) | |
FUNCTION_DEF = { | |
"name": "route_or_respond", | |
"description": "Select the next role OR respond directly.", | |
"parameters": { | |
"title": "RouteOrRespondSchema", | |
"type": "object", | |
"properties": { | |
"next": { | |
"title": "Next Worker", | |
"anyOf": [{"enum": OPTIONS}], | |
"description": "Choose next worker if needed." | |
}, | |
"response": { | |
"title": "Supervisor Response", | |
"type": "string", | |
"description": "Respond directly if no worker action is needed." | |
} | |
}, | |
"required": [], | |
}, | |
} | |
class AgentState(TypedDict): | |
messages: Annotated[Sequence[BaseMessage], operator.add] | |
next: Optional[str] | |
response: Optional[str] | |
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0) | |
def agent_node(state: AgentState, agent: Runnable, name: str) -> dict: | |
logger.info(f"Agent {name} invoked.") | |
try: | |
result = agent.invoke(state) | |
logger.info(f"Agent {name} completed successfully.") | |
return {"messages": [HumanMessage(content=result["output"], name=name)]} | |
except Exception as e: | |
logger.exception(f"Agent {name} failed with error: {str(e)}") | |
raise | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", SUPERVISOR_SYSTEM_PROMPT), | |
MessagesPlaceholder(variable_name="messages"), | |
( | |
"system", | |
"Based on the conversation, either select next worker (one of: {options}) or respond directly with a message.", | |
), | |
] | |
).partial(options=str(OPTIONS), members=", ".join(MEMBERS)) | |
# supervisor_chain = ( | |
# prompt | |
# | llm.bind_functions(functions=[FUNCTION_DEF], function_call="route_or_respond") | |
# | JsonOutputFunctionsParser() | |
# ) | |
supervisor_chain = ( | |
prompt | |
| llm.bind_tools(tools=[FUNCTION_DEF], tool_choice="route_or_respond") | |
| JsonOutputKeyToolsParser(key_name="route_or_respond") | |
) | |
def supervisor_node(state: AgentState) -> AgentState: | |
logger.info("Supervisor invoked.") | |
output = supervisor_chain.invoke(state) | |
logger.info(f"Supervisor output: {output}") | |
if isinstance(output, list) and len(output) > 0: | |
output = output[0] | |
next_step = output.get("next") | |
response = output.get("response") | |
if not next_step and not response: | |
raise ValueError(f"Supervisor produced invalid output: {output}") | |
return { | |
"messages": state["messages"], | |
"next": next_step, | |
"response": response, | |
} | |
workflow = StateGraph(AgentState) | |
workflow.add_node("Scraper", functools.partial(agent_node, agent=scraper_agent, name="Scraper")) | |
workflow.add_node("Extractor", functools.partial(agent_node, agent=extractor_agent, name="Extractor")) | |
workflow.add_node("supervisor", supervisor_node) | |
# workflow.add_node("supervisor", supervisor_chain) | |
workflow.add_node("supervisor_response", lambda state: {"messages": [AIMessage(content=state["response"], name="Supervisor")]}) | |
for member in MEMBERS: | |
workflow.add_edge(member, "supervisor") | |
def router(state: AgentState): | |
if state.get("response"): | |
return "supervisor_response" | |
return state.get("next") | |
conditional_map = {member: member for member in MEMBERS} | |
conditional_map["FINISH"] = END | |
conditional_map["supervisor_response"] = "supervisor_response" | |
workflow.add_conditional_edges("supervisor", router, conditional_map) | |
workflow.set_entry_point("supervisor") | |
graph = workflow.compile() | |
# # === Example Run === | |
# if __name__ == "__main__": | |
# logger.info("Starting the graph execution...") | |
# initial_message = HumanMessage(content="Can you get zalando pdf link") | |
# input_state = {"messages": [initial_message]} | |
# for step in graph.stream(input_state): | |
# if "__end__" not in step: | |
# logger.info(f"Graph Step Output: {step}") | |
# print(step) | |
# print("----") | |
# logger.info("Graph execution completed.") |