wt002 commited on
Commit
e7e6762
·
verified ·
1 Parent(s): dd8df2c

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +12 -7
agent.py CHANGED
@@ -4,7 +4,10 @@ from langgraph.graph import START, StateGraph, MessagesState
4
  from langgraph.prebuilt import tools_condition
5
  from langgraph.prebuilt import ToolNode
6
  from langchain_google_genai import ChatGoogleGenerativeAI
7
- from langchain_groq import ChatGroq
 
 
 
8
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
9
  from langchain_community.tools.tavily_search import TavilySearchResults
10
  from langchain_community.document_loaders import WikipediaLoader
@@ -150,17 +153,17 @@ tools = [
150
  ]
151
 
152
  # Build graph function
153
- def build_graph(provider: str = "groq"):
154
  """Build the graph"""
155
  # Load environment variables from .env file
156
  if provider == "google":
157
  # Google Gemini
158
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
159
- elif provider == "groq":
160
- # Groq https://console.groq.com/docs/models
161
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
162
  elif provider == "huggingface":
163
- # TODO: Add huggingface endpoint
164
  llm = ChatHuggingFace(
165
  llm=HuggingFaceEndpoint(
166
  url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
@@ -168,9 +171,11 @@ def build_graph(provider: str = "groq"):
168
  ),
169
  )
170
  else:
171
- raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
 
172
  # Bind tools to LLM
173
  llm_with_tools = llm.bind_tools(tools)
 
174
 
175
  # Node
176
  def assistant(state: MessagesState):
 
4
  from langgraph.prebuilt import tools_condition
5
  from langgraph.prebuilt import ToolNode
6
  from langchain_google_genai import ChatGoogleGenerativeAI
7
+ from langchain_openai import ChatOpenAI
8
+ from langchain_google_genai import ChatGoogleGenerativeAI
9
+ from langchain_huggingface import ChatHuggingFace
10
+ from langchain_core.llms import HuggingFaceEndpoint
11
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
12
  from langchain_community.tools.tavily_search import TavilySearchResults
13
  from langchain_community.document_loaders import WikipediaLoader
 
153
  ]
154
 
155
  # Build graph function
156
+ def build_graph(provider: str = "openai"):
157
  """Build the graph"""
158
  # Load environment variables from .env file
159
  if provider == "google":
160
  # Google Gemini
161
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
162
+ elif provider == "openai":
163
+ # OpenAI (e.g., GPT-4 or GPT-3.5)
164
+ llm = ChatOpenAI(model="gpt-4", temperature=0)
165
  elif provider == "huggingface":
166
+ # Hugging Face endpoint
167
  llm = ChatHuggingFace(
168
  llm=HuggingFaceEndpoint(
169
  url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
 
171
  ),
172
  )
173
  else:
174
+ raise ValueError("Invalid provider. Choose 'google', 'openai' or 'huggingface'.")
175
+
176
  # Bind tools to LLM
177
  llm_with_tools = llm.bind_tools(tools)
178
+ return llm_with_tools
179
 
180
  # Node
181
  def assistant(state: MessagesState):