PDFExtractor / application /agents /scraper_agent.py
Vela
modified functions
540db73
import os
from dotenv import load_dotenv
import json
from langchain_core.messages import ToolMessage
from typing import TypedDict, Annotated
from langgraph.graph.message import add_messages
from typing import Annotated, List
from langgraph.graph import StateGraph, END
from langchain_core.messages import ToolMessage, HumanMessage
from langchain_openai import ChatOpenAI
# Local Imports
from application.tools.web_search_tools import get_top_companies_from_web, get_sustainability_report_pdf
from application.tools.pdf_downloader_tool import download_pdf
from application.tools.emission_data_extractor import extract_emission_data_as_json
from application.services.langgraph_service import create_agent
from application.utils.logger import get_logger
# setting up environment and logger
load_dotenv()
logger = get_logger()
# Langsmith
LANGSMITH_API_KEY=os.getenv('LANGSMITH_API_KEY')
os.environ['LANGSMITH_API_KEY'] = LANGSMITH_API_KEY
os.environ['LANGCHAIN_TRACING_V2'] = 'true'
os.environ["LANGCHAIN_PROJECT"] = "Sustainability_AI"
# OpenAI api key set up
os.environ['OPENAI_API_KEY'] = os.environ.get("OPENAI_API_KEY")
class AgentState(TypedDict):
messages: Annotated[List, add_messages]
graph = StateGraph(AgentState)
model = ChatOpenAI(model= 'gpt-4o-mini', temperature=0)
tools = [get_top_companies_from_web, get_sustainability_report_pdf, download_pdf, extract_emission_data_as_json]
model_with_tools = model.bind_tools(tools)
def invoke_model(state: AgentState) -> dict:
"""Invokes the LLM with the current conversation history."""
logger.info("--- Invoking Model ---")
response = model_with_tools.invoke(state['messages'])
return {"messages": [response]}
def invoke_tools(state: AgentState) -> dict:
"""Invokes the necessary tools based on the last AI message."""
logger.info("--- Invoking Tools ---")
last_message = state['messages'][-1]
if not hasattr(last_message, 'tool_calls') or not last_message.tool_calls:
logger.info("No tool calls found in the last message.")
return {}
tool_invocation_messages = []
tool_map = {tool.name: tool for tool in tools}
for tool_call in last_message.tool_calls:
tool_name = tool_call['name']
tool_args = tool_call['args']
tool_call_id = tool_call['id']
logger.info(f"Executing tool: {tool_name} with args: {tool_args}")
if tool_name in tool_map:
selected_tool = tool_map[tool_name]
try:
result = selected_tool.invoke(tool_args)
if isinstance(result, list) or isinstance(result, dict):
result_content = json.dumps(result)
elif hasattr(result, 'companies') and isinstance(result.companies, list):
result_content = f"Companies found: {', '.join(result.companies)}"
elif result is None:
result_content = "Tool executed successfully, but returned no specific data (None)."
else:
result_content = str(result)
logger.info(f"Tool {tool_name} result: {result_content}")
tool_invocation_messages.append(
ToolMessage(content=result_content, tool_call_id=tool_call_id)
)
except Exception as e:
logger.error(f"Error executing tool {tool_name}: {e}")
tool_invocation_messages.append(
ToolMessage(content=f"Error executing tool {tool_name}: {str(e)}", tool_call_id=tool_call_id)
)
else:
logger.warning(f"Tool '{tool_name}' not found.")
tool_invocation_messages.append(
ToolMessage(content=f"Error: Tool '{tool_name}' not found.", tool_call_id=tool_call_id)
)
return {"messages": tool_invocation_messages}
graph_builder = StateGraph(AgentState)
graph_builder.add_node("scraper_agent", invoke_model)
graph_builder.add_node("tools", invoke_tools)
graph_builder.set_entry_point("scraper_agent")
def router(state: AgentState) -> str:
"""Determines the next step based on the last message."""
last_message = state['messages'][-1]
if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
logger.info("--- Routing to Tools ---")
return "tools"
else:
logger.info("--- Routing to End ---")
return END
graph_builder.add_conditional_edges(
"scraper_agent",
router,
{
"tools": "tools",
END: END,
}
)
graph_builder.add_edge("tools", "scraper_agent")
# Compile the graph
app = graph_builder.compile()
# # --- Running the Graph ---
# if __name__ == "__main__":
# logger.info("Starting graph execution...")
# # Use HumanMessage for the initial input
# initial_input = {"messages": [HumanMessage(content="Please download this pdf https://www.infosys.com/sustainability/documents/infosys-esg-report-2023-24.pdf")]}
# # Stream events to see the flow (optional, but helpful for debugging)
# # Add recursion limit to prevent infinite loops
# try:
# final_state = None
# for event in app.stream(initial_input, {"recursion_limit": 15}):
# # event is a dictionary where keys are node names and values are outputs
# logger.info(f"Event: {event}")
# # Keep track of the latest state if needed, especially the messages
# if "scraper_agent" in event:
# final_state = event["scraper_agent"]
# elif "tools" in event:
# final_state = event["tools"] # Though tool output doesn't directly give full state
# logger.info("---")
# logger.info("\n--- Final State Messages ---")
# # To get the absolute final state after streaming, invoke might be simpler,
# # or you need to properly aggregate the state from the stream events.
# # A simpler way to get final output:
# final_output = app.invoke(initial_input, {"recursion_limit": 15})
# logger.info(json.dumps(final_output['messages'][-1].dict(), indent=2)) # Print the last message
# except Exception as e:
# logger.error(f"\n--- An error occurred during graph execution ---")
# import traceback
# traceback.print_exc()
SCRAPER_SYSTEM_PROMPT = """
You are an intelligent assistant specialized in company research and sustainability report retrieval.
You have access to the following tools:
- **search_tool**: Use this tool when the user asks for a list of top companies related to an industry or category (e.g., "top 5 textile companies"). Always preserve any number mentioned (e.g., 'top 5', 'top 10') in the query.
- **pdf_finder_tool**: Use this tool when the user requests a sustainability report or any other specific PDF document about a company. Search specifically for the latest sustainability report if not otherwise specified.
- **pdf_downloader_tool**: Use this tool when the user provides a direct PDF link or asks you to download a PDF document from a URL.
Instructions:
- Carefully read the user's request and select the correct tool based on their intent.
- Always preserve important details like quantity (e.g., "top 5"), industry, or company name.
- If the user mentions multiple companies and asks for reports, find reports for **each** company individually.
- Do not add assumptions, opinions, or unrelated information.
- Always generate clean, direct, and minimal input for the tool β€” close to the user's original query.
- Prioritize the most recent information when searching for reports unless otherwise instructed.
Goal:
- Select the appropriate tool.
- Build a precise query that perfectly reflects the user's request.
- Return only what the user asks β€” no extra text or interpretation.
"""
search_tool = get_top_companies_from_web
pdf_finder_tool = get_sustainability_report_pdf
pdf_downloader_tool = download_pdf
llm = ChatOpenAI(model= 'gpt-4o-mini', temperature=0)
scraper_agent = create_agent(llm, [search_tool, pdf_finder_tool, pdf_downloader_tool], SCRAPER_SYSTEM_PROMPT)