Commit
·
283e426
1
Parent(s):
809f87e
add tools to langgraph
Browse files- app.py +6 -3
- langgraph_dir/agent.py +21 -13
- langgraph_dir/custom_tools.py +50 -15
- requirements.txt +4 -1
app.py
CHANGED
@@ -10,8 +10,8 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
|
10 |
|
11 |
|
12 |
# --- Choice of framework (either "langgraph" or "llamaindex") ---
|
13 |
-
|
14 |
-
FRAMEWORK = 'llamaindex'
|
15 |
|
16 |
|
17 |
async def run_and_submit_all(profile: gr.OAuthProfile | None):
|
@@ -98,7 +98,10 @@ async def run_and_submit_all(profile: gr.OAuthProfile | None):
|
|
98 |
pass
|
99 |
|
100 |
# call the agent
|
101 |
-
|
|
|
|
|
|
|
102 |
answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
|
103 |
results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
|
104 |
agent.ctx.clear() # clear context for next question
|
|
|
10 |
|
11 |
|
12 |
# --- Choice of framework (either "langgraph" or "llamaindex") ---
|
13 |
+
FRAMEWORK = 'langgraph'
|
14 |
+
# FRAMEWORK = 'llamaindex'
|
15 |
|
16 |
|
17 |
async def run_and_submit_all(profile: gr.OAuthProfile | None):
|
|
|
98 |
pass
|
99 |
|
100 |
# call the agent
|
101 |
+
if FRAMEWORK == 'llamaindex':
|
102 |
+
submitted_answer = await agent(question_text)
|
103 |
+
else:
|
104 |
+
submitted_answer = agent(question_text)
|
105 |
answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
|
106 |
results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
|
107 |
agent.ctx.clear() # clear context for next question
|
langgraph_dir/agent.py
CHANGED
@@ -1,12 +1,15 @@
|
|
1 |
-
|
2 |
|
|
|
3 |
from langchain_openai import ChatOpenAI
|
4 |
from langgraph.graph import MessagesState
|
5 |
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage
|
6 |
from langgraph.graph import StateGraph, START, END
|
|
|
|
|
7 |
|
8 |
from .prompt import system_prompt
|
9 |
-
from .custom_tools import multiply, add, divide
|
10 |
|
11 |
|
12 |
class LangGraphAgent:
|
@@ -16,11 +19,18 @@ class LangGraphAgent:
|
|
16 |
show_prompt=True):
|
17 |
|
18 |
# =========== LLM definition ===========
|
19 |
-
llm = ChatOpenAI(model=model_name, temperature=0)
|
20 |
print(f"LangGraphAgent initialized with model \"{model_name}\"")
|
21 |
|
22 |
# =========== Augment the LLM with tools ===========
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
tools_by_name = {tool.name: tool for tool in tools}
|
25 |
llm_with_tools = llm.bind_tools(tools)
|
26 |
|
@@ -95,17 +105,15 @@ class LangGraphAgent:
|
|
95 |
# Compile the agent
|
96 |
self.agent = agent_builder.compile()
|
97 |
|
|
|
|
|
|
|
|
|
98 |
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
# print(tool.metadata.description)
|
103 |
|
104 |
-
# if show_prompt:
|
105 |
-
# prompt_dict = self.agent.get_prompts()
|
106 |
-
# for k, v in prompt_dict.items():
|
107 |
-
# print("\n" + "="*30 + f" Prompt: {k} " + "="*30)
|
108 |
-
# print(v.template)
|
109 |
|
110 |
def __call__(self, question: str) -> str:
|
111 |
print("\n\n"+"*"*50)
|
|
|
1 |
+
import json
|
2 |
|
3 |
+
from typing import Literal
|
4 |
from langchain_openai import ChatOpenAI
|
5 |
from langgraph.graph import MessagesState
|
6 |
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage
|
7 |
from langgraph.graph import StateGraph, START, END
|
8 |
+
from langchain.agents import load_tools
|
9 |
+
from langchain_community.tools.riza.command import ExecPython
|
10 |
|
11 |
from .prompt import system_prompt
|
12 |
+
from .custom_tools import multiply, add, subtract, divide, modulus, power
|
13 |
|
14 |
|
15 |
class LangGraphAgent:
|
|
|
19 |
show_prompt=True):
|
20 |
|
21 |
# =========== LLM definition ===========
|
22 |
+
llm = ChatOpenAI(model=model_name, temperature=0) # needs OPENAI_API_KEY
|
23 |
print(f"LangGraphAgent initialized with model \"{model_name}\"")
|
24 |
|
25 |
# =========== Augment the LLM with tools ===========
|
26 |
+
community_tool_names = [
|
27 |
+
"ddg-search", # DuckDuckGo search
|
28 |
+
"wikipedia",
|
29 |
+
]
|
30 |
+
community_tools = load_tools(community_tool_names)
|
31 |
+
community_tools += [ExecPython()] # Riza code interpreter (needs RIZA_API_KEY) (not supported by load_tools)
|
32 |
+
custom_tools = [multiply, add, subtract, divide, modulus, power]
|
33 |
+
tools = community_tools + custom_tools
|
34 |
tools_by_name = {tool.name: tool for tool in tools}
|
35 |
llm_with_tools = llm.bind_tools(tools)
|
36 |
|
|
|
105 |
# Compile the agent
|
106 |
self.agent = agent_builder.compile()
|
107 |
|
108 |
+
if show_tools_desc:
|
109 |
+
for i, tool in enumerate(llm_with_tools.kwargs['tools']):
|
110 |
+
print("\n" + "="*30 + f" Tool {i+1} " + "="*30)
|
111 |
+
print(json.dumps(tool[tool['type']], indent=4))
|
112 |
|
113 |
+
if show_prompt:
|
114 |
+
print("\n" + "="*30 + f" System prompt " + "="*30)
|
115 |
+
print(system_prompt)
|
|
|
116 |
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
def __call__(self, question: str) -> str:
|
119 |
print("\n\n"+"*"*50)
|
langgraph_dir/custom_tools.py
CHANGED
@@ -1,33 +1,68 @@
|
|
1 |
from langchain_core.tools import tool
|
2 |
|
3 |
@tool
|
4 |
-
def multiply(a:
|
5 |
-
"""
|
6 |
-
|
7 |
Args:
|
8 |
-
a: first
|
9 |
-
b: second
|
10 |
"""
|
11 |
return a * b
|
12 |
|
13 |
|
14 |
@tool
|
15 |
-
def add(a:
|
16 |
-
"""
|
17 |
-
|
18 |
Args:
|
19 |
-
a: first
|
20 |
-
b: second
|
21 |
"""
|
22 |
return a + b
|
23 |
|
24 |
|
25 |
@tool
|
26 |
-
def
|
27 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
|
|
|
|
|
|
|
|
29 |
Args:
|
30 |
-
a: first
|
31 |
-
b: second
|
32 |
"""
|
33 |
-
return a
|
|
|
1 |
from langchain_core.tools import tool
|
2 |
|
3 |
@tool
|
4 |
+
def multiply(a: float, b: float) -> float:
|
5 |
+
"""
|
6 |
+
Multiplies two numbers.
|
7 |
Args:
|
8 |
+
a (float): the first number
|
9 |
+
b (float): the second number
|
10 |
"""
|
11 |
return a * b
|
12 |
|
13 |
|
14 |
@tool
|
15 |
+
def add(a: float, b: float) -> float:
|
16 |
+
"""
|
17 |
+
Adds two numbers.
|
18 |
Args:
|
19 |
+
a (float): the first number
|
20 |
+
b (float): the second number
|
21 |
"""
|
22 |
return a + b
|
23 |
|
24 |
|
25 |
@tool
|
26 |
+
def subtract(a: float, b: float) -> int:
|
27 |
+
"""
|
28 |
+
Subtracts two numbers.
|
29 |
+
Args:
|
30 |
+
a (float): the first number
|
31 |
+
b (float): the second number
|
32 |
+
"""
|
33 |
+
return a - b
|
34 |
+
|
35 |
+
|
36 |
+
@tool
|
37 |
+
def divide(a: float, b: float) -> float:
|
38 |
+
"""
|
39 |
+
Divides two numbers.
|
40 |
+
Args:
|
41 |
+
a (float): the first float number
|
42 |
+
b (float): the second float number
|
43 |
+
"""
|
44 |
+
if b == 0:
|
45 |
+
raise ValueError("Cannot divided by zero.")
|
46 |
+
return a / b
|
47 |
+
|
48 |
+
|
49 |
+
@tool
|
50 |
+
def modulus(a: int, b: int) -> int:
|
51 |
+
"""
|
52 |
+
Get the modulus of two numbers.
|
53 |
+
Args:
|
54 |
+
a (int): the first number
|
55 |
+
b (int): the second number
|
56 |
+
"""
|
57 |
+
return a % b
|
58 |
+
|
59 |
|
60 |
+
@tool
|
61 |
+
def power(a: float, b: float) -> float:
|
62 |
+
"""
|
63 |
+
Get the power of two numbers.
|
64 |
Args:
|
65 |
+
a (float): the first number
|
66 |
+
b (float): the second number
|
67 |
"""
|
68 |
+
return a**b
|
requirements.txt
CHANGED
@@ -7,4 +7,7 @@ llama_index.tools.duckduckgo
|
|
7 |
llama_index.tools.code_interpreter
|
8 |
langchain
|
9 |
langgraph
|
10 |
-
langchain-openai
|
|
|
|
|
|
|
|
7 |
llama_index.tools.code_interpreter
|
8 |
langchain
|
9 |
langgraph
|
10 |
+
langchain-openai
|
11 |
+
langchain-community
|
12 |
+
duckduckgo-search
|
13 |
+
rizaio
|