File size: 3,258 Bytes
ca8728d 809f87e ca8728d 809f87e ca8728d 809f87e 3568413 809f87e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
# env variable needed: HF_TOKEN
from llama_index.core import PromptTemplate
from llama_index.core.workflow import Context
from llama_index.core.agent.workflow import ReActAgent, AgentStream, ToolCallResult
from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
from llama_index.tools.wikipedia import WikipediaToolSpec
from llama_index.tools.duckduckgo import DuckDuckGoSearchToolSpec
from llama_index.tools.code_interpreter import CodeInterpreterToolSpec
from .prompt import custom_react_system_header_str
from .custom_tools import query_image_tool, automatic_speech_recognition_tool
class LLamaIndexAgent:
def __init__(self,
model_name="Qwen/Qwen2.5-Coder-32B-Instruct",
provider="hf-inference",
show_tools_desc=True,
show_prompt=True):
# LLM definition
llm = HuggingFaceInferenceAPI(model_name=model_name, # needs HF_TOKEN in env
provider=provider)
print(f"LLamaIndexAgent initialized with model \"{model_name}\"")
# tools definition
tool_spec_list = []
tool_spec_list += WikipediaToolSpec().to_tool_list()
tool_spec_list += DuckDuckGoSearchToolSpec().to_tool_list()
tool_spec_list += CodeInterpreterToolSpec().to_tool_list()
tool_spec_list += [query_image_tool, automatic_speech_recognition_tool]
# agent definition
self.agent = ReActAgent(llm=llm, tools=tool_spec_list)
# update default prompt with a custom one
custom_react_system_header = PromptTemplate(custom_react_system_header_str)
self.agent.update_prompts({"react_header": custom_react_system_header})
# context definition
self.ctx = Context(self.agent)
if show_tools_desc:
for i, tool in enumerate(tool_spec_list):
print("\n" + "="*30 + f" Tool {i+1} " + "="*30)
print(tool.metadata.description)
if show_prompt:
prompt_dict = self.agent.get_prompts()
for k, v in prompt_dict.items():
print("\n" + "="*30 + f" Prompt: {k} " + "="*30)
print(v.template)
async def __call__(self, question: str) -> str:
print("\n\n"+"*"*50)
print(f"Agent received question: {question}")
print("*"*50)
handler = self.agent.run(question, ctx=self.ctx)
async for ev in handler.stream_events():
# if isinstance(ev, ToolCallResult):
# print(f"\nCall {ev.tool_name} with {ev.tool_kwargs}\nReturned: {ev.tool_output}")
if isinstance(ev, AgentStream):
print(f"{ev.delta}", end="", flush=True)
response = await handler
# post-process the response (cast AgentOutput to str and keep only what's after "FINAL ANSWER:" for the exact match)
response = str(response)
try:
response = response.split("FINAL ANSWER:")[-1].strip()
except:
print('Could not split response on "FINAL ANSWER:"')
print("\n\n"+"-"*50)
print(f"Agent returning with answer: {response}")
# clear context for next question before returning
self.ctx.clear()
return response |