Martin Bär
commited on
Commit
·
dbb14b6
1
Parent(s):
8ea3490
Change agent into WorkFlow with sub-agents and use Google Gemini
Browse files- app.py +0 -7
- basic_agent.py +150 -27
- multimodality_tools.py +8 -1
- requirements.txt +3 -3
app.py
CHANGED
@@ -12,12 +12,6 @@ from basic_agent import BasicAgent
|
|
12 |
# --- Constants ---
|
13 |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
14 |
|
15 |
-
# For Llamaindex's LoadAndSearchTool
|
16 |
-
Settings.llm = None # disable LLM for Index Retrieval
|
17 |
-
Settings.chunk_size = 512 # Smaller chunk size for retrieval
|
18 |
-
|
19 |
-
Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
|
20 |
-
|
21 |
def run_and_submit_all( profile: gr.OAuthProfile | None):
|
22 |
"""
|
23 |
Fetches all questions, runs the BasicAgent on them, submits all answers,
|
@@ -139,7 +133,6 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
|
|
139 |
return status_message, results_df
|
140 |
|
141 |
async def handle_agent_input(user_input):
|
142 |
-
# TODO initialize agent at a different place
|
143 |
agent = BasicAgent()
|
144 |
response = await agent(user_input)
|
145 |
return response
|
|
|
12 |
# --- Constants ---
|
13 |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
def run_and_submit_all( profile: gr.OAuthProfile | None):
|
16 |
"""
|
17 |
Fetches all questions, runs the BasicAgent on them, submits all answers,
|
|
|
133 |
return status_message, results_df
|
134 |
|
135 |
async def handle_agent_input(user_input):
|
|
|
136 |
agent = BasicAgent()
|
137 |
response = await agent(user_input)
|
138 |
return response
|
basic_agent.py
CHANGED
@@ -1,10 +1,20 @@
|
|
|
|
|
|
|
|
|
|
1 |
from llama_index.core.tools import FunctionTool
|
2 |
from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
|
3 |
from llama_index.tools.duckduckgo import DuckDuckGoSearchToolSpec
|
4 |
from llama_index.tools.wikipedia import WikipediaToolSpec
|
5 |
from langfuse.llama_index import LlamaIndexInstrumentor
|
6 |
from llama_index.llms.ollama import Ollama
|
7 |
-
from llama_index.
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
from multimodality_tools import get_image_qa_tool, get_transcription_tool, \
|
10 |
get_excel_analysis_tool, get_excel_tool, get_csv_analysis_tool, get_csv_tool
|
@@ -12,7 +22,8 @@ from multimodality_tools import get_image_qa_tool, get_transcription_tool, \
|
|
12 |
class BasicAgent:
|
13 |
def __init__(self, ollama=False, langfuse=False):
|
14 |
if not ollama:
|
15 |
-
llm =
|
|
|
16 |
else:
|
17 |
llm = Ollama(model="mistral:latest", request_timeout=120.0)
|
18 |
|
@@ -22,29 +33,15 @@ class BasicAgent:
|
|
22 |
self.instrumentor = LlamaIndexInstrumentor()
|
23 |
self.instrumentor.start()
|
24 |
|
25 |
-
# Initialize
|
26 |
-
tool_spec = DuckDuckGoSearchToolSpec()
|
27 |
-
search_tool = FunctionTool.from_defaults(tool_spec.duckduckgo_full_search)
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
# Convert into a LoadAndSearchToolSpec because the wikipedia search tool returns
|
35 |
-
# entire Wikipedia pages and this can pollute the context window of the LLM
|
36 |
-
# TODO this does not work so well. We need to make the retriever return the top 5 chunks or sth.
|
37 |
-
# wiki_search_tool_las = LoadAndSearchToolSpec.from_defaults(wiki_search_tool).to_tool_list()
|
38 |
-
|
39 |
-
self.agent = ReActAgent(
|
40 |
-
tools=[search_tool, wiki_search_tool, get_image_qa_tool(),
|
41 |
-
get_transcription_tool(), get_excel_analysis_tool(), get_excel_tool(),
|
42 |
-
get_csv_analysis_tool(), get_csv_tool()],
|
43 |
-
llm=llm,
|
44 |
-
verbose=True,
|
45 |
-
system_prompt = (
|
46 |
"You are a general AI assistant. I will ask you a question. "
|
47 |
-
"Report your thoughts,
|
|
|
48 |
"FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number "
|
49 |
"OR as few words as possible OR a comma separated list of numbers and/or "
|
50 |
"strings. If you are asked for a number, don't use comma to write your "
|
@@ -53,19 +50,145 @@ class BasicAgent:
|
|
53 |
"for cities), and write the digits in plain text unless specified otherwise. If "
|
54 |
"you are asked for a comma separated list, apply the above rules depending of "
|
55 |
"whether the element to be put in the list is a number or a string."
|
56 |
-
)
|
|
|
|
|
|
|
57 |
)
|
58 |
|
59 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
async def __call__(self, question: str, task_id: str = None) -> str:
|
62 |
file_str = ""
|
63 |
if task_id:
|
64 |
file_str = f'\nIf you need to load a file, do so by providing the id "{task_id}".'
|
65 |
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
if self.langfuse:
|
69 |
self.instrumentor.flush()
|
70 |
|
71 |
-
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
|
4 |
+
from tavily import AsyncTavilyClient
|
5 |
from llama_index.core.tools import FunctionTool
|
6 |
from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
|
7 |
from llama_index.tools.duckduckgo import DuckDuckGoSearchToolSpec
|
8 |
from llama_index.tools.wikipedia import WikipediaToolSpec
|
9 |
from langfuse.llama_index import LlamaIndexInstrumentor
|
10 |
from llama_index.llms.ollama import Ollama
|
11 |
+
from llama_index.llms.google_genai import GoogleGenAI
|
12 |
+
from llama_index.core.agent.workflow import FunctionAgent, AgentWorkflow
|
13 |
+
from llama_index.core.agent.workflow import (
|
14 |
+
AgentOutput,
|
15 |
+
ToolCall,
|
16 |
+
ToolCallResult,
|
17 |
+
)
|
18 |
|
19 |
from multimodality_tools import get_image_qa_tool, get_transcription_tool, \
|
20 |
get_excel_analysis_tool, get_excel_tool, get_csv_analysis_tool, get_csv_tool
|
|
|
22 |
class BasicAgent:
|
23 |
def __init__(self, ollama=False, langfuse=False):
|
24 |
if not ollama:
|
25 |
+
llm = GoogleGenAI(model="gemini-2.0-flash", api_key=os.getenv("GEMINI_API_KEY"))
|
26 |
+
# llm = HuggingFaceInferenceAPI(model_name="Qwen/Qwen3-32B") #"Qwen/Qwen2.5-Coder-32B-Instruct")
|
27 |
else:
|
28 |
llm = Ollama(model="mistral:latest", request_timeout=120.0)
|
29 |
|
|
|
33 |
self.instrumentor = LlamaIndexInstrumentor()
|
34 |
self.instrumentor.start()
|
35 |
|
36 |
+
# Initialize sub-agents
|
|
|
|
|
37 |
|
38 |
+
main_agent = FunctionAgent(
|
39 |
+
name="MainAgent",
|
40 |
+
description="Can organize and delegate work to different agents and can compile a final answer to a question from other agents' outputs.",
|
41 |
+
system_prompt=(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
"You are a general AI assistant. I will ask you a question. "
|
43 |
+
"Report your thoughts, delegate work to other agents if necessary, and"
|
44 |
+
"finish your answer with the following template: "
|
45 |
"FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number "
|
46 |
"OR as few words as possible OR a comma separated list of numbers and/or "
|
47 |
"strings. If you are asked for a number, don't use comma to write your "
|
|
|
50 |
"for cities), and write the digits in plain text unless specified otherwise. If "
|
51 |
"you are asked for a comma separated list, apply the above rules depending of "
|
52 |
"whether the element to be put in the list is a number or a string."
|
53 |
+
),
|
54 |
+
llm=llm,
|
55 |
+
tools=[],
|
56 |
+
can_handoff_to=["WikiAgent", "WebAgent", "StatsAgent", "AudioAgent", "ImageAgent"],
|
57 |
)
|
58 |
|
59 |
+
# Wikipedia tool does not return the tables from the page...
|
60 |
+
wiki_spec = WikipediaToolSpec()
|
61 |
+
wiki_search_tool = wiki_spec.to_tool_list()[1]
|
62 |
+
|
63 |
+
wiki_agent = FunctionAgent(
|
64 |
+
name="WikiAgent",
|
65 |
+
description="Uses wikipedia to answer a question.",
|
66 |
+
system_prompt=(
|
67 |
+
"You are a Wikipedia agent that can search Wikipedia for information to answer a question. "
|
68 |
+
"You only give concise answers and if you don't find an answer to the given query on Wikipedia, "
|
69 |
+
"you communicate this clearly. Always hand off your answer to MainAgent."
|
70 |
+
),
|
71 |
+
llm=llm,
|
72 |
+
tools=[wiki_search_tool],
|
73 |
+
can_handoff_to=["MainAgent"],
|
74 |
+
)
|
75 |
+
|
76 |
+
tool_spec = DuckDuckGoSearchToolSpec()
|
77 |
+
search_tool = FunctionTool.from_defaults(tool_spec.duckduckgo_full_search)
|
78 |
+
# In case DuckDuckGo is not good enough
|
79 |
+
async def search_web(query: str) -> str:
|
80 |
+
"""Searches the web to answer questions."""
|
81 |
+
client = AsyncTavilyClient(api_key=os.getenv("TAVILY"))
|
82 |
+
return str(await client.search(query))
|
83 |
+
|
84 |
+
web_search_agent = FunctionAgent(
|
85 |
+
name="WebAgent",
|
86 |
+
description="Uses the web to answer a question.",
|
87 |
+
system_prompt=(
|
88 |
+
"You are a Web agent that can search the Web for information to answer a question. "
|
89 |
+
"You only give concise answers and if you don't find an answer to the given query with your tool, "
|
90 |
+
"you communicate this clearly. Always hand off your answer to MainAgent."
|
91 |
+
),
|
92 |
+
llm=llm,
|
93 |
+
tools=[search_web],
|
94 |
+
can_handoff_to=["MainAgent"],
|
95 |
+
)
|
96 |
+
|
97 |
+
audio_agent = FunctionAgent(
|
98 |
+
name="AudioAgent",
|
99 |
+
description="Uses transcription tools to analyze audio files.",
|
100 |
+
system_prompt=(
|
101 |
+
"You are an audio agent that can transcribe an audio file identified by its id and answer questions about it. "
|
102 |
+
"You only give concise answers and if you cannot answer the given query using your tool, "
|
103 |
+
"you communicate this clearly. Always hand off your answer to MainAgent."
|
104 |
+
),
|
105 |
+
llm=llm,
|
106 |
+
tools=[get_transcription_tool()],
|
107 |
+
can_handoff_to=["MainAgent"],
|
108 |
+
)
|
109 |
+
|
110 |
+
image_agent = FunctionAgent(
|
111 |
+
name="ImageAgent",
|
112 |
+
description="Uses image analysis tools to analyze images and respond to questions.",
|
113 |
+
system_prompt=(
|
114 |
+
"You are an agent that can read images from a file identified by its id and answer questions about it. "
|
115 |
+
"You only give concise answers and if you cannot answer the given query using your tool, "
|
116 |
+
"you communicate this clearly. Always hand off your answer to MainAgent."
|
117 |
+
),
|
118 |
+
llm=llm,
|
119 |
+
tools=[get_image_qa_tool()],
|
120 |
+
can_handoff_to=["MainAgent"],
|
121 |
+
)
|
122 |
+
|
123 |
+
stats_agent = FunctionAgent(
|
124 |
+
name="ImageAgent",
|
125 |
+
description="Uses statistical tools to read and analyse excel and csv files.",
|
126 |
+
system_prompt=(
|
127 |
+
"You are an agent that can read excel and csv files and run simple statistical analysis on them. "
|
128 |
+
"You can use this information or the loaded file to answer questions about it. "
|
129 |
+
"You only give concise answers and if you cannot answer the given query using your tool, "
|
130 |
+
"you communicate this clearly. Always hand off your answer to MainAgent."
|
131 |
+
),
|
132 |
+
llm=llm,
|
133 |
+
tools=[get_csv_analysis_tool(), get_csv_tool(),
|
134 |
+
get_excel_analysis_tool(), get_excel_tool()],
|
135 |
+
can_handoff_to=["MainAgent"],
|
136 |
+
)
|
137 |
+
|
138 |
+
# Main AgentWorkflow
|
139 |
+
self.agent = AgentWorkflow(
|
140 |
+
agents=[main_agent, wiki_agent, web_search_agent,
|
141 |
+
audio_agent, image_agent, stats_agent],
|
142 |
+
root_agent=main_agent.name,
|
143 |
+
)
|
144 |
|
145 |
async def __call__(self, question: str, task_id: str = None) -> str:
|
146 |
file_str = ""
|
147 |
if task_id:
|
148 |
file_str = f'\nIf you need to load a file, do so by providing the id "{task_id}".'
|
149 |
|
150 |
+
msg = f"{question}{file_str}"
|
151 |
+
|
152 |
+
# Stream events
|
153 |
+
handler = self.agent.run(user_msg=msg)
|
154 |
+
|
155 |
+
current_agent = None
|
156 |
+
current_tool_calls = ""
|
157 |
+
async for event in handler.stream_events():
|
158 |
+
if (
|
159 |
+
hasattr(event, "current_agent_name")
|
160 |
+
and event.current_agent_name != current_agent
|
161 |
+
):
|
162 |
+
current_agent = event.current_agent_name
|
163 |
+
print(f"\n{'='*50}")
|
164 |
+
print(f"🤖 Agent: {current_agent}")
|
165 |
+
print(f"{'='*50}\n")
|
166 |
+
|
167 |
+
# if isinstance(event, AgentStream):
|
168 |
+
# if event.delta:
|
169 |
+
# print(event.delta, end="", flush=True)
|
170 |
+
# elif isinstance(event, AgentInput):
|
171 |
+
# print("📥 Input:", event.input)
|
172 |
+
elif isinstance(event, AgentOutput):
|
173 |
+
if event.response.content:
|
174 |
+
print("📤 Output:", event.response.content)
|
175 |
+
if event.tool_calls:
|
176 |
+
print(
|
177 |
+
"🛠️ Planning to use tools:",
|
178 |
+
[call.tool_name for call in event.tool_calls],
|
179 |
+
)
|
180 |
+
elif isinstance(event, ToolCallResult):
|
181 |
+
print(f"🔧 Tool Result ({event.tool_name}):")
|
182 |
+
print(f" Arguments: {event.tool_kwargs}")
|
183 |
+
print(f" Output: {event.tool_output}")
|
184 |
+
elif isinstance(event, ToolCall):
|
185 |
+
print(f"🔨 Calling Tool: {event.tool_name}")
|
186 |
+
print(f" With arguments: {event.tool_kwargs}")
|
187 |
|
188 |
if self.langfuse:
|
189 |
self.instrumentor.flush()
|
190 |
|
191 |
+
res = await handler
|
192 |
+
res = res.response.content.strip()
|
193 |
+
res = re.sub(r'^.*?FINAL ANSWER:', 'FINAL ANSWER:', res, flags=re.DOTALL)
|
194 |
+
return res
|
multimodality_tools.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2 |
|
3 |
import os
|
4 |
import io
|
|
|
5 |
import requests
|
6 |
|
7 |
import librosa
|
@@ -86,7 +87,7 @@ def answer_image_question(question: str, file_id: str) -> str:
|
|
86 |
max_tokens=512,
|
87 |
)
|
88 |
|
89 |
-
return completion.choices[0].message.content
|
90 |
|
91 |
def get_image_qa_tool():
|
92 |
return FunctionTool.from_defaults(
|
@@ -153,3 +154,9 @@ def _get_file(task_id: str) -> io.BytesIO:
|
|
153 |
raise FileNotFoundError("Invalid file or task id.")
|
154 |
file_like = io.BytesIO(res.content)
|
155 |
return file_like
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
import os
|
4 |
import io
|
5 |
+
import re
|
6 |
import requests
|
7 |
|
8 |
import librosa
|
|
|
87 |
max_tokens=512,
|
88 |
)
|
89 |
|
90 |
+
return remove_think(completion.choices[0].message.content)
|
91 |
|
92 |
def get_image_qa_tool():
|
93 |
return FunctionTool.from_defaults(
|
|
|
154 |
raise FileNotFoundError("Invalid file or task id.")
|
155 |
file_like = io.BytesIO(res.content)
|
156 |
return file_like
|
157 |
+
|
158 |
+
def remove_think(output: str) -> str:
|
159 |
+
"""Removes the <think> part of an LLM output."""
|
160 |
+
if output:
|
161 |
+
return re.sub("<think>.*</think>", "", output).strip()
|
162 |
+
return output
|
requirements.txt
CHANGED
@@ -4,9 +4,8 @@ llama-index
|
|
4 |
llama-index-llms-huggingface-api
|
5 |
llama_index-tools-duckduckgo
|
6 |
llama_index-tools-wikipedia
|
7 |
-
llama-index-embeddings-huggingface
|
8 |
-
llama-index-readers-web
|
9 |
llama-index-llms-ollama
|
|
|
10 |
langfuse
|
11 |
tabulate
|
12 |
soundfile
|
@@ -14,4 +13,5 @@ librosa
|
|
14 |
pillow
|
15 |
pandas
|
16 |
huggingface_hub
|
17 |
-
transformers
|
|
|
|
4 |
llama-index-llms-huggingface-api
|
5 |
llama_index-tools-duckduckgo
|
6 |
llama_index-tools-wikipedia
|
|
|
|
|
7 |
llama-index-llms-ollama
|
8 |
+
llama-index-llms-google-genai
|
9 |
langfuse
|
10 |
tabulate
|
11 |
soundfile
|
|
|
13 |
pillow
|
14 |
pandas
|
15 |
huggingface_hub
|
16 |
+
transformers
|
17 |
+
tavily-python
|