Ahmud commited on
Commit
38cfdef
·
verified ·
1 Parent(s): aac43d6

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +125 -111
agent.py CHANGED
@@ -1,92 +1,110 @@
1
  import os
2
- import itertools
3
- from typing import TypedDict, Annotated, Literal
4
 
5
  from langchain_openai import ChatOpenAI
6
  from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, ToolMessage
7
  from langgraph.graph.message import add_messages
8
- from langgraph.graph import MessagesState, StateGraph, START, END
9
- from langchain_core.runnables import RunnableConfig # for LangSmith tracking
 
10
 
11
  from langchain_community.tools import BraveSearch # web search
12
  from langchain_experimental.tools.python.tool import PythonAstREPLTool # for logic/math problems
13
 
14
- from tools import (
15
- calculator_basic, datetime_tools, transcribe_audio,
16
- transcribe_youtube, query_image, webpage_content, read_excel
17
- )
18
  from prompt import system_prompt
19
 
20
- # --------------------------------------------------------------------
21
- # 1. API Key Rotation Setup
22
- # --------------------------------------------------------------------
23
- api_keys = [k for k in [
24
- os.getenv("OPENROUTER_API_KEY"),
25
- os.getenv("OPENROUTER_API_KEY_1")
26
- ] if k]
27
-
28
- if not api_keys:
29
- raise EnvironmentError("No OpenRouter API keys found in environment variables.")
30
-
31
- api_key_cycle = itertools.cycle(api_keys)
32
-
33
- def get_next_api_key() -> str:
34
- """Get the next API key in rotation."""
35
- return next(api_key_cycle)
36
-
37
- class RotatingChatOpenAI(ChatOpenAI):
38
- """ChatOpenAI wrapper that automatically rotates OpenRouter API keys on failure."""
39
-
40
- def invoke(self, *args, **kwargs):
41
- for _ in range(len(api_keys)): # try each key once per call
42
- current_key = get_next_api_key()
43
- self.openai_api_key = current_key # ✅ Correct for ChatOpenAI
44
- try:
45
- return super().invoke(*args, **kwargs)
46
- except Exception as e:
47
- # Handle rate-limits or auth errors
48
- if any(code in str(e) for code in ["429", "401", "403"]):
49
- print(f"[API Key Rotation] Key {current_key[:5]}... failed, trying next key...")
50
- continue
51
- raise # Re-raise unexpected errors
52
- raise RuntimeError("All OpenRouter API keys failed or were rate-limited.")
53
-
54
- # --------------------------------------------------------------------
55
- # 2. Initialize LLM with API Key Rotation
56
- # --------------------------------------------------------------------
57
- llm = RotatingChatOpenAI(
58
- base_url="https://openrouter.ai/api/v1",
59
- openai_api_key=get_next_api_key(), # ✅ start with the first key
60
- model="qwen/qwen3-coder:free", # must support tool/function calling
61
- temperature=1
62
- )
63
 
64
- # --------------------------------------------------------------------
65
- # 3. Tools Setup
66
- # --------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  python_tool = PythonAstREPLTool()
68
  search_tool = BraveSearch.from_api_key(
69
- api_key=os.getenv("BRAVE_SEARCH_API"),
70
- search_kwargs={"count": 4},
71
- description="Web search using Brave"
72
  )
73
 
74
  community_tools = [search_tool, python_tool]
75
- custom_tools = calculator_basic + datetime_tools + [
76
- transcribe_audio, transcribe_youtube, query_image, webpage_content, read_excel
77
- ]
78
 
79
  tools = community_tools + custom_tools
80
  llm_with_tools = llm.bind_tools(tools)
 
 
81
  tools_by_name = {tool.name: tool for tool in tools}
82
 
83
- # --------------------------------------------------------------------
84
- # 4. Define LangGraph State and Nodes
85
- # --------------------------------------------------------------------
86
- class MessagesState(TypedDict):
87
  messages: Annotated[list[AnyMessage], add_messages]
88
 
89
- # LLM Node
90
  def llm_call(state: MessagesState):
91
  return {
92
  "messages": [
@@ -96,64 +114,61 @@ def llm_call(state: MessagesState):
96
  ]
97
  }
98
 
99
- # Tool Node
100
  def tool_node(state: MessagesState):
101
- """Executes tools requested by the LLM."""
102
- results = []
103
- last_message = state["messages"][-1]
104
-
105
- for tool_call in getattr(last_message, "tool_calls", []) or []:
106
- tool = tools_by_name.get(tool_call["name"])
107
- if not tool:
108
- results.append(ToolMessage(content=f"Unknown tool: {tool_call['name']}", tool_call_id=tool_call["id"]))
109
- continue
110
-
111
- args = tool_call["args"]
112
- # Handle dict vs positional args safely
113
- try:
114
- observation = tool.invoke(**args) if isinstance(args, dict) else tool.invoke(args)
115
- except Exception as e:
116
- observation = f"[Tool Error] {str(e)}"
117
-
118
- results.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))
119
-
120
- return {"messages": results}
121
-
122
- # Conditional Routing
123
  def should_continue(state: MessagesState) -> Literal["Action", END]:
124
- """Route to tools if LLM made a tool call, else end."""
125
- last_message = state["messages"][-1]
126
- return "Action" if getattr(last_message, "tool_calls", None) else END
 
 
 
 
 
 
127
 
128
- # --------------------------------------------------------------------
129
- # 5. Build LangGraph Agent
130
- # --------------------------------------------------------------------
131
  builder = StateGraph(MessagesState)
132
 
 
133
  builder.add_node("llm_call", llm_call)
134
  builder.add_node("environment", tool_node)
135
 
 
136
  builder.add_edge(START, "llm_call")
137
  builder.add_conditional_edges(
138
  "llm_call",
139
  should_continue,
140
- {"Action": "environment", END: END}
 
141
  )
142
- builder.add_edge("environment", "llm_call")
 
143
 
144
- gaia_agent = builder.compile()
145
 
146
- # --------------------------------------------------------------------
147
- # 6. Agent Wrapper
148
- # --------------------------------------------------------------------
149
  class LangGraphAgent:
150
  def __init__(self):
151
- print("LangGraphAgent initialized with API key rotation.")
152
 
153
  def __call__(self, question: str) -> str:
154
- input_state = {"messages": [HumanMessage(content=question)]}
155
  print(f"Running LangGraphAgent with input: {question[:150]}...")
156
-
 
157
  config = RunnableConfig(
158
  config={
159
  "run_name": "GAIA Agent",
@@ -163,12 +178,11 @@ class LangGraphAgent:
163
  "tracing": True
164
  }
165
  )
166
-
167
- result = gaia_agent.invoke(input_state, config)
168
  final_response = result["messages"][-1].content
169
-
170
- # Extract "FINAL ANSWER" if present
171
- if isinstance(final_response, str):
172
- return final_response.split("FINAL ANSWER:")[-1].strip()
173
- else:
174
- return str(final_response)
 
1
  import os
 
 
2
 
3
  from langchain_openai import ChatOpenAI
4
  from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, ToolMessage
5
  from langgraph.graph.message import add_messages
6
+ from langgraph.graph import MessagesState
7
+ from langgraph.graph import StateGraph, START, END
8
+ from typing import TypedDict, Annotated, Literal
9
 
10
  from langchain_community.tools import BraveSearch # web search
11
  from langchain_experimental.tools.python.tool import PythonAstREPLTool # for logic/math problems
12
 
13
+ from tools import (calculator_basic, datetime_tools, transcribe_audio, transcribe_youtube, query_image, webpage_content, read_excel)
 
 
 
14
  from prompt import system_prompt
15
 
16
+ from langchain_core.runnables import RunnableConfig # for LangSmith tracking
17
+
18
+ # LangSmith to observe the agent
19
+ langsmith_api_key = os.getenv("LANGSMITH_API_KEY")
20
+ langsmith_tracing = os.getenv("LANGSMITH_TRACING")
21
+
22
+ # Add a global counter to track the number of questions answered
23
+ question_counter = 0
24
+
25
+ # Modify get_llm to accept a key parameter
26
+ def get_llm(api_key=None):
27
+ if api_key is None:
28
+ api_keys = [os.getenv("OPENROUTER_API_KEY"), os.getenv("OPENROUTER_API_KEY_1")]
29
+ else:
30
+ api_keys = [api_key]
31
+ last_exception = None
32
+ for key in api_keys:
33
+ if not key:
34
+ continue
35
+ try:
36
+ llm = ChatOpenAI(
37
+ base_url="https://openrouter.ai/api/v1",
38
+ api_key=key,
39
+ model="qwen/qwen3-coder:free",
40
+ temperature=1
41
+ )
42
+ return llm
43
+ except Exception as e:
44
+ last_exception = e
45
+ continue
46
+ raise RuntimeError(f"All OpenRouter API keys failed: {last_exception}")
47
+
48
+ # Remove the global llm instance
49
+ # llm = get_llm()
 
 
 
 
 
 
 
 
 
50
 
51
+ # In the LangGraphAgent class, select the key based on the counter
52
+ class LangGraphAgent:
53
+ def __init__(self):
54
+ print("LangGraphAgent initialized.")
55
+ self.counter = 0
56
+ self.total = 0 # Set this to the total number of GAIA questions
57
+
58
+ def __call__(self, question: str) -> str:
59
+ # Decide which key to use
60
+ if self.total == 0:
61
+ self.total = 100 # Replace with actual total if known
62
+ halfway = self.total // 2
63
+ if self.counter < halfway:
64
+ api_key = os.getenv("OPENROUTER_API_KEY")
65
+ else:
66
+ api_key = os.getenv("OPENROUTER_API_KEY_1")
67
+ llm = get_llm(api_key)
68
+ llm_with_tools = llm.bind_tools(tools)
69
+ input_state = {"messages": [HumanMessage(content=question)]}
70
+ print(f"Running LangGraphAgent with input: {question[:150]}...")
71
+ config = RunnableConfig(
72
+ config={
73
+ "run_name": "GAIA Agent",
74
+ "tags": ["gaia", "langgraph", "agent"],
75
+ "metadata": {"user_input": question},
76
+ "recursion_limit": 30,
77
+ "tracing": True
78
+ }
79
+ )
80
+ result = gaia_agent.invoke(input_state, config)
81
+ final_response = result["messages"][-1].content
82
+ self.counter += 1
83
+ try:
84
+ return final_response.split("FINAL ANSWER:")[-1].strip()
85
+ except Exception:
86
+ print("Could not split on 'FINAL ANSWER:'")
87
+ return final_response
88
  python_tool = PythonAstREPLTool()
89
  search_tool = BraveSearch.from_api_key(
90
+ api_key=os.getenv("BRAVE_SEARCH_API"),
91
+ search_kwargs={"count": 4}, # returns the 4 best results and their URL
92
+ description="Web search using Brave"
93
  )
94
 
95
  community_tools = [search_tool, python_tool]
96
+ custom_tools = calculator_basic + datetime_tools + [transcribe_audio, transcribe_youtube, query_image, webpage_content, read_excel]
 
 
97
 
98
  tools = community_tools + custom_tools
99
  llm_with_tools = llm.bind_tools(tools)
100
+
101
+ # Prepare tools by name
102
  tools_by_name = {tool.name: tool for tool in tools}
103
 
104
+ class MessagesState(TypedDict): # creates the state (is like the agent's memory at any moment)
 
 
 
105
  messages: Annotated[list[AnyMessage], add_messages]
106
 
107
+ # LLM node
108
  def llm_call(state: MessagesState):
109
  return {
110
  "messages": [
 
114
  ]
115
  }
116
 
117
+ # Tool node
118
  def tool_node(state: MessagesState):
119
+ """Executes the tools"""
120
+
121
+ result = []
122
+ for tool_call in state["messages"][-1].tool_calls: # gives a list of the tools the LLM decided to call
123
+ tool = tools_by_name[tool_call["name"]] # look up the actual tool function using a dictionary
124
+ observation = tool.invoke(tool_call["args"]) # executes the tool
125
+ result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"])) # the result from the tool is added to the memory
126
+ return {"messages": result} # thanks to add_messages, LangGraph will automatically append the result to the agent's message history
127
+
128
+ # Conditional edge function to route to the tool node or end based upon whether the LLM made a tool call
 
 
 
 
 
 
 
 
 
 
 
 
129
  def should_continue(state: MessagesState) -> Literal["Action", END]:
130
+ """Decide if we should continue the loop or stop based upon whether the LLM made a tool call"""
131
+
132
+ last_message = state["messages"][-1] # looks at the last message (usually from the LLM)
133
+
134
+ # If the LLM makes a tool call, then perform an action
135
+ if last_message.tool_calls:
136
+ return "Action"
137
+ # Otherwise, we stop (reply to the user)
138
+ return END
139
 
140
+ # Build workflow
 
 
141
  builder = StateGraph(MessagesState)
142
 
143
+ # Add nodes
144
  builder.add_node("llm_call", llm_call)
145
  builder.add_node("environment", tool_node)
146
 
147
+ # Add edges to connect nodes
148
  builder.add_edge(START, "llm_call")
149
  builder.add_conditional_edges(
150
  "llm_call",
151
  should_continue,
152
+ {"Action": "environment", # name returned by should_continue : Name of the next node
153
+ END: END}
154
  )
155
+ # If tool calls -> "Action" -> environment (executes the tool)
156
+ # If no tool calls -> END
157
 
158
+ builder.add_edge("environment", "llm_call") # after running the tools go back to the LLM for another round of reasoning
159
 
160
+ gaia_agent = builder.compile() # converts my builder into a runnable agent by using gaia_agent.invoke()
161
+
162
+ # Wrapper class to initialize and call the LangGraph agent with a user question
163
  class LangGraphAgent:
164
  def __init__(self):
165
+ print("LangGraphAgent initialized.")
166
 
167
  def __call__(self, question: str) -> str:
168
+ input_state = {"messages": [HumanMessage(content=question)]} # prepare the initial user message
169
  print(f"Running LangGraphAgent with input: {question[:150]}...")
170
+
171
+ # tracing configuration for LangSmith
172
  config = RunnableConfig(
173
  config={
174
  "run_name": "GAIA Agent",
 
178
  "tracing": True
179
  }
180
  )
181
+ result = gaia_agent.invoke(input_state, config) # prevents infinite looping when the LLM keeps calling tools over and over
 
182
  final_response = result["messages"][-1].content
183
+
184
+ try:
185
+ return final_response.split("FINAL ANSWER:")[-1].strip() # parse out only what's after "FINAL ANSWER:"
186
+ except Exception:
187
+ print("Could not split on 'FINAL ANSWER:'")
188
+ return final_response