File size: 2,426 Bytes
4fb4269
e3e865e
d5ce935
e3e865e
d5ce935
e3e865e
 
4fb4269
 
e3e865e
 
 
4fb4269
d5ce935
4fb4269
 
e3e865e
d5ce935
e3e865e
 
 
4fb4269
 
 
 
 
 
 
 
 
d5ce935
 
 
 
 
 
 
 
e3e865e
 
 
 
 
 
 
 
 
 
4fb4269
e3e865e
 
 
 
 
 
 
 
 
d5ce935
 
 
 
 
 
4fb4269
 
 
d5ce935
 
e3e865e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from langchain_core.messages import BaseMessage, SystemMessage

import logging
import os
import re
from typing import List

from args import Args
from llm_factory import LLMFactory, AgentPreset


class IAgent():
    def __init__(self, sys_prompt_filename, agent_preset: AgentPreset, tools: List = [], parallel_tool_calls=False):
        self.name = self._format_name(sys_prompt_filename)
        self.interface = agent_preset.get_interface()

        # Load the system prompt from a file
        system_prompt_path = os.path.join(os.getcwd(), "system_prompts", sys_prompt_filename)
        self.system_prompt = ""
        with open(system_prompt_path, "r") as file:
            self.system_prompt = file.read().strip()

        # Define LLM
        llm = LLMFactory.create(agent_preset)

        # Add tools
        if tools:
            self.model = llm.bind_tools(tools, parallel_tool_calls=parallel_tool_calls)
        else:
            self.model = llm

    @staticmethod
    def _format_name(sys_prompt_filename: str) -> str:
        # Remove file extension
        name_without_ext = os.path.splitext(sys_prompt_filename)[0]
        # Remove numbers and special characters from the beginning
        cleaned_name = re.sub(r'^[^a-zA-Z]+', '', name_without_ext)
        return cleaned_name
    
    def get_system_prompt(self) -> str:
        """
        Retrieves the system prompt.

        Returns:
            str: The system prompt string.
        """
        return self.system_prompt

    def query(self, messages: List[BaseMessage]) -> BaseMessage:
        """
        Asynchronously queries the agent with a given question and returns the response.

        Args:
            question (str): The question to be sent to the agent.

        Returns:
            str: The response from the agent as a string.
        """
        if Args.LOGGER is None:
            raise RuntimeError("LOGGER must be defined before querying the agent.")

        separator = "=============================="
        Args.LOGGER.log(logging.INFO, f"\n{separator}\nAgent '{self.name}' has been queried !\nINPUT:\n{question}\n")

        system_prompt = self.get_system_prompt()
        conversation = [SystemMessage(content=system_prompt)] + messages
        response = self.model.invoke(conversation)

        Args.LOGGER.log(logging.INFO, f"\nAgent '{self.name}' produced OUTPUT:\n{response}\n{separator}\n")
        return response