Spaces:
Sleeping
Sleeping
File size: 10,175 Bytes
172e21d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
import os
from dotenv import load_dotenv
import json
from langchain_core.messages import ToolMessage
from typing import TypedDict, Annotated
from langgraph.graph.message import add_messages
from typing import Annotated, List
from langgraph.graph import StateGraph, END
from langchain_core.messages import ToolMessage, HumanMessage
from langchain_openai import ChatOpenAI
# Local Imports
from application.tools.web_search_tools import get_top_companies_from_web, get_sustainability_report_pdf
from application.tools.pdf_downloader_tool import download_pdf
from application.tools.emission_data_extractor import extract_emission_data_as_json
from application.services.langgraph_service import create_agent
from application.utils.logger import get_logger
# setting up environment and logger
load_dotenv()
logger = get_logger()
# Langsmith
LANGSMITH_API_KEY=os.getenv('LANGSMITH_API_KEY')
os.environ['LANGSMITH_API_KEY'] = LANGSMITH_API_KEY
os.environ['LANGCHAIN_TRACING_V2'] = 'true'
os.environ["LANGCHAIN_PROJECT"] = "Sustainability_AI"
# OpenAI api key set up
os.environ['OPENAI_API_KEY'] = os.environ.get("OPENAI_API_KEY")
class AgentState(TypedDict):
messages: Annotated[List, add_messages]
graph = StateGraph(AgentState)
model = ChatOpenAI(model= 'gpt-4o-mini', temperature=0)
tools = [get_top_companies_from_web, get_sustainability_report_pdf, download_pdf, extract_emission_data_as_json]
model_with_tools = model.bind_tools(tools)
def invoke_model(state: AgentState) -> dict:
"""Invokes the LLM with the current conversation history."""
logger.info("--- Invoking Model ---")
# LangGraph automatically passes the entire state
# The model_with_tools expects a list of BaseMessages
response = model_with_tools.invoke(state['messages'])
# logger.info(f"Model response: {response}")
# We return a dictionary with the key corresponding to the state field name
return {"messages": [response]} # The response is already an AIMessage
def invoke_tools(state: AgentState) -> dict:
"""Invokes the necessary tools based on the last AI message."""
logger.info("--- Invoking Tools ---")
# The state contains the history, the last message is the AI's request
last_message = state['messages'][-1]
# Check if the last message is an AIMessage with tool_calls
if not hasattr(last_message, 'tool_calls') or not last_message.tool_calls:
logger.info("No tool calls found in the last message.")
# This scenario might indicate the conversation should end or requires clarification
# For now, return an empty dict, which won't update the state significantly.
# Consider adding a message indicating no tools were called if needed.
return {}
# Alternative: return {"messages": [SystemMessage(content="No tool calls requested.")]}
tool_invocation_messages = []
# Find the tool object by name
tool_map = {tool.name: tool for tool in tools}
for tool_call in last_message.tool_calls:
tool_name = tool_call['name']
tool_args = tool_call['args']
tool_call_id = tool_call['id'] # Crucial for linking the result
logger.info(f"Executing tool: {tool_name} with args: {tool_args}")
if tool_name in tool_map:
selected_tool = tool_map[tool_name]
try:
# Use the tool's invoke method, passing the arguments dictionary
result = selected_tool.invoke(tool_args)
# IMPORTANT: Convert the result to a string or a JSON serializable format
# if it's a complex object. ToolMessage content should be simple.
# Adjust this based on what your tools actually return.
if isinstance(result, list) or isinstance(result, dict):
result_content = json.dumps(result) # Convert dict/list to JSON string
elif hasattr(result, 'companies') and isinstance(result.companies, list): # Handle CompanyListResponse example
result_content = f"Companies found: {', '.join(result.companies)}"
elif result is None:
result_content = "Tool executed successfully, but returned no specific data (None)."
else:
result_content = str(result) # Default to string conversion
logger.info(f"Tool {tool_name} result: {result_content}")
tool_invocation_messages.append(
ToolMessage(content=result_content, tool_call_id=tool_call_id)
)
except Exception as e:
logger.error(f"Error executing tool {tool_name}: {e}")
# Return an error message in the ToolMessage
tool_invocation_messages.append(
ToolMessage(content=f"Error executing tool {tool_name}: {str(e)}", tool_call_id=tool_call_id)
)
else:
logger.warning(f"Tool '{tool_name}' not found.")
tool_invocation_messages.append(
ToolMessage(content=f"Error: Tool '{tool_name}' not found.", tool_call_id=tool_call_id)
)
# Return the collected ToolMessages to be added to the state
return {"messages": tool_invocation_messages}
# --- Graph Definition ---
graph_builder = StateGraph(AgentState)
# Add nodes
graph_builder.add_node("scraper_agent", invoke_model)
graph_builder.add_node("tools", invoke_tools) # Renamed for clarity
# Define edges
graph_builder.set_entry_point("scraper_agent")
# Conditional edge: After the agent runs, decide whether to call tools or end.
def router(state: AgentState) -> str:
"""Determines the next step based on the last message."""
last_message = state['messages'][-1]
if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
# If the AI message has tool calls, invoke the tools node
logger.info("--- Routing to Tools ---")
return "tools"
else:
# Otherwise, the conversation can end
logger.info("--- Routing to End ---")
return END
graph_builder.add_conditional_edges(
"scraper_agent",
router,
{
"tools": "tools", # If router returns "tools", go to the "tools" node
END: END, # If router returns END, finish the graph execution
}
)
# After tools are invoked, their results (ToolMessages) should go back to the agent
graph_builder.add_edge("tools", "scraper_agent")
# Compile the graph
app = graph_builder.compile()
# # --- Running the Graph ---
# if __name__ == "__main__":
# logger.info("Starting graph execution...")
# # Use HumanMessage for the initial input
# initial_input = {"messages": [HumanMessage(content="Please download this pdf https://www.infosys.com/sustainability/documents/infosys-esg-report-2023-24.pdf")]}
# # Stream events to see the flow (optional, but helpful for debugging)
# # Add recursion limit to prevent infinite loops
# try:
# final_state = None
# for event in app.stream(initial_input, {"recursion_limit": 15}):
# # event is a dictionary where keys are node names and values are outputs
# logger.info(f"Event: {event}")
# # Keep track of the latest state if needed, especially the messages
# if "scraper_agent" in event:
# final_state = event["scraper_agent"]
# elif "tools" in event:
# final_state = event["tools"] # Though tool output doesn't directly give full state
# logger.info("---")
# logger.info("\n--- Final State Messages ---")
# # To get the absolute final state after streaming, invoke might be simpler,
# # or you need to properly aggregate the state from the stream events.
# # A simpler way to get final output:
# final_output = app.invoke(initial_input, {"recursion_limit": 15})
# logger.info(json.dumps(final_output['messages'][-1].dict(), indent=2)) # Print the last message
# except Exception as e:
# logger.error(f"\n--- An error occurred during graph execution ---")
# import traceback
# traceback.print_exc()
SCRAPER_SYSTEM_PROMPT = """
You are an intelligent assistant specialized in company research and sustainability report retrieval.
You have access to the following tools:
- **search_tool**: Use this tool when the user asks for a list of top companies related to an industry or category (e.g., "top 5 textile companies"). Always preserve any number mentioned (e.g., 'top 5', 'top 10') in the query.
- **pdf_finder_tool**: Use this tool when the user requests a sustainability report or any other specific PDF document about a company. Search specifically for the latest sustainability report if not otherwise specified.
- **pdf_downloader_tool**: Use this tool when the user provides a direct PDF link or asks you to download a PDF document from a URL.
Instructions:
- Carefully read the user's request and select the correct tool based on their intent.
- Always preserve important details like quantity (e.g., "top 5"), industry, or company name.
- If the user mentions multiple companies and asks for reports, find reports for **each** company individually.
- Do not add assumptions, opinions, or unrelated information.
- Always generate clean, direct, and minimal input for the tool β close to the user's original query.
- Prioritize the most recent information when searching for reports unless otherwise instructed.
Goal:
- Select the appropriate tool.
- Build a precise query that perfectly reflects the user's request.
- Return only what the user asks β no extra text or interpretation.
"""
search_tool = get_top_companies_from_web
pdf_finder_tool = get_sustainability_report_pdf
pdf_downloader_tool = download_pdf
llm = ChatOpenAI(model= 'gpt-4o-mini', temperature=0)
scraper_agent = create_agent(llm, [search_tool, pdf_finder_tool, pdf_downloader_tool], SCRAPER_SYSTEM_PROMPT) |