Spaces:
Sleeping
Sleeping
File size: 2,426 Bytes
4fb4269 e3e865e d5ce935 e3e865e d5ce935 e3e865e 4fb4269 e3e865e 4fb4269 d5ce935 4fb4269 e3e865e d5ce935 e3e865e 4fb4269 d5ce935 e3e865e 4fb4269 e3e865e d5ce935 4fb4269 d5ce935 e3e865e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
from langchain_core.messages import BaseMessage, SystemMessage
import logging
import os
import re
from typing import List
from args import Args
from llm_factory import LLMFactory, AgentPreset
class IAgent():
def __init__(self, sys_prompt_filename, agent_preset: AgentPreset, tools: List = [], parallel_tool_calls=False):
self.name = self._format_name(sys_prompt_filename)
self.interface = agent_preset.get_interface()
# Load the system prompt from a file
system_prompt_path = os.path.join(os.getcwd(), "system_prompts", sys_prompt_filename)
self.system_prompt = ""
with open(system_prompt_path, "r") as file:
self.system_prompt = file.read().strip()
# Define LLM
llm = LLMFactory.create(agent_preset)
# Add tools
if tools:
self.model = llm.bind_tools(tools, parallel_tool_calls=parallel_tool_calls)
else:
self.model = llm
@staticmethod
def _format_name(sys_prompt_filename: str) -> str:
# Remove file extension
name_without_ext = os.path.splitext(sys_prompt_filename)[0]
# Remove numbers and special characters from the beginning
cleaned_name = re.sub(r'^[^a-zA-Z]+', '', name_without_ext)
return cleaned_name
def get_system_prompt(self) -> str:
"""
Retrieves the system prompt.
Returns:
str: The system prompt string.
"""
return self.system_prompt
def query(self, messages: List[BaseMessage]) -> BaseMessage:
"""
Asynchronously queries the agent with a given question and returns the response.
Args:
question (str): The question to be sent to the agent.
Returns:
str: The response from the agent as a string.
"""
if Args.LOGGER is None:
raise RuntimeError("LOGGER must be defined before querying the agent.")
separator = "=============================="
Args.LOGGER.log(logging.INFO, f"\n{separator}\nAgent '{self.name}' has been queried !\nINPUT:\n{question}\n")
system_prompt = self.get_system_prompt()
conversation = [SystemMessage(content=system_prompt)] + messages
response = self.model.invoke(conversation)
Args.LOGGER.log(logging.INFO, f"\nAgent '{self.name}' produced OUTPUT:\n{response}\n{separator}\n")
return response
|