Spaces:
Sleeping
Sleeping
File size: 8,295 Bytes
172e21d 540db73 172e21d 540db73 172e21d 540db73 172e21d 540db73 172e21d 540db73 172e21d 540db73 172e21d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import os
from dotenv import load_dotenv
import json
from langchain_core.messages import ToolMessage
from typing import TypedDict, Annotated
from langgraph.graph.message import add_messages
from typing import Annotated, List
from langgraph.graph import StateGraph, END
from langchain_core.messages import ToolMessage, HumanMessage
from langchain_openai import ChatOpenAI
# Local Imports
from application.tools.web_search_tools import get_top_companies_from_web, get_sustainability_report_pdf
from application.tools.pdf_downloader_tool import download_pdf
from application.tools.emission_data_extractor import extract_emission_data_as_json
from application.services.langgraph_service import create_agent
from application.utils.logger import get_logger
# setting up environment and logger
load_dotenv()
logger = get_logger()
# Langsmith
LANGSMITH_API_KEY=os.getenv('LANGSMITH_API_KEY')
os.environ['LANGSMITH_API_KEY'] = LANGSMITH_API_KEY
os.environ['LANGCHAIN_TRACING_V2'] = 'true'
os.environ["LANGCHAIN_PROJECT"] = "Sustainability_AI"
# OpenAI api key set up
os.environ['OPENAI_API_KEY'] = os.environ.get("OPENAI_API_KEY")
class AgentState(TypedDict):
messages: Annotated[List, add_messages]
graph = StateGraph(AgentState)
model = ChatOpenAI(model= 'gpt-4o-mini', temperature=0)
tools = [get_top_companies_from_web, get_sustainability_report_pdf, download_pdf, extract_emission_data_as_json]
model_with_tools = model.bind_tools(tools)
def invoke_model(state: AgentState) -> dict:
"""Invokes the LLM with the current conversation history."""
logger.info("--- Invoking Model ---")
response = model_with_tools.invoke(state['messages'])
return {"messages": [response]}
def invoke_tools(state: AgentState) -> dict:
"""Invokes the necessary tools based on the last AI message."""
logger.info("--- Invoking Tools ---")
last_message = state['messages'][-1]
if not hasattr(last_message, 'tool_calls') or not last_message.tool_calls:
logger.info("No tool calls found in the last message.")
return {}
tool_invocation_messages = []
tool_map = {tool.name: tool for tool in tools}
for tool_call in last_message.tool_calls:
tool_name = tool_call['name']
tool_args = tool_call['args']
tool_call_id = tool_call['id']
logger.info(f"Executing tool: {tool_name} with args: {tool_args}")
if tool_name in tool_map:
selected_tool = tool_map[tool_name]
try:
result = selected_tool.invoke(tool_args)
if isinstance(result, list) or isinstance(result, dict):
result_content = json.dumps(result)
elif hasattr(result, 'companies') and isinstance(result.companies, list):
result_content = f"Companies found: {', '.join(result.companies)}"
elif result is None:
result_content = "Tool executed successfully, but returned no specific data (None)."
else:
result_content = str(result)
logger.info(f"Tool {tool_name} result: {result_content}")
tool_invocation_messages.append(
ToolMessage(content=result_content, tool_call_id=tool_call_id)
)
except Exception as e:
logger.error(f"Error executing tool {tool_name}: {e}")
tool_invocation_messages.append(
ToolMessage(content=f"Error executing tool {tool_name}: {str(e)}", tool_call_id=tool_call_id)
)
else:
logger.warning(f"Tool '{tool_name}' not found.")
tool_invocation_messages.append(
ToolMessage(content=f"Error: Tool '{tool_name}' not found.", tool_call_id=tool_call_id)
)
return {"messages": tool_invocation_messages}
graph_builder = StateGraph(AgentState)
graph_builder.add_node("scraper_agent", invoke_model)
graph_builder.add_node("tools", invoke_tools)
graph_builder.set_entry_point("scraper_agent")
def router(state: AgentState) -> str:
"""Determines the next step based on the last message."""
last_message = state['messages'][-1]
if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
logger.info("--- Routing to Tools ---")
return "tools"
else:
logger.info("--- Routing to End ---")
return END
graph_builder.add_conditional_edges(
"scraper_agent",
router,
{
"tools": "tools",
END: END,
}
)
graph_builder.add_edge("tools", "scraper_agent")
# Compile the graph
app = graph_builder.compile()
# # --- Running the Graph ---
# if __name__ == "__main__":
# logger.info("Starting graph execution...")
# # Use HumanMessage for the initial input
# initial_input = {"messages": [HumanMessage(content="Please download this pdf https://www.infosys.com/sustainability/documents/infosys-esg-report-2023-24.pdf")]}
# # Stream events to see the flow (optional, but helpful for debugging)
# # Add recursion limit to prevent infinite loops
# try:
# final_state = None
# for event in app.stream(initial_input, {"recursion_limit": 15}):
# # event is a dictionary where keys are node names and values are outputs
# logger.info(f"Event: {event}")
# # Keep track of the latest state if needed, especially the messages
# if "scraper_agent" in event:
# final_state = event["scraper_agent"]
# elif "tools" in event:
# final_state = event["tools"] # Though tool output doesn't directly give full state
# logger.info("---")
# logger.info("\n--- Final State Messages ---")
# # To get the absolute final state after streaming, invoke might be simpler,
# # or you need to properly aggregate the state from the stream events.
# # A simpler way to get final output:
# final_output = app.invoke(initial_input, {"recursion_limit": 15})
# logger.info(json.dumps(final_output['messages'][-1].dict(), indent=2)) # Print the last message
# except Exception as e:
# logger.error(f"\n--- An error occurred during graph execution ---")
# import traceback
# traceback.print_exc()
SCRAPER_SYSTEM_PROMPT = """
You are an intelligent assistant specialized in company research and sustainability report retrieval.
You have access to the following tools:
- **search_tool**: Use this tool when the user asks for a list of top companies related to an industry or category (e.g., "top 5 textile companies"). Always preserve any number mentioned (e.g., 'top 5', 'top 10') in the query.
- **pdf_finder_tool**: Use this tool when the user requests a sustainability report or any other specific PDF document about a company. Search specifically for the latest sustainability report if not otherwise specified.
- **pdf_downloader_tool**: Use this tool when the user provides a direct PDF link or asks you to download a PDF document from a URL.
Instructions:
- Carefully read the user's request and select the correct tool based on their intent.
- Always preserve important details like quantity (e.g., "top 5"), industry, or company name.
- If the user mentions multiple companies and asks for reports, find reports for **each** company individually.
- Do not add assumptions, opinions, or unrelated information.
- Always generate clean, direct, and minimal input for the tool β close to the user's original query.
- Prioritize the most recent information when searching for reports unless otherwise instructed.
Goal:
- Select the appropriate tool.
- Build a precise query that perfectly reflects the user's request.
- Return only what the user asks β no extra text or interpretation.
"""
search_tool = get_top_companies_from_web
pdf_finder_tool = get_sustainability_report_pdf
pdf_downloader_tool = download_pdf
llm = ChatOpenAI(model= 'gpt-4o-mini', temperature=0)
scraper_agent = create_agent(llm, [search_tool, pdf_finder_tool, pdf_downloader_tool], SCRAPER_SYSTEM_PROMPT) |