Spaces:
Sleeping
Sleeping
import os | |
from dotenv import load_dotenv | |
import json | |
from langchain_core.messages import ToolMessage | |
from typing import TypedDict, Annotated | |
from langgraph.graph.message import add_messages | |
from typing import Annotated, List | |
from langgraph.graph import StateGraph, END | |
from langchain_core.messages import ToolMessage, HumanMessage | |
from langchain_openai import ChatOpenAI | |
# Local Imports | |
from application.tools.web_search_tools import get_top_companies_from_web, get_sustainability_report_pdf | |
from application.tools.pdf_downloader_tool import download_pdf | |
from application.tools.emission_data_extractor import extract_emission_data_as_json | |
from application.services.langgraph_service import create_agent | |
from application.utils.logger import get_logger | |
# setting up environment and logger | |
load_dotenv() | |
logger = get_logger() | |
# Langsmith | |
LANGSMITH_API_KEY=os.getenv('LANGSMITH_API_KEY') | |
os.environ['LANGSMITH_API_KEY'] = LANGSMITH_API_KEY | |
os.environ['LANGCHAIN_TRACING_V2'] = 'true' | |
os.environ["LANGCHAIN_PROJECT"] = "Sustainability_AI" | |
# OpenAI api key set up | |
os.environ['OPENAI_API_KEY'] = os.environ.get("OPENAI_API_KEY") | |
class AgentState(TypedDict): | |
messages: Annotated[List, add_messages] | |
graph = StateGraph(AgentState) | |
model = ChatOpenAI(model= 'gpt-4o-mini', temperature=0) | |
tools = [get_top_companies_from_web, get_sustainability_report_pdf, download_pdf, extract_emission_data_as_json] | |
model_with_tools = model.bind_tools(tools) | |
def invoke_model(state: AgentState) -> dict: | |
"""Invokes the LLM with the current conversation history.""" | |
logger.info("--- Invoking Model ---") | |
response = model_with_tools.invoke(state['messages']) | |
return {"messages": [response]} | |
def invoke_tools(state: AgentState) -> dict: | |
"""Invokes the necessary tools based on the last AI message.""" | |
logger.info("--- Invoking Tools ---") | |
last_message = state['messages'][-1] | |
if not hasattr(last_message, 'tool_calls') or not last_message.tool_calls: | |
logger.info("No tool calls found in the last message.") | |
return {} | |
tool_invocation_messages = [] | |
tool_map = {tool.name: tool for tool in tools} | |
for tool_call in last_message.tool_calls: | |
tool_name = tool_call['name'] | |
tool_args = tool_call['args'] | |
tool_call_id = tool_call['id'] | |
logger.info(f"Executing tool: {tool_name} with args: {tool_args}") | |
if tool_name in tool_map: | |
selected_tool = tool_map[tool_name] | |
try: | |
result = selected_tool.invoke(tool_args) | |
if isinstance(result, list) or isinstance(result, dict): | |
result_content = json.dumps(result) | |
elif hasattr(result, 'companies') and isinstance(result.companies, list): | |
result_content = f"Companies found: {', '.join(result.companies)}" | |
elif result is None: | |
result_content = "Tool executed successfully, but returned no specific data (None)." | |
else: | |
result_content = str(result) | |
logger.info(f"Tool {tool_name} result: {result_content}") | |
tool_invocation_messages.append( | |
ToolMessage(content=result_content, tool_call_id=tool_call_id) | |
) | |
except Exception as e: | |
logger.error(f"Error executing tool {tool_name}: {e}") | |
tool_invocation_messages.append( | |
ToolMessage(content=f"Error executing tool {tool_name}: {str(e)}", tool_call_id=tool_call_id) | |
) | |
else: | |
logger.warning(f"Tool '{tool_name}' not found.") | |
tool_invocation_messages.append( | |
ToolMessage(content=f"Error: Tool '{tool_name}' not found.", tool_call_id=tool_call_id) | |
) | |
return {"messages": tool_invocation_messages} | |
graph_builder = StateGraph(AgentState) | |
graph_builder.add_node("scraper_agent", invoke_model) | |
graph_builder.add_node("tools", invoke_tools) | |
graph_builder.set_entry_point("scraper_agent") | |
def router(state: AgentState) -> str: | |
"""Determines the next step based on the last message.""" | |
last_message = state['messages'][-1] | |
if hasattr(last_message, 'tool_calls') and last_message.tool_calls: | |
logger.info("--- Routing to Tools ---") | |
return "tools" | |
else: | |
logger.info("--- Routing to End ---") | |
return END | |
graph_builder.add_conditional_edges( | |
"scraper_agent", | |
router, | |
{ | |
"tools": "tools", | |
END: END, | |
} | |
) | |
graph_builder.add_edge("tools", "scraper_agent") | |
# Compile the graph | |
app = graph_builder.compile() | |
# # --- Running the Graph --- | |
# if __name__ == "__main__": | |
# logger.info("Starting graph execution...") | |
# # Use HumanMessage for the initial input | |
# initial_input = {"messages": [HumanMessage(content="Please download this pdf https://www.infosys.com/sustainability/documents/infosys-esg-report-2023-24.pdf")]} | |
# # Stream events to see the flow (optional, but helpful for debugging) | |
# # Add recursion limit to prevent infinite loops | |
# try: | |
# final_state = None | |
# for event in app.stream(initial_input, {"recursion_limit": 15}): | |
# # event is a dictionary where keys are node names and values are outputs | |
# logger.info(f"Event: {event}") | |
# # Keep track of the latest state if needed, especially the messages | |
# if "scraper_agent" in event: | |
# final_state = event["scraper_agent"] | |
# elif "tools" in event: | |
# final_state = event["tools"] # Though tool output doesn't directly give full state | |
# logger.info("---") | |
# logger.info("\n--- Final State Messages ---") | |
# # To get the absolute final state after streaming, invoke might be simpler, | |
# # or you need to properly aggregate the state from the stream events. | |
# # A simpler way to get final output: | |
# final_output = app.invoke(initial_input, {"recursion_limit": 15}) | |
# logger.info(json.dumps(final_output['messages'][-1].dict(), indent=2)) # Print the last message | |
# except Exception as e: | |
# logger.error(f"\n--- An error occurred during graph execution ---") | |
# import traceback | |
# traceback.print_exc() | |
SCRAPER_SYSTEM_PROMPT = """ | |
You are an intelligent assistant specialized in company research and sustainability report retrieval. | |
You have access to the following tools: | |
- **search_tool**: Use this tool when the user asks for a list of top companies related to an industry or category (e.g., "top 5 textile companies"). Always preserve any number mentioned (e.g., 'top 5', 'top 10') in the query. | |
- **pdf_finder_tool**: Use this tool when the user requests a sustainability report or any other specific PDF document about a company. Search specifically for the latest sustainability report if not otherwise specified. | |
- **pdf_downloader_tool**: Use this tool when the user provides a direct PDF link or asks you to download a PDF document from a URL. | |
Instructions: | |
- Carefully read the user's request and select the correct tool based on their intent. | |
- Always preserve important details like quantity (e.g., "top 5"), industry, or company name. | |
- If the user mentions multiple companies and asks for reports, find reports for **each** company individually. | |
- Do not add assumptions, opinions, or unrelated information. | |
- Always generate clean, direct, and minimal input for the tool β close to the user's original query. | |
- Prioritize the most recent information when searching for reports unless otherwise instructed. | |
Goal: | |
- Select the appropriate tool. | |
- Build a precise query that perfectly reflects the user's request. | |
- Return only what the user asks β no extra text or interpretation. | |
""" | |
search_tool = get_top_companies_from_web | |
pdf_finder_tool = get_sustainability_report_pdf | |
pdf_downloader_tool = download_pdf | |
llm = ChatOpenAI(model= 'gpt-4o-mini', temperature=0) | |
scraper_agent = create_agent(llm, [search_tool, pdf_finder_tool, pdf_downloader_tool], SCRAPER_SYSTEM_PROMPT) |