Spaces:
Sleeping
Sleeping
from langchain_core.messages import BaseMessage, SystemMessage | |
import logging | |
import os | |
import re | |
from typing import List | |
from args import Args | |
from llm_factory import LLMFactory, AgentPreset | |
class IAgent(): | |
def __init__(self, sys_prompt_filename, agent_preset: AgentPreset, tools: List = [], parallel_tool_calls=False): | |
self.name = self._format_name(sys_prompt_filename) | |
self.interface = agent_preset.get_interface() | |
# Load the system prompt from a file | |
system_prompt_path = os.path.join(os.getcwd(), "system_prompts", sys_prompt_filename) | |
self.system_prompt = "" | |
with open(system_prompt_path, "r") as file: | |
self.system_prompt = file.read().strip() | |
# Define LLM | |
llm = LLMFactory.create(agent_preset) | |
# Add tools | |
if tools: | |
self.model = llm.bind_tools(tools, parallel_tool_calls=parallel_tool_calls) | |
else: | |
self.model = llm | |
def _format_name(sys_prompt_filename: str) -> str: | |
# Remove file extension | |
name_without_ext = os.path.splitext(sys_prompt_filename)[0] | |
# Remove numbers and special characters from the beginning | |
cleaned_name = re.sub(r'^[^a-zA-Z]+', '', name_without_ext) | |
return cleaned_name | |
def get_system_prompt(self) -> str: | |
""" | |
Retrieves the system prompt. | |
Returns: | |
str: The system prompt string. | |
""" | |
return self.system_prompt | |
def query(self, messages: List[BaseMessage]) -> BaseMessage: | |
""" | |
Asynchronously queries the agent with a given question and returns the response. | |
Args: | |
question (str): The question to be sent to the agent. | |
Returns: | |
str: The response from the agent as a string. | |
""" | |
if Args.LOGGER is None: | |
raise RuntimeError("LOGGER must be defined before querying the agent.") | |
separator = "==============================" | |
Args.LOGGER.log(logging.INFO, f"\n{separator}\nAgent '{self.name}' has been queried !\nINPUT:\n{question}\n") | |
system_prompt = self.get_system_prompt() | |
conversation = [SystemMessage(content=system_prompt)] + messages | |
response = self.model.invoke(conversation) | |
Args.LOGGER.log(logging.INFO, f"\nAgent '{self.name}' produced OUTPUT:\n{response}\n{separator}\n") | |
return response | |