File size: 10,175 Bytes
172e21d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import os
from dotenv import load_dotenv
import json
from langchain_core.messages import ToolMessage
from typing import TypedDict, Annotated
from langgraph.graph.message import add_messages
from typing import Annotated, List
from langgraph.graph import StateGraph, END
from langchain_core.messages import ToolMessage, HumanMessage
from langchain_openai import ChatOpenAI

# Local Imports
from application.tools.web_search_tools import get_top_companies_from_web, get_sustainability_report_pdf
from application.tools.pdf_downloader_tool import download_pdf
from application.tools.emission_data_extractor import extract_emission_data_as_json
from application.services.langgraph_service import create_agent
from application.utils.logger import get_logger

# setting up environment and logger
load_dotenv()
logger = get_logger()

# Langsmith
LANGSMITH_API_KEY=os.getenv('LANGSMITH_API_KEY')
os.environ['LANGSMITH_API_KEY'] = LANGSMITH_API_KEY
os.environ['LANGCHAIN_TRACING_V2'] = 'true'
os.environ["LANGCHAIN_PROJECT"] = "Sustainability_AI"

# OpenAI api key set up
os.environ['OPENAI_API_KEY'] = os.environ.get("OPENAI_API_KEY")


class AgentState(TypedDict):
    messages: Annotated[List, add_messages]

graph = StateGraph(AgentState)

model = ChatOpenAI(model= 'gpt-4o-mini', temperature=0)
tools = [get_top_companies_from_web, get_sustainability_report_pdf, download_pdf, extract_emission_data_as_json]
model_with_tools = model.bind_tools(tools)

def invoke_model(state: AgentState) -> dict:
    """Invokes the LLM with the current conversation history."""
    logger.info("--- Invoking Model ---")
    # LangGraph automatically passes the entire state
    # The model_with_tools expects a list of BaseMessages
    response = model_with_tools.invoke(state['messages'])
    # logger.info(f"Model response: {response}")
    # We return a dictionary with the key corresponding to the state field name
    return {"messages": [response]} # The response is already an AIMessage

def invoke_tools(state: AgentState) -> dict:
    """Invokes the necessary tools based on the last AI message."""
    logger.info("--- Invoking Tools ---")
    # The state contains the history, the last message is the AI's request
    last_message = state['messages'][-1]

    # Check if the last message is an AIMessage with tool_calls
    if not hasattr(last_message, 'tool_calls') or not last_message.tool_calls:
         logger.info("No tool calls found in the last message.")
         # This scenario might indicate the conversation should end or requires clarification
         # For now, return an empty dict, which won't update the state significantly.
         # Consider adding a message indicating no tools were called if needed.
         return {}
         # Alternative: return {"messages": [SystemMessage(content="No tool calls requested.")]}

    tool_invocation_messages = []

    # Find the tool object by name
    tool_map = {tool.name: tool for tool in tools}

    for tool_call in last_message.tool_calls:
        tool_name = tool_call['name']
        tool_args = tool_call['args']
        tool_call_id = tool_call['id'] # Crucial for linking the result

        logger.info(f"Executing tool: {tool_name} with args: {tool_args}")

        if tool_name in tool_map:
            selected_tool = tool_map[tool_name]
            try:
                # Use the tool's invoke method, passing the arguments dictionary
                result = selected_tool.invoke(tool_args)

                # IMPORTANT: Convert the result to a string or a JSON serializable format
                # if it's a complex object. ToolMessage content should be simple.
                # Adjust this based on what your tools actually return.
                if isinstance(result, list) or isinstance(result, dict):
                    result_content = json.dumps(result) # Convert dict/list to JSON string
                elif hasattr(result, 'companies') and isinstance(result.companies, list): # Handle CompanyListResponse example
                     result_content = f"Companies found: {', '.join(result.companies)}"
                elif result is None:
                    result_content = "Tool executed successfully, but returned no specific data (None)."
                else:
                    result_content = str(result) # Default to string conversion

                logger.info(f"Tool {tool_name} result: {result_content}")
                tool_invocation_messages.append(
                    ToolMessage(content=result_content, tool_call_id=tool_call_id)
                )
            except Exception as e:
                logger.error(f"Error executing tool {tool_name}: {e}")
                # Return an error message in the ToolMessage
                tool_invocation_messages.append(
                    ToolMessage(content=f"Error executing tool {tool_name}: {str(e)}", tool_call_id=tool_call_id)
                )
        else:
            logger.warning(f"Tool '{tool_name}' not found.")
            tool_invocation_messages.append(
                ToolMessage(content=f"Error: Tool '{tool_name}' not found.", tool_call_id=tool_call_id)
            )

    # Return the collected ToolMessages to be added to the state
    return {"messages": tool_invocation_messages}

# --- Graph Definition ---
graph_builder = StateGraph(AgentState)

# Add nodes
graph_builder.add_node("scraper_agent", invoke_model)
graph_builder.add_node("tools", invoke_tools) # Renamed for clarity

# Define edges
graph_builder.set_entry_point("scraper_agent")

# Conditional edge: After the agent runs, decide whether to call tools or end.
def router(state: AgentState) -> str:
    """Determines the next step based on the last message."""
    last_message = state['messages'][-1]
    if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
        # If the AI message has tool calls, invoke the tools node
        logger.info("--- Routing to Tools ---")
        return "tools"
    else:
        # Otherwise, the conversation can end
        logger.info("--- Routing to End ---")
        return END

graph_builder.add_conditional_edges(
    "scraper_agent",
    router,
    {
        "tools": "tools", # If router returns "tools", go to the "tools" node
         END: END,       # If router returns END, finish the graph execution
    }
)

# After tools are invoked, their results (ToolMessages) should go back to the agent
graph_builder.add_edge("tools", "scraper_agent")

# Compile the graph
app = graph_builder.compile()

# # --- Running the Graph ---
# if __name__ == "__main__":
#     logger.info("Starting graph execution...")
#     # Use HumanMessage for the initial input
#     initial_input = {"messages": [HumanMessage(content="Please download this pdf https://www.infosys.com/sustainability/documents/infosys-esg-report-2023-24.pdf")]}

#     # Stream events to see the flow (optional, but helpful for debugging)
#     # Add recursion limit to prevent infinite loops
#     try:
#         final_state = None
#         for event in app.stream(initial_input, {"recursion_limit": 15}):
#             # event is a dictionary where keys are node names and values are outputs
#             logger.info(f"Event: {event}")
#             # Keep track of the latest state if needed, especially the messages
#             if "scraper_agent" in event:
#                 final_state = event["scraper_agent"]
#             elif "tools" in event:
#                 final_state = event["tools"] # Though tool output doesn't directly give full state
#             logger.info("---")

#         logger.info("\n--- Final State Messages ---")
#         # To get the absolute final state after streaming, invoke might be simpler,
#         # or you need to properly aggregate the state from the stream events.
#         # A simpler way to get final output:
#         final_output = app.invoke(initial_input, {"recursion_limit": 15})
#         logger.info(json.dumps(final_output['messages'][-1].dict(), indent=2)) # Print the last message

#     except Exception as e:
#         logger.error(f"\n--- An error occurred during graph execution ---")
#         import traceback
#         traceback.print_exc()


SCRAPER_SYSTEM_PROMPT = """

You are an intelligent assistant specialized in company research and sustainability report retrieval.



You have access to the following tools:

- **search_tool**: Use this tool when the user asks for a list of top companies related to an industry or category (e.g., "top 5 textile companies"). Always preserve any number mentioned (e.g., 'top 5', 'top 10') in the query.

- **pdf_finder_tool**: Use this tool when the user requests a sustainability report or any other specific PDF document about a company. Search specifically for the latest sustainability report if not otherwise specified.

- **pdf_downloader_tool**: Use this tool when the user provides a direct PDF link or asks you to download a PDF document from a URL.



Instructions:

- Carefully read the user's request and select the correct tool based on their intent.

- Always preserve important details like quantity (e.g., "top 5"), industry, or company name.

- If the user mentions multiple companies and asks for reports, find reports for **each** company individually.

- Do not add assumptions, opinions, or unrelated information.

- Always generate clean, direct, and minimal input for the tool β€” close to the user's original query.

- Prioritize the most recent information when searching for reports unless otherwise instructed.



Goal:

- Select the appropriate tool.

- Build a precise query that perfectly reflects the user's request.

- Return only what the user asks β€” no extra text or interpretation.

"""
search_tool = get_top_companies_from_web
pdf_finder_tool = get_sustainability_report_pdf
pdf_downloader_tool = download_pdf

llm = ChatOpenAI(model= 'gpt-4o-mini', temperature=0)

scraper_agent = create_agent(llm, [search_tool, pdf_finder_tool, pdf_downloader_tool], SCRAPER_SYSTEM_PROMPT)