File size: 8,295 Bytes
172e21d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
540db73
172e21d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
540db73
172e21d
 
 
 
 
 
 
 
 
540db73
 
172e21d
 
 
 
540db73
172e21d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
540db73
172e21d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
540db73
 
172e21d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import os
from dotenv import load_dotenv
import json
from langchain_core.messages import ToolMessage
from typing import TypedDict, Annotated
from langgraph.graph.message import add_messages
from typing import Annotated, List
from langgraph.graph import StateGraph, END
from langchain_core.messages import ToolMessage, HumanMessage
from langchain_openai import ChatOpenAI

# Local Imports
from application.tools.web_search_tools import get_top_companies_from_web, get_sustainability_report_pdf
from application.tools.pdf_downloader_tool import download_pdf
from application.tools.emission_data_extractor import extract_emission_data_as_json
from application.services.langgraph_service import create_agent
from application.utils.logger import get_logger

# setting up environment and logger
load_dotenv()
logger = get_logger()

# Langsmith
LANGSMITH_API_KEY=os.getenv('LANGSMITH_API_KEY')
os.environ['LANGSMITH_API_KEY'] = LANGSMITH_API_KEY
os.environ['LANGCHAIN_TRACING_V2'] = 'true'
os.environ["LANGCHAIN_PROJECT"] = "Sustainability_AI"

# OpenAI api key set up
os.environ['OPENAI_API_KEY'] = os.environ.get("OPENAI_API_KEY")


class AgentState(TypedDict):
    messages: Annotated[List, add_messages]

graph = StateGraph(AgentState)

model = ChatOpenAI(model= 'gpt-4o-mini', temperature=0)
tools = [get_top_companies_from_web, get_sustainability_report_pdf, download_pdf, extract_emission_data_as_json]
model_with_tools = model.bind_tools(tools)

def invoke_model(state: AgentState) -> dict:
    """Invokes the LLM with the current conversation history."""
    logger.info("--- Invoking Model ---")
    response = model_with_tools.invoke(state['messages'])
    return {"messages": [response]}

def invoke_tools(state: AgentState) -> dict:
    """Invokes the necessary tools based on the last AI message."""
    logger.info("--- Invoking Tools ---")
    last_message = state['messages'][-1]

    if not hasattr(last_message, 'tool_calls') or not last_message.tool_calls:
         logger.info("No tool calls found in the last message.")
         return {}

    tool_invocation_messages = []

    tool_map = {tool.name: tool for tool in tools}

    for tool_call in last_message.tool_calls:
        tool_name = tool_call['name']
        tool_args = tool_call['args']
        tool_call_id = tool_call['id'] 

        logger.info(f"Executing tool: {tool_name} with args: {tool_args}")

        if tool_name in tool_map:
            selected_tool = tool_map[tool_name]
            try:
                result = selected_tool.invoke(tool_args)

                if isinstance(result, list) or isinstance(result, dict):
                    result_content = json.dumps(result) 
                elif hasattr(result, 'companies') and isinstance(result.companies, list):
                     result_content = f"Companies found: {', '.join(result.companies)}"
                elif result is None:
                    result_content = "Tool executed successfully, but returned no specific data (None)."
                else:
                    result_content = str(result) 

                logger.info(f"Tool {tool_name} result: {result_content}")
                tool_invocation_messages.append(
                    ToolMessage(content=result_content, tool_call_id=tool_call_id)
                )
            except Exception as e:
                logger.error(f"Error executing tool {tool_name}: {e}")
                tool_invocation_messages.append(
                    ToolMessage(content=f"Error executing tool {tool_name}: {str(e)}", tool_call_id=tool_call_id)
                )
        else:
            logger.warning(f"Tool '{tool_name}' not found.")
            tool_invocation_messages.append(
                ToolMessage(content=f"Error: Tool '{tool_name}' not found.", tool_call_id=tool_call_id)
            )

    return {"messages": tool_invocation_messages}

graph_builder = StateGraph(AgentState)

graph_builder.add_node("scraper_agent", invoke_model)
graph_builder.add_node("tools", invoke_tools)

graph_builder.set_entry_point("scraper_agent")

def router(state: AgentState) -> str:
    """Determines the next step based on the last message."""
    last_message = state['messages'][-1]
    if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
        logger.info("--- Routing to Tools ---")
        return "tools"
    else:
        logger.info("--- Routing to End ---")
        return END

graph_builder.add_conditional_edges(
    "scraper_agent",
    router,
    {
        "tools": "tools",
         END: END,
    }
)

graph_builder.add_edge("tools", "scraper_agent")

# Compile the graph
app = graph_builder.compile()

# # --- Running the Graph ---
# if __name__ == "__main__":
#     logger.info("Starting graph execution...")
#     # Use HumanMessage for the initial input
#     initial_input = {"messages": [HumanMessage(content="Please download this pdf https://www.infosys.com/sustainability/documents/infosys-esg-report-2023-24.pdf")]}

#     # Stream events to see the flow (optional, but helpful for debugging)
#     # Add recursion limit to prevent infinite loops
#     try:
#         final_state = None
#         for event in app.stream(initial_input, {"recursion_limit": 15}):
#             # event is a dictionary where keys are node names and values are outputs
#             logger.info(f"Event: {event}")
#             # Keep track of the latest state if needed, especially the messages
#             if "scraper_agent" in event:
#                 final_state = event["scraper_agent"]
#             elif "tools" in event:
#                 final_state = event["tools"] # Though tool output doesn't directly give full state
#             logger.info("---")

#         logger.info("\n--- Final State Messages ---")
#         # To get the absolute final state after streaming, invoke might be simpler,
#         # or you need to properly aggregate the state from the stream events.
#         # A simpler way to get final output:
#         final_output = app.invoke(initial_input, {"recursion_limit": 15})
#         logger.info(json.dumps(final_output['messages'][-1].dict(), indent=2)) # Print the last message

#     except Exception as e:
#         logger.error(f"\n--- An error occurred during graph execution ---")
#         import traceback
#         traceback.print_exc()


SCRAPER_SYSTEM_PROMPT = """

You are an intelligent assistant specialized in company research and sustainability report retrieval.



You have access to the following tools:

- **search_tool**: Use this tool when the user asks for a list of top companies related to an industry or category (e.g., "top 5 textile companies"). Always preserve any number mentioned (e.g., 'top 5', 'top 10') in the query.

- **pdf_finder_tool**: Use this tool when the user requests a sustainability report or any other specific PDF document about a company. Search specifically for the latest sustainability report if not otherwise specified.

- **pdf_downloader_tool**: Use this tool when the user provides a direct PDF link or asks you to download a PDF document from a URL.



Instructions:

- Carefully read the user's request and select the correct tool based on their intent.

- Always preserve important details like quantity (e.g., "top 5"), industry, or company name.

- If the user mentions multiple companies and asks for reports, find reports for **each** company individually.

- Do not add assumptions, opinions, or unrelated information.

- Always generate clean, direct, and minimal input for the tool β€” close to the user's original query.

- Prioritize the most recent information when searching for reports unless otherwise instructed.



Goal:

- Select the appropriate tool.

- Build a precise query that perfectly reflects the user's request.

- Return only what the user asks β€” no extra text or interpretation.

"""
search_tool = get_top_companies_from_web
pdf_finder_tool = get_sustainability_report_pdf
pdf_downloader_tool = download_pdf

llm = ChatOpenAI(model= 'gpt-4o-mini', temperature=0)

scraper_agent = create_agent(llm, [search_tool, pdf_finder_tool, pdf_downloader_tool], SCRAPER_SYSTEM_PROMPT)