|
|
|
|
|
from llama_index.core import PromptTemplate |
|
from llama_index.core.workflow import Context |
|
from llama_index.core.agent.workflow import ReActAgent, AgentStream, ToolCallResult |
|
from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI |
|
from llama_index.tools.wikipedia import WikipediaToolSpec |
|
from llama_index.tools.duckduckgo import DuckDuckGoSearchToolSpec |
|
from llama_index.tools.code_interpreter import CodeInterpreterToolSpec |
|
|
|
from .prompt import custom_react_system_header_str |
|
from .custom_tools import query_image_tool, automatic_speech_recognition_tool |
|
|
|
class LLamaIndexAgent: |
|
def __init__(self, |
|
model_name="Qwen/Qwen2.5-Coder-32B-Instruct", |
|
provider="hf-inference", |
|
show_tools_desc=True, |
|
show_prompt=True): |
|
|
|
|
|
llm = HuggingFaceInferenceAPI(model_name=model_name, |
|
provider=provider) |
|
print(f"LLamaIndexAgent initialized with model \"{model_name}\"") |
|
|
|
|
|
tool_spec_list = [] |
|
tool_spec_list += WikipediaToolSpec().to_tool_list() |
|
tool_spec_list += DuckDuckGoSearchToolSpec().to_tool_list() |
|
tool_spec_list += CodeInterpreterToolSpec().to_tool_list() |
|
tool_spec_list += [query_image_tool, automatic_speech_recognition_tool] |
|
|
|
|
|
self.agent = ReActAgent(llm=llm, tools=tool_spec_list) |
|
|
|
|
|
custom_react_system_header = PromptTemplate(custom_react_system_header_str) |
|
self.agent.update_prompts({"react_header": custom_react_system_header}) |
|
|
|
|
|
self.ctx = Context(self.agent) |
|
|
|
if show_tools_desc: |
|
for i, tool in enumerate(tool_spec_list): |
|
print("\n" + "="*30 + f" Tool {i+1} " + "="*30) |
|
print(tool.metadata.description) |
|
|
|
if show_prompt: |
|
prompt_dict = self.agent.get_prompts() |
|
for k, v in prompt_dict.items(): |
|
print("\n" + "="*30 + f" Prompt: {k} " + "="*30) |
|
print(v.template) |
|
|
|
async def __call__(self, question: str) -> str: |
|
print("\n\n"+"*"*50) |
|
print(f"Agent received question: {question}") |
|
print("*"*50) |
|
|
|
handler = self.agent.run(question, ctx=self.ctx) |
|
async for ev in handler.stream_events(): |
|
|
|
|
|
if isinstance(ev, AgentStream): |
|
print(f"{ev.delta}", end="", flush=True) |
|
|
|
response = await handler |
|
|
|
|
|
response = str(response) |
|
try: |
|
response = response.split("FINAL ANSWER:")[-1].strip() |
|
except: |
|
print('Could not split response on "FINAL ANSWER:"') |
|
print("\n\n"+"-"*50) |
|
print(f"Agent returning with answer: {response}") |
|
|
|
|
|
self.ctx.clear() |
|
|
|
return response |