File size: 5,900 Bytes
75115cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
540db73
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import os
import operator
import functools
from typing import Annotated, Sequence, TypedDict, Union, Optional

from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import Runnable
from langchain.output_parsers.openai_tools import JsonOutputKeyToolsParser
from langgraph.graph import StateGraph, END

from application.agents.scraper_agent import scraper_agent
from application.agents.extractor_agent import extractor_agent
from application.utils.logger import get_logger

load_dotenv()
logger = get_logger()

OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
if not OPENAI_API_KEY:
    logger.error("OPENAI_API_KEY is missing. Please set it in your environment variables.")
    raise EnvironmentError("OPENAI_API_KEY not found in environment variables.")

MEMBERS = ["Scraper", "Extractor"]
OPTIONS = ["FINISH"] + MEMBERS

SUPERVISOR_SYSTEM_PROMPT = (
    "You are a supervisor tasked with managing a conversation between the following workers: {members}. "
    "Given the user's request and the previous messages, determine what to do next:\n"
    "- If the user asks to search, find, or scrape data from the web, choose 'Scraper'.\n"
    "- If the user asks to extract ESG emissions data from a file or PDF, choose 'Extractor'.\n"
    "- If the task is complete, choose 'FINISH'.\n"
    "- If the message is general conversation (like greetings, questions, thanks, chatting), directly respond with a message.\n"
    "Each worker will perform its task and report back.\n"
    "When you respond directly, make sure your message is friendly and helpful."
)

FUNCTION_DEF = {
    "name": "route_or_respond",
    "description": "Select the next role OR respond directly.",
    "parameters": {
        "title": "RouteOrRespondSchema",
        "type": "object",
        "properties": {
            "next": {
                "title": "Next Worker",
                "anyOf": [{"enum": OPTIONS}],
                "description": "Choose next worker if needed."
            },
            "response": {
                "title": "Supervisor Response",
                "type": "string",
                "description": "Respond directly if no worker action is needed."
            }
        },
        "required": [],
    },
}

class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]
    next: Optional[str]
    response: Optional[str]

llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)

def agent_node(state: AgentState, agent: Runnable, name: str) -> dict:
    logger.info(f"Agent {name} invoked.")
    try:
        result = agent.invoke(state)
        logger.info(f"Agent {name} completed successfully.")
        return {"messages": [HumanMessage(content=result["output"], name=name)]}
    except Exception as e:
        logger.exception(f"Agent {name} failed with error: {str(e)}")
        raise

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", SUPERVISOR_SYSTEM_PROMPT),
        MessagesPlaceholder(variable_name="messages"),
        (
            "system",
            "Based on the conversation, either select next worker (one of: {options}) or respond directly with a message.",
        ),
    ]
).partial(options=str(OPTIONS), members=", ".join(MEMBERS))

# supervisor_chain = (
#     prompt
#     | llm.bind_functions(functions=[FUNCTION_DEF], function_call="route_or_respond")
#     | JsonOutputFunctionsParser()
# )

supervisor_chain = (
    prompt
    | llm.bind_tools(tools=[FUNCTION_DEF], tool_choice="route_or_respond")
    | JsonOutputKeyToolsParser(key_name="route_or_respond")
)

def supervisor_node(state: AgentState) -> AgentState:
    logger.info("Supervisor invoked.")
    output = supervisor_chain.invoke(state)
    logger.info(f"Supervisor output: {output}")

    if isinstance(output, list) and len(output) > 0:
        output = output[0]  

    next_step = output.get("next")
    response = output.get("response")

    if not next_step and not response:
        raise ValueError(f"Supervisor produced invalid output: {output}")

    return {
        "messages": state["messages"],
        "next": next_step,
        "response": response,
    }

workflow = StateGraph(AgentState)

workflow.add_node("Scraper", functools.partial(agent_node, agent=scraper_agent, name="Scraper"))
workflow.add_node("Extractor", functools.partial(agent_node, agent=extractor_agent, name="Extractor"))
workflow.add_node("supervisor", supervisor_node)
# workflow.add_node("supervisor", supervisor_chain)
workflow.add_node("supervisor_response", lambda state: {"messages": [AIMessage(content=state["response"], name="Supervisor")]})

for member in MEMBERS:
    workflow.add_edge(member, "supervisor")

def router(state: AgentState):
    if state.get("response"):
        return "supervisor_response"
    return state.get("next")

conditional_map = {member: member for member in MEMBERS}
conditional_map["FINISH"] = END
conditional_map["supervisor_response"] = "supervisor_response"


workflow.add_conditional_edges("supervisor", router, conditional_map)

workflow.set_entry_point("supervisor")

graph = workflow.compile()

# # === Example Run ===
# if __name__ == "__main__":
#     logger.info("Starting the graph execution...")
#     initial_message = HumanMessage(content="Can you get zalando pdf link")
#     input_state = {"messages": [initial_message]}

#     for step in graph.stream(input_state):
#         if "__end__" not in step:
#             logger.info(f"Graph Step Output: {step}")
#             print(step)
#             print("----")

#     logger.info("Graph execution completed.")