File size: 3,258 Bytes
ca8728d
 
809f87e
 
 
ca8728d
809f87e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca8728d
809f87e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3568413
 
 
 
809f87e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# env variable needed: HF_TOKEN

from llama_index.core import PromptTemplate
from llama_index.core.workflow import Context
from llama_index.core.agent.workflow import ReActAgent, AgentStream, ToolCallResult
from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
from llama_index.tools.wikipedia import WikipediaToolSpec
from llama_index.tools.duckduckgo import DuckDuckGoSearchToolSpec
from llama_index.tools.code_interpreter import CodeInterpreterToolSpec

from .prompt import custom_react_system_header_str
from .custom_tools import query_image_tool, automatic_speech_recognition_tool

class LLamaIndexAgent:
    def __init__(self,
                 model_name="Qwen/Qwen2.5-Coder-32B-Instruct",
                 provider="hf-inference",
                 show_tools_desc=True,
                 show_prompt=True):

        # LLM definition
        llm = HuggingFaceInferenceAPI(model_name=model_name, # needs HF_TOKEN in env
                                      provider=provider)
        print(f"LLamaIndexAgent initialized with model \"{model_name}\"")

        # tools definition
        tool_spec_list = []
        tool_spec_list += WikipediaToolSpec().to_tool_list()
        tool_spec_list += DuckDuckGoSearchToolSpec().to_tool_list()
        tool_spec_list += CodeInterpreterToolSpec().to_tool_list()
        tool_spec_list += [query_image_tool, automatic_speech_recognition_tool]

        # agent definition
        self.agent = ReActAgent(llm=llm, tools=tool_spec_list)

        # update default prompt with a custom one
        custom_react_system_header = PromptTemplate(custom_react_system_header_str)
        self.agent.update_prompts({"react_header": custom_react_system_header})

        # context definition
        self.ctx = Context(self.agent)

        if show_tools_desc:
            for i, tool in enumerate(tool_spec_list):
                print("\n" + "="*30 + f" Tool {i+1} " + "="*30)
                print(tool.metadata.description)

        if show_prompt:
            prompt_dict = self.agent.get_prompts()
            for k, v in prompt_dict.items():
                print("\n" + "="*30 + f" Prompt: {k} " + "="*30)
                print(v.template)

    async def __call__(self, question: str) -> str:
        print("\n\n"+"*"*50)
        print(f"Agent received question: {question}")
        print("*"*50)

        handler = self.agent.run(question, ctx=self.ctx)
        async for ev in handler.stream_events():
            # if isinstance(ev, ToolCallResult):
            #     print(f"\nCall {ev.tool_name} with {ev.tool_kwargs}\nReturned: {ev.tool_output}")
            if isinstance(ev, AgentStream):
                print(f"{ev.delta}", end="", flush=True)

        response = await handler

        # post-process the response (cast AgentOutput to str and keep only what's after "FINAL ANSWER:" for the exact match)
        response = str(response)
        try:
            response = response.split("FINAL ANSWER:")[-1].strip()
        except:
            print('Could not split response on "FINAL ANSWER:"')
        print("\n\n"+"-"*50)
        print(f"Agent returning with answer: {response}")

        # clear context for next question before returning
        self.ctx.clear()

        return response