Ahmud commited on
Commit
aac43d6
·
verified ·
1 Parent(s): 5972474

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +38 -25
agent.py CHANGED
@@ -20,44 +20,44 @@ from prompt import system_prompt
20
  # --------------------------------------------------------------------
21
  # 1. API Key Rotation Setup
22
  # --------------------------------------------------------------------
23
- api_keys = [
24
  os.getenv("OPENROUTER_API_KEY"),
25
  os.getenv("OPENROUTER_API_KEY_1")
26
- ]
27
 
28
- if not any(api_keys):
29
  raise EnvironmentError("No OpenRouter API keys found in environment variables.")
30
 
31
- api_key_cycle = itertools.cycle([k for k in api_keys if k])
32
 
33
- def get_next_api_key():
34
  """Get the next API key in rotation."""
35
  return next(api_key_cycle)
36
 
37
  class RotatingChatOpenAI(ChatOpenAI):
38
- """ChatOpenAI wrapper that automatically rotates API keys on failure."""
39
 
40
  def invoke(self, *args, **kwargs):
41
- # Try each key once per call
42
- for _ in range(len(api_keys)):
43
- self.api_key = get_next_api_key()
44
  try:
45
  return super().invoke(*args, **kwargs)
46
  except Exception as e:
47
  # Handle rate-limits or auth errors
48
  if any(code in str(e) for code in ["429", "401", "403"]):
49
- print(f"[API Key Rotation] Key {self.api_key[:5]}... failed, trying next key...")
50
  continue
51
- raise # Re-raise other unexpected errors
52
- raise RuntimeError("All OpenRouter API keys failed or rate-limited.")
53
 
54
  # --------------------------------------------------------------------
55
  # 2. Initialize LLM with API Key Rotation
56
  # --------------------------------------------------------------------
57
  llm = RotatingChatOpenAI(
58
  base_url="https://openrouter.ai/api/v1",
59
- api_key=get_next_api_key(), # Start with the first key
60
- model="qwen/qwen3-coder:free", # Model must support function calling
61
  temperature=1
62
  )
63
 
@@ -99,18 +99,31 @@ def llm_call(state: MessagesState):
99
  # Tool Node
100
  def tool_node(state: MessagesState):
101
  """Executes tools requested by the LLM."""
102
- result = []
103
- for tool_call in state["messages"][-1].tool_calls:
104
- tool = tools_by_name[tool_call["name"]]
105
- observation = tool.invoke(tool_call["args"])
106
- result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))
107
- return {"messages": result}
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  # Conditional Routing
110
  def should_continue(state: MessagesState) -> Literal["Action", END]:
111
  """Route to tools if LLM made a tool call, else end."""
112
  last_message = state["messages"][-1]
113
- return "Action" if last_message.tool_calls else END
114
 
115
  # --------------------------------------------------------------------
116
  # 5. Build LangGraph Agent
@@ -154,8 +167,8 @@ class LangGraphAgent:
154
  result = gaia_agent.invoke(input_state, config)
155
  final_response = result["messages"][-1].content
156
 
157
- try:
 
158
  return final_response.split("FINAL ANSWER:")[-1].strip()
159
- except Exception:
160
- print("Could not split on 'FINAL ANSWER:'")
161
- return final_response
 
20
  # --------------------------------------------------------------------
21
  # 1. API Key Rotation Setup
22
  # --------------------------------------------------------------------
23
+ api_keys = [k for k in [
24
  os.getenv("OPENROUTER_API_KEY"),
25
  os.getenv("OPENROUTER_API_KEY_1")
26
+ ] if k]
27
 
28
+ if not api_keys:
29
  raise EnvironmentError("No OpenRouter API keys found in environment variables.")
30
 
31
+ api_key_cycle = itertools.cycle(api_keys)
32
 
33
+ def get_next_api_key() -> str:
34
  """Get the next API key in rotation."""
35
  return next(api_key_cycle)
36
 
37
  class RotatingChatOpenAI(ChatOpenAI):
38
+ """ChatOpenAI wrapper that automatically rotates OpenRouter API keys on failure."""
39
 
40
  def invoke(self, *args, **kwargs):
41
+ for _ in range(len(api_keys)): # try each key once per call
42
+ current_key = get_next_api_key()
43
+ self.openai_api_key = current_key # ✅ Correct for ChatOpenAI
44
  try:
45
  return super().invoke(*args, **kwargs)
46
  except Exception as e:
47
  # Handle rate-limits or auth errors
48
  if any(code in str(e) for code in ["429", "401", "403"]):
49
+ print(f"[API Key Rotation] Key {current_key[:5]}... failed, trying next key...")
50
  continue
51
+ raise # Re-raise unexpected errors
52
+ raise RuntimeError("All OpenRouter API keys failed or were rate-limited.")
53
 
54
  # --------------------------------------------------------------------
55
  # 2. Initialize LLM with API Key Rotation
56
  # --------------------------------------------------------------------
57
  llm = RotatingChatOpenAI(
58
  base_url="https://openrouter.ai/api/v1",
59
+ openai_api_key=get_next_api_key(), # start with the first key
60
+ model="qwen/qwen3-coder:free", # must support tool/function calling
61
  temperature=1
62
  )
63
 
 
99
  # Tool Node
100
  def tool_node(state: MessagesState):
101
  """Executes tools requested by the LLM."""
102
+ results = []
103
+ last_message = state["messages"][-1]
104
+
105
+ for tool_call in getattr(last_message, "tool_calls", []) or []:
106
+ tool = tools_by_name.get(tool_call["name"])
107
+ if not tool:
108
+ results.append(ToolMessage(content=f"Unknown tool: {tool_call['name']}", tool_call_id=tool_call["id"]))
109
+ continue
110
+
111
+ args = tool_call["args"]
112
+ # Handle dict vs positional args safely
113
+ try:
114
+ observation = tool.invoke(**args) if isinstance(args, dict) else tool.invoke(args)
115
+ except Exception as e:
116
+ observation = f"[Tool Error] {str(e)}"
117
+
118
+ results.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))
119
+
120
+ return {"messages": results}
121
 
122
  # Conditional Routing
123
  def should_continue(state: MessagesState) -> Literal["Action", END]:
124
  """Route to tools if LLM made a tool call, else end."""
125
  last_message = state["messages"][-1]
126
+ return "Action" if getattr(last_message, "tool_calls", None) else END
127
 
128
  # --------------------------------------------------------------------
129
  # 5. Build LangGraph Agent
 
167
  result = gaia_agent.invoke(input_state, config)
168
  final_response = result["messages"][-1].content
169
 
170
+ # Extract "FINAL ANSWER" if present
171
+ if isinstance(final_response, str):
172
  return final_response.split("FINAL ANSWER:")[-1].strip()
173
+ else:
174
+ return str(final_response)