Spaces:
Sleeping
Sleeping
Update agent.py
Browse files
agent.py
CHANGED
@@ -136,6 +136,8 @@ with open("config.yaml", "r") as f:
|
|
136 |
config = yaml.safe_load(f)
|
137 |
|
138 |
provider = config["provider"]
|
|
|
|
|
139 |
#prompt_path = config["system_prompt_path"]
|
140 |
enabled_tool_names = config["tools"]
|
141 |
|
@@ -245,8 +247,6 @@ question_retriever_tool = create_retriever_tool(
|
|
245 |
|
246 |
|
247 |
|
248 |
-
|
249 |
-
|
250 |
tools = [
|
251 |
multiply,
|
252 |
add,
|
@@ -258,27 +258,24 @@ tools = [
|
|
258 |
arvix_search,
|
259 |
]
|
260 |
|
261 |
-
|
262 |
-
def
|
263 |
-
"""Build the graph"""
|
264 |
-
# Load environment variables from .env file
|
265 |
if provider == "google":
|
266 |
-
|
267 |
-
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
|
268 |
elif provider == "groq":
|
269 |
-
|
270 |
-
llm = ChatGroq(model="gemma2-9b-it", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
|
271 |
elif provider == "huggingface":
|
272 |
-
|
273 |
-
|
274 |
-
llm=HuggingFaceEndpoint(
|
275 |
-
url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
|
276 |
-
temperature=0,
|
277 |
-
),
|
278 |
)
|
279 |
else:
|
280 |
-
raise ValueError("Invalid provider
|
281 |
-
|
|
|
|
|
|
|
|
|
|
|
282 |
llm_with_tools = llm.bind_tools(tools)
|
283 |
|
284 |
# Node
|
|
|
136 |
config = yaml.safe_load(f)
|
137 |
|
138 |
provider = config["provider"]
|
139 |
+
model_config = config["models"][provider]
|
140 |
+
|
141 |
#prompt_path = config["system_prompt_path"]
|
142 |
enabled_tool_names = config["tools"]
|
143 |
|
|
|
247 |
|
248 |
|
249 |
|
|
|
|
|
250 |
tools = [
|
251 |
multiply,
|
252 |
add,
|
|
|
258 |
arvix_search,
|
259 |
]
|
260 |
|
261 |
+
|
262 |
+
def get_llm(provider: str, config: dict):
|
|
|
|
|
263 |
if provider == "google":
|
264 |
+
return ChatGoogleGenerativeAI(model=config["model"], temperature=config["temperature"])
|
|
|
265 |
elif provider == "groq":
|
266 |
+
return ChatGroq(model=config["model"], temperature=config["temperature"])
|
|
|
267 |
elif provider == "huggingface":
|
268 |
+
return ChatHuggingFace(
|
269 |
+
llm=HuggingFaceEndpoint(url=config["url"], temperature=config["temperature"])
|
|
|
|
|
|
|
|
|
270 |
)
|
271 |
else:
|
272 |
+
raise ValueError(f"Invalid provider: {provider}")
|
273 |
+
|
274 |
+
|
275 |
+
# Build graph function
|
276 |
+
def build_graph():
|
277 |
+
"""Build the graph based on provider"""
|
278 |
+
llm = get_llm(provider, model_config)
|
279 |
llm_with_tools = llm.bind_tools(tools)
|
280 |
|
281 |
# Node
|