wt002 commited on
Commit
ab6c455
·
verified ·
1 Parent(s): a4bdca3

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +26 -10
agent.py CHANGED
@@ -152,25 +152,41 @@ tools = [
152
  ]
153
 
154
  # Build graph function
155
- def build_graph(provider: str = "openai"):
156
  """Build the graph"""
157
- # Load environment variables from .env file
158
  if provider == "google":
159
  # Google Gemini
160
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
161
- elif provider == "openai":
162
- # OpenAI (e.g., GPT-4 or GPT-3.5 free)
163
- llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
 
 
164
  elif provider == "huggingface":
165
- # Hugging Face endpoint
 
 
 
 
 
 
 
 
 
 
166
  llm = ChatHuggingFace(
167
  llm=HuggingFaceEndpoint(
168
- url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
169
  temperature=0,
170
- ),
 
171
  )
 
172
  else:
173
- raise ValueError("Invalid provider. Choose 'google', 'openai' or 'huggingface'.")
 
 
174
 
175
  # Bind tools to LLM
176
  llm_with_tools = llm.bind_tools(tools)
 
152
  ]
153
 
154
  # Build graph function
155
+ def build_graph(provider: str = "huggingface", huggingface_model: str = "mistral"):
156
  """Build the graph"""
157
+
158
  if provider == "google":
159
  # Google Gemini
160
+ llm = ChatGoogleGenerativeAI(
161
+ model="gemini-2.0-flash",
162
+ temperature=0,
163
+ google_api_key=os.getenv("GOOGLE_API_KEY")
164
+ )
165
+
166
  elif provider == "huggingface":
167
+ # Choose between supported Hugging Face models
168
+ if huggingface_model == "mistral":
169
+ model_url = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.1"
170
+ elif huggingface_model == "llama":
171
+ model_url = "https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf"
172
+ else:
173
+ raise ValueError("Unsupported Hugging Face model")
174
+
175
+ hf_token = os.getenv("HUGGINGFACE_API_TOKEN")
176
+ headers = {"Authorization": f"Bearer {hf_token}"} if hf_token else {}
177
+
178
  llm = ChatHuggingFace(
179
  llm=HuggingFaceEndpoint(
180
+ url=model_url,
181
  temperature=0,
182
+ headers=headers
183
+ )
184
  )
185
+
186
  else:
187
+ raise ValueError("Invalid provider. Choose 'google' or 'huggingface'.")
188
+
189
+ return llm
190
 
191
  # Bind tools to LLM
192
  llm_with_tools = llm.bind_tools(tools)