Ahmud commited on
Commit
66f6cc6
·
verified ·
1 Parent(s): 9c9b3ff

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +90 -58
agent.py CHANGED
@@ -1,51 +1,92 @@
1
  import os
 
 
2
 
3
  from langchain_openai import ChatOpenAI
4
  from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, ToolMessage
5
  from langgraph.graph.message import add_messages
6
- from langgraph.graph import MessagesState
7
- from langgraph.graph import StateGraph, START, END
8
- from typing import TypedDict, Annotated, Literal
9
 
10
  from langchain_community.tools import BraveSearch # web search
11
  from langchain_experimental.tools.python.tool import PythonAstREPLTool # for logic/math problems
12
 
13
- from tools import (calculator_basic, datetime_tools, transcribe_audio, transcribe_youtube, query_image, webpage_content, read_excel)
 
 
 
14
  from prompt import system_prompt
15
 
16
- from langchain_core.runnables import RunnableConfig # for LangSmith tracking
17
-
18
- # LangSmith to observe the agent
19
- langsmith_api_key = os.getenv("LANGSMITH_API_KEY")
20
- langsmith_tracing = os.getenv("LANGSMITH_TRACING")
21
-
22
- llm = ChatOpenAI(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  base_url="https://openrouter.ai/api/v1",
24
- api_key=os.getenv("OPENROUTER_API_KEY"),
25
- model="qwen/qwen3-coder:free", # Model must support function calling in OpenRouter
26
  temperature=1
27
  )
28
 
 
 
 
29
  python_tool = PythonAstREPLTool()
30
  search_tool = BraveSearch.from_api_key(
31
- api_key=os.getenv("BRAVE_SEARCH_API"),
32
- search_kwargs={"count": 4}, # returns the 4 best results and their URL
33
- description="Web search using Brave"
34
  )
35
 
36
  community_tools = [search_tool, python_tool]
37
- custom_tools = calculator_basic + datetime_tools + [transcribe_audio, transcribe_youtube, query_image, webpage_content, read_excel]
 
 
38
 
39
  tools = community_tools + custom_tools
40
  llm_with_tools = llm.bind_tools(tools)
41
-
42
- # Prepare tools by name
43
  tools_by_name = {tool.name: tool for tool in tools}
44
 
45
- class MessagesState(TypedDict): # creates the state (is like the agent's memory at any moment)
 
 
 
46
  messages: Annotated[list[AnyMessage], add_messages]
47
 
48
- # LLM node
49
  def llm_call(state: MessagesState):
50
  return {
51
  "messages": [
@@ -55,61 +96,51 @@ def llm_call(state: MessagesState):
55
  ]
56
  }
57
 
58
- # Tool node
59
  def tool_node(state: MessagesState):
60
- """Executes the tools"""
61
-
62
  result = []
63
- for tool_call in state["messages"][-1].tool_calls: # gives a list of the tools the LLM decided to call
64
- tool = tools_by_name[tool_call["name"]] # look up the actual tool function using a dictionary
65
- observation = tool.invoke(tool_call["args"]) # executes the tool
66
- result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"])) # the result from the tool is added to the memory
67
- return {"messages": result} # thanks to add_messages, LangGraph will automatically append the result to the agent's message history
68
 
69
- # Conditional edge function to route to the tool node or end based upon whether the LLM made a tool call
70
  def should_continue(state: MessagesState) -> Literal["Action", END]:
71
- """Decide if we should continue the loop or stop based upon whether the LLM made a tool call"""
72
-
73
- last_message = state["messages"][-1] # looks at the last message (usually from the LLM)
74
-
75
- # If the LLM makes a tool call, then perform an action
76
- if last_message.tool_calls:
77
- return "Action"
78
- # Otherwise, we stop (reply to the user)
79
- return END
80
 
81
- # Build workflow
 
 
82
  builder = StateGraph(MessagesState)
83
 
84
- # Add nodes
85
  builder.add_node("llm_call", llm_call)
86
  builder.add_node("environment", tool_node)
87
 
88
- # Add edges to connect nodes
89
  builder.add_edge(START, "llm_call")
90
  builder.add_conditional_edges(
91
  "llm_call",
92
  should_continue,
93
- {"Action": "environment", # name returned by should_continue : Name of the next node
94
- END: END}
95
  )
96
- # If tool calls -> "Action" -> environment (executes the tool)
97
- # If no tool calls -> END
98
-
99
- builder.add_edge("environment", "llm_call") # after running the tools go back to the LLM for another round of reasoning
100
 
101
- gaia_agent = builder.compile() # converts my builder into a runnable agent by using gaia_agent.invoke()
102
 
103
- # Wrapper class to initialize and call the LangGraph agent with a user question
 
 
104
  class LangGraphAgent:
105
  def __init__(self):
106
- print("LangGraphAgent initialized.")
107
 
108
  def __call__(self, question: str) -> str:
109
- input_state = {"messages": [HumanMessage(content=question)]} # prepare the initial user message
110
  print(f"Running LangGraphAgent with input: {question[:150]}...")
111
-
112
- # tracing configuration for LangSmith
113
  config = RunnableConfig(
114
  config={
115
  "run_name": "GAIA Agent",
@@ -119,11 +150,12 @@ class LangGraphAgent:
119
  "tracing": True
120
  }
121
  )
122
- result = gaia_agent.invoke(input_state, config) # prevents infinite looping when the LLM keeps calling tools over and over
 
123
  final_response = result["messages"][-1].content
124
-
125
  try:
126
- return final_response.split("FINAL ANSWER:")[-1].strip() # parse out only what's after "FINAL ANSWER:"
127
  except Exception:
128
  print("Could not split on 'FINAL ANSWER:'")
129
- return final_response
 
1
  import os
2
+ import itertools
3
+ from typing import TypedDict, Annotated, Literal
4
 
5
  from langchain_openai import ChatOpenAI
6
  from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, ToolMessage
7
  from langgraph.graph.message import add_messages
8
+ from langgraph.graph import MessagesState, StateGraph, START, END
9
+ from langchain_core.runnables import RunnableConfig # for LangSmith tracking
 
10
 
11
  from langchain_community.tools import BraveSearch # web search
12
  from langchain_experimental.tools.python.tool import PythonAstREPLTool # for logic/math problems
13
 
14
+ from tools import (
15
+ calculator_basic, datetime_tools, transcribe_audio,
16
+ transcribe_youtube, query_image, webpage_content, read_excel
17
+ )
18
  from prompt import system_prompt
19
 
20
+ # --------------------------------------------------------------------
21
+ # 1. API Key Rotation Setup
22
+ # --------------------------------------------------------------------
23
+ api_keys = [
24
+ os.getenv("OPENROUTER_API_KEY"),
25
+ os.getenv("OPENROUTER_API_KEY_1")
26
+ ]
27
+
28
+ if not any(api_keys):
29
+ raise EnvironmentError("No OpenRouter API keys found in environment variables.")
30
+
31
+ api_key_cycle = itertools.cycle([k for k in api_keys if k])
32
+
33
+ def get_next_api_key():
34
+ """Get the next API key in rotation."""
35
+ return next(api_key_cycle)
36
+
37
+ class RotatingChatOpenAI(ChatOpenAI):
38
+ """ChatOpenAI wrapper that automatically rotates API keys on failure."""
39
+
40
+ def invoke(self, *args, **kwargs):
41
+ # Try each key once per call
42
+ for _ in range(len(api_keys)):
43
+ self.api_key = get_next_api_key()
44
+ try:
45
+ return super().invoke(*args, **kwargs)
46
+ except Exception as e:
47
+ # Handle rate-limits or auth errors
48
+ if any(code in str(e) for code in ["429", "401", "403"]):
49
+ print(f"[API Key Rotation] Key {self.api_key[:5]}... failed, trying next key...")
50
+ continue
51
+ raise # Re-raise other unexpected errors
52
+ raise RuntimeError("All OpenRouter API keys failed or rate-limited.")
53
+
54
+ # --------------------------------------------------------------------
55
+ # 2. Initialize LLM with API Key Rotation
56
+ # --------------------------------------------------------------------
57
+ llm = RotatingChatOpenAI(
58
  base_url="https://openrouter.ai/api/v1",
59
+ api_key=get_next_api_key(), # Start with the first key
60
+ model="qwen/qwen3-coder:free", # Model must support function calling
61
  temperature=1
62
  )
63
 
64
+ # --------------------------------------------------------------------
65
+ # 3. Tools Setup
66
+ # --------------------------------------------------------------------
67
  python_tool = PythonAstREPLTool()
68
  search_tool = BraveSearch.from_api_key(
69
+ api_key=os.getenv("BRAVE_SEARCH_API"),
70
+ search_kwargs={"count": 4},
71
+ description="Web search using Brave"
72
  )
73
 
74
  community_tools = [search_tool, python_tool]
75
+ custom_tools = calculator_basic + datetime_tools + [
76
+ transcribe_audio, transcribe_youtube, query_image, webpage_content, read_excel
77
+ ]
78
 
79
  tools = community_tools + custom_tools
80
  llm_with_tools = llm.bind_tools(tools)
 
 
81
  tools_by_name = {tool.name: tool for tool in tools}
82
 
83
+ # --------------------------------------------------------------------
84
+ # 4. Define LangGraph State and Nodes
85
+ # --------------------------------------------------------------------
86
+ class MessagesState(TypedDict):
87
  messages: Annotated[list[AnyMessage], add_messages]
88
 
89
+ # LLM Node
90
  def llm_call(state: MessagesState):
91
  return {
92
  "messages": [
 
96
  ]
97
  }
98
 
99
+ # Tool Node
100
  def tool_node(state: MessagesState):
101
+ """Executes tools requested by the LLM."""
 
102
  result = []
103
+ for tool_call in state["messages"][-1].tool_calls:
104
+ tool = tools_by_name[tool_call["name"]]
105
+ observation = tool.invoke(tool_call["args"])
106
+ result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))
107
+ return {"messages": result}
108
 
109
+ # Conditional Routing
110
  def should_continue(state: MessagesState) -> Literal["Action", END]:
111
+ """Route to tools if LLM made a tool call, else end."""
112
+ last_message = state["messages"][-1]
113
+ return "Action" if last_message.tool_calls else END
 
 
 
 
 
 
114
 
115
+ # --------------------------------------------------------------------
116
+ # 5. Build LangGraph Agent
117
+ # --------------------------------------------------------------------
118
  builder = StateGraph(MessagesState)
119
 
 
120
  builder.add_node("llm_call", llm_call)
121
  builder.add_node("environment", tool_node)
122
 
 
123
  builder.add_edge(START, "llm_call")
124
  builder.add_conditional_edges(
125
  "llm_call",
126
  should_continue,
127
+ {"Action": "environment", END: END}
 
128
  )
129
+ builder.add_edge("environment", "llm_call")
 
 
 
130
 
131
+ gaia_agent = builder.compile()
132
 
133
+ # --------------------------------------------------------------------
134
+ # 6. Agent Wrapper
135
+ # --------------------------------------------------------------------
136
  class LangGraphAgent:
137
  def __init__(self):
138
+ print("LangGraphAgent initialized with API key rotation.")
139
 
140
  def __call__(self, question: str) -> str:
141
+ input_state = {"messages": [HumanMessage(content=question)]}
142
  print(f"Running LangGraphAgent with input: {question[:150]}...")
143
+
 
144
  config = RunnableConfig(
145
  config={
146
  "run_name": "GAIA Agent",
 
150
  "tracing": True
151
  }
152
  )
153
+
154
+ result = gaia_agent.invoke(input_state, config)
155
  final_response = result["messages"][-1].content
156
+
157
  try:
158
+ return final_response.split("FINAL ANSWER:")[-1].strip()
159
  except Exception:
160
  print("Could not split on 'FINAL ANSWER:'")
161
+ return final_response