Ahmud commited on
Commit
0edd622
·
verified ·
1 Parent(s): 38cfdef

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +54 -54
agent.py CHANGED
@@ -19,15 +19,8 @@ from langchain_core.runnables import RunnableConfig # for LangSmith tracking
19
  langsmith_api_key = os.getenv("LANGSMITH_API_KEY")
20
  langsmith_tracing = os.getenv("LANGSMITH_TRACING")
21
 
22
- # Add a global counter to track the number of questions answered
23
- question_counter = 0
24
-
25
- # Modify get_llm to accept a key parameter
26
- def get_llm(api_key=None):
27
- if api_key is None:
28
- api_keys = [os.getenv("OPENROUTER_API_KEY"), os.getenv("OPENROUTER_API_KEY_1")]
29
- else:
30
- api_keys = [api_key]
31
  last_exception = None
32
  for key in api_keys:
33
  if not key:
@@ -39,52 +32,15 @@ def get_llm(api_key=None):
39
  model="qwen/qwen3-coder:free",
40
  temperature=1
41
  )
 
 
42
  return llm
43
  except Exception as e:
44
  last_exception = e
45
  continue
46
  raise RuntimeError(f"All OpenRouter API keys failed: {last_exception}")
47
 
48
- # Remove the global llm instance
49
- # llm = get_llm()
50
-
51
- # In the LangGraphAgent class, select the key based on the counter
52
- class LangGraphAgent:
53
- def __init__(self):
54
- print("LangGraphAgent initialized.")
55
- self.counter = 0
56
- self.total = 0 # Set this to the total number of GAIA questions
57
-
58
- def __call__(self, question: str) -> str:
59
- # Decide which key to use
60
- if self.total == 0:
61
- self.total = 100 # Replace with actual total if known
62
- halfway = self.total // 2
63
- if self.counter < halfway:
64
- api_key = os.getenv("OPENROUTER_API_KEY")
65
- else:
66
- api_key = os.getenv("OPENROUTER_API_KEY_1")
67
- llm = get_llm(api_key)
68
- llm_with_tools = llm.bind_tools(tools)
69
- input_state = {"messages": [HumanMessage(content=question)]}
70
- print(f"Running LangGraphAgent with input: {question[:150]}...")
71
- config = RunnableConfig(
72
- config={
73
- "run_name": "GAIA Agent",
74
- "tags": ["gaia", "langgraph", "agent"],
75
- "metadata": {"user_input": question},
76
- "recursion_limit": 30,
77
- "tracing": True
78
- }
79
- )
80
- result = gaia_agent.invoke(input_state, config)
81
- final_response = result["messages"][-1].content
82
- self.counter += 1
83
- try:
84
- return final_response.split("FINAL ANSWER:")[-1].strip()
85
- except Exception:
86
- print("Could not split on 'FINAL ANSWER:'")
87
- return final_response
88
  python_tool = PythonAstREPLTool()
89
  search_tool = BraveSearch.from_api_key(
90
  api_key=os.getenv("BRAVE_SEARCH_API"),
@@ -163,12 +119,54 @@ gaia_agent = builder.compile() # converts my builder into a runnable agent by u
163
  class LangGraphAgent:
164
  def __init__(self):
165
  print("LangGraphAgent initialized.")
 
166
 
167
  def __call__(self, question: str) -> str:
168
- input_state = {"messages": [HumanMessage(content=question)]} # prepare the initial user message
169
- print(f"Running LangGraphAgent with input: {question[:150]}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
- # tracing configuration for LangSmith
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  config = RunnableConfig(
173
  config={
174
  "run_name": "GAIA Agent",
@@ -178,11 +176,13 @@ class LangGraphAgent:
178
  "tracing": True
179
  }
180
  )
181
- result = gaia_agent.invoke(input_state, config) # prevents infinite looping when the LLM keeps calling tools over and over
 
 
182
  final_response = result["messages"][-1].content
183
 
184
  try:
185
- return final_response.split("FINAL ANSWER:")[-1].strip() # parse out only what's after "FINAL ANSWER:"
186
  except Exception:
187
  print("Could not split on 'FINAL ANSWER:'")
188
  return final_response
 
19
  langsmith_api_key = os.getenv("LANGSMITH_API_KEY")
20
  langsmith_tracing = os.getenv("LANGSMITH_TRACING")
21
 
22
+ def get_llm():
23
+ api_keys = [os.getenv("OPENROUTER_API_KEY"), os.getenv("OPENROUTER_API_KEY_1")]
 
 
 
 
 
 
 
24
  last_exception = None
25
  for key in api_keys:
26
  if not key:
 
32
  model="qwen/qwen3-coder:free",
33
  temperature=1
34
  )
35
+ # Optionally, test the key with a trivial call to ensure it's valid
36
+ # llm.invoke([SystemMessage(content="ping")])
37
  return llm
38
  except Exception as e:
39
  last_exception = e
40
  continue
41
  raise RuntimeError(f"All OpenRouter API keys failed: {last_exception}")
42
 
43
+ llm = get_llm()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  python_tool = PythonAstREPLTool()
45
  search_tool = BraveSearch.from_api_key(
46
  api_key=os.getenv("BRAVE_SEARCH_API"),
 
119
  class LangGraphAgent:
120
  def __init__(self):
121
  print("LangGraphAgent initialized.")
122
+ self.question_count = 0 # Track the number of questions processed
123
 
124
  def __call__(self, question: str) -> str:
125
+ # Determine which API key to use based on question count
126
+ # First 50% of questions use OPENROUTER_API_KEY, rest use OPENROUTER_API_KEY_1
127
+ api_key = os.getenv("OPENROUTER_API_KEY") if self.question_count % 2 == 0 else os.getenv("OPENROUTER_API_KEY_1")
128
+
129
+ # Create a new LLM instance with the selected API key
130
+ current_llm = ChatOpenAI(
131
+ base_url="https://openrouter.ai/api/v1",
132
+ api_key=api_key,
133
+ model="qwen/qwen3-coder:free",
134
+ temperature=1
135
+ )
136
+
137
+ # Bind tools to the current LLM
138
+ current_llm_with_tools = current_llm.bind_tools(tools)
139
+
140
+ # Increment question counter for next call
141
+ self.question_count += 1
142
 
143
+ print(f"Running LangGraphAgent with input: {question[:150]}... (Using API key {self.question_count % 2 + 1})")
144
+
145
+ # Create a custom LLM node for this specific question
146
+ def custom_llm_call(state: MessagesState):
147
+ return {
148
+ "messages": [
149
+ current_llm_with_tools.invoke(
150
+ [SystemMessage(content=system_prompt)] + state["messages"]
151
+ )
152
+ ]
153
+ }
154
+
155
+ # Build a new workflow with the custom LLM
156
+ custom_builder = StateGraph(MessagesState)
157
+ custom_builder.add_node("llm_call", custom_llm_call)
158
+ custom_builder.add_node("environment", tool_node)
159
+ custom_builder.add_edge(START, "llm_call")
160
+ custom_builder.add_conditional_edges(
161
+ "llm_call",
162
+ should_continue,
163
+ {"Action": "environment", END: END}
164
+ )
165
+ custom_builder.add_edge("environment", "llm_call")
166
+ custom_agent = custom_builder.compile()
167
+
168
+ # Prepare the initial state and config
169
+ input_state = {"messages": [HumanMessage(content=question)]}
170
  config = RunnableConfig(
171
  config={
172
  "run_name": "GAIA Agent",
 
176
  "tracing": True
177
  }
178
  )
179
+
180
+ # Run the agent
181
+ result = custom_agent.invoke(input_state, config)
182
  final_response = result["messages"][-1].content
183
 
184
  try:
185
+ return final_response.split("FINAL ANSWER:")[-1].strip()
186
  except Exception:
187
  print("Could not split on 'FINAL ANSWER:'")
188
  return final_response