guillaumefrd commited on
Commit
283e426
·
1 Parent(s): 809f87e

add tools to langgraph

Browse files
app.py CHANGED
@@ -10,8 +10,8 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
10
 
11
 
12
  # --- Choice of framework (either "langgraph" or "llamaindex") ---
13
- # FRAMEWORK = 'langgraph'
14
- FRAMEWORK = 'llamaindex'
15
 
16
 
17
  async def run_and_submit_all(profile: gr.OAuthProfile | None):
@@ -98,7 +98,10 @@ async def run_and_submit_all(profile: gr.OAuthProfile | None):
98
  pass
99
 
100
  # call the agent
101
- submitted_answer = await agent(question_text)
 
 
 
102
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
103
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
104
  agent.ctx.clear() # clear context for next question
 
10
 
11
 
12
  # --- Choice of framework (either "langgraph" or "llamaindex") ---
13
+ FRAMEWORK = 'langgraph'
14
+ # FRAMEWORK = 'llamaindex'
15
 
16
 
17
  async def run_and_submit_all(profile: gr.OAuthProfile | None):
 
98
  pass
99
 
100
  # call the agent
101
+ if FRAMEWORK == 'llamaindex':
102
+ submitted_answer = await agent(question_text)
103
+ else:
104
+ submitted_answer = agent(question_text)
105
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
106
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
107
  agent.ctx.clear() # clear context for next question
langgraph_dir/agent.py CHANGED
@@ -1,12 +1,15 @@
1
- from typing import Literal
2
 
 
3
  from langchain_openai import ChatOpenAI
4
  from langgraph.graph import MessagesState
5
  from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage
6
  from langgraph.graph import StateGraph, START, END
 
 
7
 
8
  from .prompt import system_prompt
9
- from .custom_tools import multiply, add, divide
10
 
11
 
12
  class LangGraphAgent:
@@ -16,11 +19,18 @@ class LangGraphAgent:
16
  show_prompt=True):
17
 
18
  # =========== LLM definition ===========
19
- llm = ChatOpenAI(model=model_name, temperature=0)
20
  print(f"LangGraphAgent initialized with model \"{model_name}\"")
21
 
22
  # =========== Augment the LLM with tools ===========
23
- tools = [add, multiply, divide]
 
 
 
 
 
 
 
24
  tools_by_name = {tool.name: tool for tool in tools}
25
  llm_with_tools = llm.bind_tools(tools)
26
 
@@ -95,17 +105,15 @@ class LangGraphAgent:
95
  # Compile the agent
96
  self.agent = agent_builder.compile()
97
 
 
 
 
 
98
 
99
- # if show_tools_desc:
100
- # for i, tool in enumerate(tool_spec_list):
101
- # print("\n" + "="*30 + f" Tool {i+1} " + "="*30)
102
- # print(tool.metadata.description)
103
 
104
- # if show_prompt:
105
- # prompt_dict = self.agent.get_prompts()
106
- # for k, v in prompt_dict.items():
107
- # print("\n" + "="*30 + f" Prompt: {k} " + "="*30)
108
- # print(v.template)
109
 
110
  def __call__(self, question: str) -> str:
111
  print("\n\n"+"*"*50)
 
1
+ import json
2
 
3
+ from typing import Literal
4
  from langchain_openai import ChatOpenAI
5
  from langgraph.graph import MessagesState
6
  from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage
7
  from langgraph.graph import StateGraph, START, END
8
+ from langchain.agents import load_tools
9
+ from langchain_community.tools.riza.command import ExecPython
10
 
11
  from .prompt import system_prompt
12
+ from .custom_tools import multiply, add, subtract, divide, modulus, power
13
 
14
 
15
  class LangGraphAgent:
 
19
  show_prompt=True):
20
 
21
  # =========== LLM definition ===========
22
+ llm = ChatOpenAI(model=model_name, temperature=0) # needs OPENAI_API_KEY
23
  print(f"LangGraphAgent initialized with model \"{model_name}\"")
24
 
25
  # =========== Augment the LLM with tools ===========
26
+ community_tool_names = [
27
+ "ddg-search", # DuckDuckGo search
28
+ "wikipedia",
29
+ ]
30
+ community_tools = load_tools(community_tool_names)
31
+ community_tools += [ExecPython()] # Riza code interpreter (needs RIZA_API_KEY) (not supported by load_tools)
32
+ custom_tools = [multiply, add, subtract, divide, modulus, power]
33
+ tools = community_tools + custom_tools
34
  tools_by_name = {tool.name: tool for tool in tools}
35
  llm_with_tools = llm.bind_tools(tools)
36
 
 
105
  # Compile the agent
106
  self.agent = agent_builder.compile()
107
 
108
+ if show_tools_desc:
109
+ for i, tool in enumerate(llm_with_tools.kwargs['tools']):
110
+ print("\n" + "="*30 + f" Tool {i+1} " + "="*30)
111
+ print(json.dumps(tool[tool['type']], indent=4))
112
 
113
+ if show_prompt:
114
+ print("\n" + "="*30 + f" System prompt " + "="*30)
115
+ print(system_prompt)
 
116
 
 
 
 
 
 
117
 
118
  def __call__(self, question: str) -> str:
119
  print("\n\n"+"*"*50)
langgraph_dir/custom_tools.py CHANGED
@@ -1,33 +1,68 @@
1
  from langchain_core.tools import tool
2
 
3
  @tool
4
- def multiply(a: int, b: int) -> int:
5
- """Multiply a and b.
6
-
7
  Args:
8
- a: first int
9
- b: second int
10
  """
11
  return a * b
12
 
13
 
14
  @tool
15
- def add(a: int, b: int) -> int:
16
- """Adds a and b.
17
-
18
  Args:
19
- a: first int
20
- b: second int
21
  """
22
  return a + b
23
 
24
 
25
  @tool
26
- def divide(a: int, b: int) -> float:
27
- """Divide a and b.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
 
 
 
 
29
  Args:
30
- a: first int
31
- b: second int
32
  """
33
- return a / b
 
1
  from langchain_core.tools import tool
2
 
3
  @tool
4
+ def multiply(a: float, b: float) -> float:
5
+ """
6
+ Multiplies two numbers.
7
  Args:
8
+ a (float): the first number
9
+ b (float): the second number
10
  """
11
  return a * b
12
 
13
 
14
  @tool
15
+ def add(a: float, b: float) -> float:
16
+ """
17
+ Adds two numbers.
18
  Args:
19
+ a (float): the first number
20
+ b (float): the second number
21
  """
22
  return a + b
23
 
24
 
25
  @tool
26
+ def subtract(a: float, b: float) -> int:
27
+ """
28
+ Subtracts two numbers.
29
+ Args:
30
+ a (float): the first number
31
+ b (float): the second number
32
+ """
33
+ return a - b
34
+
35
+
36
+ @tool
37
+ def divide(a: float, b: float) -> float:
38
+ """
39
+ Divides two numbers.
40
+ Args:
41
+ a (float): the first float number
42
+ b (float): the second float number
43
+ """
44
+ if b == 0:
45
+ raise ValueError("Cannot divided by zero.")
46
+ return a / b
47
+
48
+
49
+ @tool
50
+ def modulus(a: int, b: int) -> int:
51
+ """
52
+ Get the modulus of two numbers.
53
+ Args:
54
+ a (int): the first number
55
+ b (int): the second number
56
+ """
57
+ return a % b
58
+
59
 
60
+ @tool
61
+ def power(a: float, b: float) -> float:
62
+ """
63
+ Get the power of two numbers.
64
  Args:
65
+ a (float): the first number
66
+ b (float): the second number
67
  """
68
+ return a**b
requirements.txt CHANGED
@@ -7,4 +7,7 @@ llama_index.tools.duckduckgo
7
  llama_index.tools.code_interpreter
8
  langchain
9
  langgraph
10
- langchain-openai
 
 
 
 
7
  llama_index.tools.code_interpreter
8
  langchain
9
  langgraph
10
+ langchain-openai
11
+ langchain-community
12
+ duckduckgo-search
13
+ rizaio