PDFExtractor / main.py
Vela
modified functions
540db73
import os
import operator
import functools
from typing import Annotated, Sequence, TypedDict, Union, Optional
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import Runnable
from langchain.output_parsers.openai_tools import JsonOutputKeyToolsParser
from langgraph.graph import StateGraph, END
from application.agents.scraper_agent import scraper_agent
from application.agents.extractor_agent import extractor_agent
from application.utils.logger import get_logger
load_dotenv()
logger = get_logger()
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
if not OPENAI_API_KEY:
logger.error("OPENAI_API_KEY is missing. Please set it in your environment variables.")
raise EnvironmentError("OPENAI_API_KEY not found in environment variables.")
MEMBERS = ["Scraper", "Extractor"]
OPTIONS = ["FINISH"] + MEMBERS
SUPERVISOR_SYSTEM_PROMPT = (
"You are a supervisor tasked with managing a conversation between the following workers: {members}. "
"Given the user's request and the previous messages, determine what to do next:\n"
"- If the user asks to search, find, or scrape data from the web, choose 'Scraper'.\n"
"- If the user asks to extract ESG emissions data from a file or PDF, choose 'Extractor'.\n"
"- If the task is complete, choose 'FINISH'.\n"
"- If the message is general conversation (like greetings, questions, thanks, chatting), directly respond with a message.\n"
"Each worker will perform its task and report back.\n"
"When you respond directly, make sure your message is friendly and helpful."
)
FUNCTION_DEF = {
"name": "route_or_respond",
"description": "Select the next role OR respond directly.",
"parameters": {
"title": "RouteOrRespondSchema",
"type": "object",
"properties": {
"next": {
"title": "Next Worker",
"anyOf": [{"enum": OPTIONS}],
"description": "Choose next worker if needed."
},
"response": {
"title": "Supervisor Response",
"type": "string",
"description": "Respond directly if no worker action is needed."
}
},
"required": [],
},
}
class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
next: Optional[str]
response: Optional[str]
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
def agent_node(state: AgentState, agent: Runnable, name: str) -> dict:
logger.info(f"Agent {name} invoked.")
try:
result = agent.invoke(state)
logger.info(f"Agent {name} completed successfully.")
return {"messages": [HumanMessage(content=result["output"], name=name)]}
except Exception as e:
logger.exception(f"Agent {name} failed with error: {str(e)}")
raise
prompt = ChatPromptTemplate.from_messages(
[
("system", SUPERVISOR_SYSTEM_PROMPT),
MessagesPlaceholder(variable_name="messages"),
(
"system",
"Based on the conversation, either select next worker (one of: {options}) or respond directly with a message.",
),
]
).partial(options=str(OPTIONS), members=", ".join(MEMBERS))
# supervisor_chain = (
# prompt
# | llm.bind_functions(functions=[FUNCTION_DEF], function_call="route_or_respond")
# | JsonOutputFunctionsParser()
# )
supervisor_chain = (
prompt
| llm.bind_tools(tools=[FUNCTION_DEF], tool_choice="route_or_respond")
| JsonOutputKeyToolsParser(key_name="route_or_respond")
)
def supervisor_node(state: AgentState) -> AgentState:
logger.info("Supervisor invoked.")
output = supervisor_chain.invoke(state)
logger.info(f"Supervisor output: {output}")
if isinstance(output, list) and len(output) > 0:
output = output[0]
next_step = output.get("next")
response = output.get("response")
if not next_step and not response:
raise ValueError(f"Supervisor produced invalid output: {output}")
return {
"messages": state["messages"],
"next": next_step,
"response": response,
}
workflow = StateGraph(AgentState)
workflow.add_node("Scraper", functools.partial(agent_node, agent=scraper_agent, name="Scraper"))
workflow.add_node("Extractor", functools.partial(agent_node, agent=extractor_agent, name="Extractor"))
workflow.add_node("supervisor", supervisor_node)
# workflow.add_node("supervisor", supervisor_chain)
workflow.add_node("supervisor_response", lambda state: {"messages": [AIMessage(content=state["response"], name="Supervisor")]})
for member in MEMBERS:
workflow.add_edge(member, "supervisor")
def router(state: AgentState):
if state.get("response"):
return "supervisor_response"
return state.get("next")
conditional_map = {member: member for member in MEMBERS}
conditional_map["FINISH"] = END
conditional_map["supervisor_response"] = "supervisor_response"
workflow.add_conditional_edges("supervisor", router, conditional_map)
workflow.set_entry_point("supervisor")
graph = workflow.compile()
# # === Example Run ===
# if __name__ == "__main__":
# logger.info("Starting the graph execution...")
# initial_message = HumanMessage(content="Can you get zalando pdf link")
# input_state = {"messages": [initial_message]}
# for step in graph.stream(input_state):
# if "__end__" not in step:
# logger.info(f"Graph Step Output: {step}")
# print(step)
# print("----")
# logger.info("Graph execution completed.")