wt002 commited on
Commit
9027aff
·
verified ·
1 Parent(s): 208c9b2

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +15 -18
agent.py CHANGED
@@ -136,6 +136,8 @@ with open("config.yaml", "r") as f:
136
  config = yaml.safe_load(f)
137
 
138
  provider = config["provider"]
 
 
139
  #prompt_path = config["system_prompt_path"]
140
  enabled_tool_names = config["tools"]
141
 
@@ -245,8 +247,6 @@ question_retriever_tool = create_retriever_tool(
245
 
246
 
247
 
248
-
249
-
250
  tools = [
251
  multiply,
252
  add,
@@ -258,27 +258,24 @@ tools = [
258
  arvix_search,
259
  ]
260
 
261
- # Build graph function
262
- def build_graph(provider: str = "google"):
263
- """Build the graph"""
264
- # Load environment variables from .env file
265
  if provider == "google":
266
- # Google Gemini
267
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
268
  elif provider == "groq":
269
- # Groq https://console.groq.com/docs/models
270
- llm = ChatGroq(model="gemma2-9b-it", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
271
  elif provider == "huggingface":
272
- # TODO: Add huggingface endpoint
273
- llm = ChatHuggingFace(
274
- llm=HuggingFaceEndpoint(
275
- url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
276
- temperature=0,
277
- ),
278
  )
279
  else:
280
- raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
281
- # Bind tools to LLM
 
 
 
 
 
282
  llm_with_tools = llm.bind_tools(tools)
283
 
284
  # Node
 
136
  config = yaml.safe_load(f)
137
 
138
  provider = config["provider"]
139
+ model_config = config["models"][provider]
140
+
141
  #prompt_path = config["system_prompt_path"]
142
  enabled_tool_names = config["tools"]
143
 
 
247
 
248
 
249
 
 
 
250
  tools = [
251
  multiply,
252
  add,
 
258
  arvix_search,
259
  ]
260
 
261
+
262
+ def get_llm(provider: str, config: dict):
 
 
263
  if provider == "google":
264
+ return ChatGoogleGenerativeAI(model=config["model"], temperature=config["temperature"])
 
265
  elif provider == "groq":
266
+ return ChatGroq(model=config["model"], temperature=config["temperature"])
 
267
  elif provider == "huggingface":
268
+ return ChatHuggingFace(
269
+ llm=HuggingFaceEndpoint(url=config["url"], temperature=config["temperature"])
 
 
 
 
270
  )
271
  else:
272
+ raise ValueError(f"Invalid provider: {provider}")
273
+
274
+
275
+ # Build graph function
276
+ def build_graph():
277
+ """Build the graph based on provider"""
278
+ llm = get_llm(provider, model_config)
279
  llm_with_tools = llm.bind_tools(tools)
280
 
281
  # Node