supratipb commited on
Commit
c39e7b4
Β·
verified Β·
1 Parent(s): 89922e7

Upload agent.py

Browse files
Files changed (1) hide show
  1. agent.py +6 -8
agent.py CHANGED
@@ -279,7 +279,7 @@ tools = [
279
  ]
280
 
281
  # Build graph function
282
- def build_graph(provider: str = "anthropic"):
283
  """Build the graph"""
284
  # Load environment variables from .env file
285
  if provider == "openai":
@@ -287,15 +287,12 @@ def build_graph(provider: str = "anthropic"):
287
  elif provider == "anthropic":
288
  llm = ChatAnthropic(model="claude-v1", temperature=0)
289
  elif provider == "google":
290
- # Google Gemini
291
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
292
  elif provider == "groq":
293
- # Groq https://console.groq.com/docs/models
294
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
295
  elif provider == "huggingface":
296
- # TODO: Add huggingface endpoint
297
  llm = ChatHuggingFace(
298
- llm=HuggingFaceEndpoint(
299
  endpoint_url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
300
  temperature=0,
301
  ),
@@ -307,8 +304,9 @@ def build_graph(provider: str = "anthropic"):
307
 
308
  # Node
309
  def assistant(state: MessagesState):
310
- """Assistant node"""
311
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
 
312
 
313
  def retriever(state: MessagesState):
314
  """Retriever node"""
 
279
  ]
280
 
281
  # Build graph function
282
+ def build_graph(provider: str = "groq"):
283
  """Build the graph"""
284
  # Load environment variables from .env file
285
  if provider == "openai":
 
287
  elif provider == "anthropic":
288
  llm = ChatAnthropic(model="claude-v1", temperature=0)
289
  elif provider == "google":
 
290
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
291
  elif provider == "groq":
292
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
 
293
  elif provider == "huggingface":
 
294
  llm = ChatHuggingFace(
295
+ llm = HuggingFaceEndpoint(
296
  endpoint_url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
297
  temperature=0,
298
  ),
 
304
 
305
  # Node
306
  def assistant(state: MessagesState):
307
+ messages_with_sys = [sys_msg] + state["messages"]
308
+ return {"messages": [llm_with_tools.invoke(messages_with_sys)]}
309
+
310
 
311
  def retriever(state: MessagesState):
312
  """Retriever node"""