Martin Bär commited on
Commit
dbb14b6
·
1 Parent(s): 8ea3490

Change agent into WorkFlow with sub-agents and use Google Gemini

Browse files
Files changed (4) hide show
  1. app.py +0 -7
  2. basic_agent.py +150 -27
  3. multimodality_tools.py +8 -1
  4. requirements.txt +3 -3
app.py CHANGED
@@ -12,12 +12,6 @@ from basic_agent import BasicAgent
12
  # --- Constants ---
13
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
14
 
15
- # For Llamaindex's LoadAndSearchTool
16
- Settings.llm = None # disable LLM for Index Retrieval
17
- Settings.chunk_size = 512 # Smaller chunk size for retrieval
18
-
19
- Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
20
-
21
  def run_and_submit_all( profile: gr.OAuthProfile | None):
22
  """
23
  Fetches all questions, runs the BasicAgent on them, submits all answers,
@@ -139,7 +133,6 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
139
  return status_message, results_df
140
 
141
  async def handle_agent_input(user_input):
142
- # TODO initialize agent at a different place
143
  agent = BasicAgent()
144
  response = await agent(user_input)
145
  return response
 
12
  # --- Constants ---
13
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
14
 
 
 
 
 
 
 
15
  def run_and_submit_all( profile: gr.OAuthProfile | None):
16
  """
17
  Fetches all questions, runs the BasicAgent on them, submits all answers,
 
133
  return status_message, results_df
134
 
135
  async def handle_agent_input(user_input):
 
136
  agent = BasicAgent()
137
  response = await agent(user_input)
138
  return response
basic_agent.py CHANGED
@@ -1,10 +1,20 @@
 
 
 
 
1
  from llama_index.core.tools import FunctionTool
2
  from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
3
  from llama_index.tools.duckduckgo import DuckDuckGoSearchToolSpec
4
  from llama_index.tools.wikipedia import WikipediaToolSpec
5
  from langfuse.llama_index import LlamaIndexInstrumentor
6
  from llama_index.llms.ollama import Ollama
7
- from llama_index.core.agent.workflow import FunctionAgent, ReActAgent
 
 
 
 
 
 
8
 
9
  from multimodality_tools import get_image_qa_tool, get_transcription_tool, \
10
  get_excel_analysis_tool, get_excel_tool, get_csv_analysis_tool, get_csv_tool
@@ -12,7 +22,8 @@ from multimodality_tools import get_image_qa_tool, get_transcription_tool, \
12
  class BasicAgent:
13
  def __init__(self, ollama=False, langfuse=False):
14
  if not ollama:
15
- llm = HuggingFaceInferenceAPI(model_name="Qwen/Qwen3-32B") #"Qwen/Qwen2.5-Coder-32B-Instruct")
 
16
  else:
17
  llm = Ollama(model="mistral:latest", request_timeout=120.0)
18
 
@@ -22,29 +33,15 @@ class BasicAgent:
22
  self.instrumentor = LlamaIndexInstrumentor()
23
  self.instrumentor.start()
24
 
25
- # Initialize tools
26
- tool_spec = DuckDuckGoSearchToolSpec()
27
- search_tool = FunctionTool.from_defaults(tool_spec.duckduckgo_full_search)
28
 
29
- # Convert into a LoadAndSearchToolSpec because the wikipedia search tool returns
30
- # entire Wikipedia pages and this can pollute the context window of the LLM
31
- wiki_spec = WikipediaToolSpec()
32
- wiki_search_tool = wiki_spec.to_tool_list()[1]
33
-
34
- # Convert into a LoadAndSearchToolSpec because the wikipedia search tool returns
35
- # entire Wikipedia pages and this can pollute the context window of the LLM
36
- # TODO this does not work so well. We need to make the retriever return the top 5 chunks or sth.
37
- # wiki_search_tool_las = LoadAndSearchToolSpec.from_defaults(wiki_search_tool).to_tool_list()
38
-
39
- self.agent = ReActAgent(
40
- tools=[search_tool, wiki_search_tool, get_image_qa_tool(),
41
- get_transcription_tool(), get_excel_analysis_tool(), get_excel_tool(),
42
- get_csv_analysis_tool(), get_csv_tool()],
43
- llm=llm,
44
- verbose=True,
45
- system_prompt = (
46
  "You are a general AI assistant. I will ask you a question. "
47
- "Report your thoughts, and finish your answer with the following template: "
 
48
  "FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number "
49
  "OR as few words as possible OR a comma separated list of numbers and/or "
50
  "strings. If you are asked for a number, don't use comma to write your "
@@ -53,19 +50,145 @@ class BasicAgent:
53
  "for cities), and write the digits in plain text unless specified otherwise. If "
54
  "you are asked for a comma separated list, apply the above rules depending of "
55
  "whether the element to be put in the list is a number or a string."
56
- )
 
 
 
57
  )
58
 
59
- # self.ctx = Context(self.agent)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  async def __call__(self, question: str, task_id: str = None) -> str:
62
  file_str = ""
63
  if task_id:
64
  file_str = f'\nIf you need to load a file, do so by providing the id "{task_id}".'
65
 
66
- response = await self.agent.run(user_msg=question + file_str) # ctx=self.ctx)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  if self.langfuse:
69
  self.instrumentor.flush()
70
 
71
- return response.response.content.replace("FINAL ANSWER:", "").strip()
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ from tavily import AsyncTavilyClient
5
  from llama_index.core.tools import FunctionTool
6
  from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
7
  from llama_index.tools.duckduckgo import DuckDuckGoSearchToolSpec
8
  from llama_index.tools.wikipedia import WikipediaToolSpec
9
  from langfuse.llama_index import LlamaIndexInstrumentor
10
  from llama_index.llms.ollama import Ollama
11
+ from llama_index.llms.google_genai import GoogleGenAI
12
+ from llama_index.core.agent.workflow import FunctionAgent, AgentWorkflow
13
+ from llama_index.core.agent.workflow import (
14
+ AgentOutput,
15
+ ToolCall,
16
+ ToolCallResult,
17
+ )
18
 
19
  from multimodality_tools import get_image_qa_tool, get_transcription_tool, \
20
  get_excel_analysis_tool, get_excel_tool, get_csv_analysis_tool, get_csv_tool
 
22
  class BasicAgent:
23
  def __init__(self, ollama=False, langfuse=False):
24
  if not ollama:
25
+ llm = GoogleGenAI(model="gemini-2.0-flash", api_key=os.getenv("GEMINI_API_KEY"))
26
+ # llm = HuggingFaceInferenceAPI(model_name="Qwen/Qwen3-32B") #"Qwen/Qwen2.5-Coder-32B-Instruct")
27
  else:
28
  llm = Ollama(model="mistral:latest", request_timeout=120.0)
29
 
 
33
  self.instrumentor = LlamaIndexInstrumentor()
34
  self.instrumentor.start()
35
 
36
+ # Initialize sub-agents
 
 
37
 
38
+ main_agent = FunctionAgent(
39
+ name="MainAgent",
40
+ description="Can organize and delegate work to different agents and can compile a final answer to a question from other agents' outputs.",
41
+ system_prompt=(
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  "You are a general AI assistant. I will ask you a question. "
43
+ "Report your thoughts, delegate work to other agents if necessary, and"
44
+ "finish your answer with the following template: "
45
  "FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number "
46
  "OR as few words as possible OR a comma separated list of numbers and/or "
47
  "strings. If you are asked for a number, don't use comma to write your "
 
50
  "for cities), and write the digits in plain text unless specified otherwise. If "
51
  "you are asked for a comma separated list, apply the above rules depending of "
52
  "whether the element to be put in the list is a number or a string."
53
+ ),
54
+ llm=llm,
55
+ tools=[],
56
+ can_handoff_to=["WikiAgent", "WebAgent", "StatsAgent", "AudioAgent", "ImageAgent"],
57
  )
58
 
59
+ # Wikipedia tool does not return the tables from the page...
60
+ wiki_spec = WikipediaToolSpec()
61
+ wiki_search_tool = wiki_spec.to_tool_list()[1]
62
+
63
+ wiki_agent = FunctionAgent(
64
+ name="WikiAgent",
65
+ description="Uses wikipedia to answer a question.",
66
+ system_prompt=(
67
+ "You are a Wikipedia agent that can search Wikipedia for information to answer a question. "
68
+ "You only give concise answers and if you don't find an answer to the given query on Wikipedia, "
69
+ "you communicate this clearly. Always hand off your answer to MainAgent."
70
+ ),
71
+ llm=llm,
72
+ tools=[wiki_search_tool],
73
+ can_handoff_to=["MainAgent"],
74
+ )
75
+
76
+ tool_spec = DuckDuckGoSearchToolSpec()
77
+ search_tool = FunctionTool.from_defaults(tool_spec.duckduckgo_full_search)
78
+ # In case DuckDuckGo is not good enough
79
+ async def search_web(query: str) -> str:
80
+ """Searches the web to answer questions."""
81
+ client = AsyncTavilyClient(api_key=os.getenv("TAVILY"))
82
+ return str(await client.search(query))
83
+
84
+ web_search_agent = FunctionAgent(
85
+ name="WebAgent",
86
+ description="Uses the web to answer a question.",
87
+ system_prompt=(
88
+ "You are a Web agent that can search the Web for information to answer a question. "
89
+ "You only give concise answers and if you don't find an answer to the given query with your tool, "
90
+ "you communicate this clearly. Always hand off your answer to MainAgent."
91
+ ),
92
+ llm=llm,
93
+ tools=[search_web],
94
+ can_handoff_to=["MainAgent"],
95
+ )
96
+
97
+ audio_agent = FunctionAgent(
98
+ name="AudioAgent",
99
+ description="Uses transcription tools to analyze audio files.",
100
+ system_prompt=(
101
+ "You are an audio agent that can transcribe an audio file identified by its id and answer questions about it. "
102
+ "You only give concise answers and if you cannot answer the given query using your tool, "
103
+ "you communicate this clearly. Always hand off your answer to MainAgent."
104
+ ),
105
+ llm=llm,
106
+ tools=[get_transcription_tool()],
107
+ can_handoff_to=["MainAgent"],
108
+ )
109
+
110
+ image_agent = FunctionAgent(
111
+ name="ImageAgent",
112
+ description="Uses image analysis tools to analyze images and respond to questions.",
113
+ system_prompt=(
114
+ "You are an agent that can read images from a file identified by its id and answer questions about it. "
115
+ "You only give concise answers and if you cannot answer the given query using your tool, "
116
+ "you communicate this clearly. Always hand off your answer to MainAgent."
117
+ ),
118
+ llm=llm,
119
+ tools=[get_image_qa_tool()],
120
+ can_handoff_to=["MainAgent"],
121
+ )
122
+
123
+ stats_agent = FunctionAgent(
124
+ name="ImageAgent",
125
+ description="Uses statistical tools to read and analyse excel and csv files.",
126
+ system_prompt=(
127
+ "You are an agent that can read excel and csv files and run simple statistical analysis on them. "
128
+ "You can use this information or the loaded file to answer questions about it. "
129
+ "You only give concise answers and if you cannot answer the given query using your tool, "
130
+ "you communicate this clearly. Always hand off your answer to MainAgent."
131
+ ),
132
+ llm=llm,
133
+ tools=[get_csv_analysis_tool(), get_csv_tool(),
134
+ get_excel_analysis_tool(), get_excel_tool()],
135
+ can_handoff_to=["MainAgent"],
136
+ )
137
+
138
+ # Main AgentWorkflow
139
+ self.agent = AgentWorkflow(
140
+ agents=[main_agent, wiki_agent, web_search_agent,
141
+ audio_agent, image_agent, stats_agent],
142
+ root_agent=main_agent.name,
143
+ )
144
 
145
  async def __call__(self, question: str, task_id: str = None) -> str:
146
  file_str = ""
147
  if task_id:
148
  file_str = f'\nIf you need to load a file, do so by providing the id "{task_id}".'
149
 
150
+ msg = f"{question}{file_str}"
151
+
152
+ # Stream events
153
+ handler = self.agent.run(user_msg=msg)
154
+
155
+ current_agent = None
156
+ current_tool_calls = ""
157
+ async for event in handler.stream_events():
158
+ if (
159
+ hasattr(event, "current_agent_name")
160
+ and event.current_agent_name != current_agent
161
+ ):
162
+ current_agent = event.current_agent_name
163
+ print(f"\n{'='*50}")
164
+ print(f"🤖 Agent: {current_agent}")
165
+ print(f"{'='*50}\n")
166
+
167
+ # if isinstance(event, AgentStream):
168
+ # if event.delta:
169
+ # print(event.delta, end="", flush=True)
170
+ # elif isinstance(event, AgentInput):
171
+ # print("📥 Input:", event.input)
172
+ elif isinstance(event, AgentOutput):
173
+ if event.response.content:
174
+ print("📤 Output:", event.response.content)
175
+ if event.tool_calls:
176
+ print(
177
+ "🛠️ Planning to use tools:",
178
+ [call.tool_name for call in event.tool_calls],
179
+ )
180
+ elif isinstance(event, ToolCallResult):
181
+ print(f"🔧 Tool Result ({event.tool_name}):")
182
+ print(f" Arguments: {event.tool_kwargs}")
183
+ print(f" Output: {event.tool_output}")
184
+ elif isinstance(event, ToolCall):
185
+ print(f"🔨 Calling Tool: {event.tool_name}")
186
+ print(f" With arguments: {event.tool_kwargs}")
187
 
188
  if self.langfuse:
189
  self.instrumentor.flush()
190
 
191
+ res = await handler
192
+ res = res.response.content.strip()
193
+ res = re.sub(r'^.*?FINAL ANSWER:', 'FINAL ANSWER:', res, flags=re.DOTALL)
194
+ return res
multimodality_tools.py CHANGED
@@ -2,6 +2,7 @@
2
 
3
  import os
4
  import io
 
5
  import requests
6
 
7
  import librosa
@@ -86,7 +87,7 @@ def answer_image_question(question: str, file_id: str) -> str:
86
  max_tokens=512,
87
  )
88
 
89
- return completion.choices[0].message.content
90
 
91
  def get_image_qa_tool():
92
  return FunctionTool.from_defaults(
@@ -153,3 +154,9 @@ def _get_file(task_id: str) -> io.BytesIO:
153
  raise FileNotFoundError("Invalid file or task id.")
154
  file_like = io.BytesIO(res.content)
155
  return file_like
 
 
 
 
 
 
 
2
 
3
  import os
4
  import io
5
+ import re
6
  import requests
7
 
8
  import librosa
 
87
  max_tokens=512,
88
  )
89
 
90
+ return remove_think(completion.choices[0].message.content)
91
 
92
  def get_image_qa_tool():
93
  return FunctionTool.from_defaults(
 
154
  raise FileNotFoundError("Invalid file or task id.")
155
  file_like = io.BytesIO(res.content)
156
  return file_like
157
+
158
+ def remove_think(output: str) -> str:
159
+ """Removes the <think> part of an LLM output."""
160
+ if output:
161
+ return re.sub("<think>.*</think>", "", output).strip()
162
+ return output
requirements.txt CHANGED
@@ -4,9 +4,8 @@ llama-index
4
  llama-index-llms-huggingface-api
5
  llama_index-tools-duckduckgo
6
  llama_index-tools-wikipedia
7
- llama-index-embeddings-huggingface
8
- llama-index-readers-web
9
  llama-index-llms-ollama
 
10
  langfuse
11
  tabulate
12
  soundfile
@@ -14,4 +13,5 @@ librosa
14
  pillow
15
  pandas
16
  huggingface_hub
17
- transformers
 
 
4
  llama-index-llms-huggingface-api
5
  llama_index-tools-duckduckgo
6
  llama_index-tools-wikipedia
 
 
7
  llama-index-llms-ollama
8
+ llama-index-llms-google-genai
9
  langfuse
10
  tabulate
11
  soundfile
 
13
  pillow
14
  pandas
15
  huggingface_hub
16
+ transformers
17
+ tavily-python