Spaces:
Sleeping
Sleeping
File size: 3,844 Bytes
e4f6727 e3e865e d5ce935 e3e865e d5ce935 e3e865e f49023b e3e865e 4fb4269 d5ce935 4fb4269 e4f6727 4fb4269 e3e865e d5ce935 e3e865e 4fb4269 d5ce935 e3e865e e4f6727 e3e865e e4f6727 e3e865e e4f6727 d5ce935 98ada6c d5ce935 e4f6727 4fb4269 e4f6727 d5ce935 e3e865e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage
import logging
import os
import re
from typing import List
from args import Args, AgentPreset
from llm_factory import LLMFactory
class IAgent():
def __init__(self, sys_prompt_filename, agent_preset: AgentPreset, tools: List = [], parallel_tool_calls=False):
self.name = self._format_name(sys_prompt_filename)
self.interface = agent_preset.get_interface()
self.mock = (agent_preset.get_model_name() == "groot")
# Load the system prompt from a file
system_prompt_path = os.path.join(os.getcwd(), "system_prompts", sys_prompt_filename)
self.system_prompt = ""
with open(system_prompt_path, "r") as file:
self.system_prompt = file.read().strip()
# Define LLM
llm = LLMFactory.create(agent_preset)
# Add tools
if tools:
self.model = llm.bind_tools(tools, parallel_tool_calls=parallel_tool_calls)
else:
self.model = llm
@staticmethod
def _format_name(sys_prompt_filename: str) -> str:
# Remove file extension
name_without_ext = os.path.splitext(sys_prompt_filename)[0]
# Remove numbers and special characters from the beginning
cleaned_name = re.sub(r'^[^a-zA-Z]+', '', name_without_ext)
return cleaned_name
@staticmethod
def _bake_roles(messages: List[str]) -> List[AnyMessage]:
"""
Assigns roles to messages in reverse order: last message is HumanMessage,
previous is AIMessage, and so on, alternating backwards.
Args:
messages (List[str]): List of message strings.
Returns:
List[AnyMessage]: List of messages wrapped with appropriate role classes.
Raises:
ValueError: If messages is empty.
"""
if not messages:
raise ValueError("The list of messages cannot be empty !")
messages_with_roles = []
total_messages = len(messages)
for idx, msg in enumerate(messages):
# Assign roles in reverse: last is Human, previous is AI, etc.
reverse_idx = total_messages - idx - 1
if reverse_idx % 2 == 0:
messages_with_roles.append(HumanMessage(content=msg))
else:
messages_with_roles.append(AIMessage(content=msg))
return messages_with_roles
def get_system_prompt(self) -> str:
"""
Retrieves the system prompt.
Returns:
str: The system prompt string.
"""
return self.system_prompt
def query(self, messages: List[str]) -> str:
"""
Asynchronously queries the agent with a given question and returns the response.
Args:
question (str): The question to be sent to the agent.
Returns:
str: The response from the agent as a string.
"""
if Args.LOGGER is None:
raise RuntimeError("LOGGER must be defined before querying the agent.")
separator = "=============================="
Args.LOGGER.log(logging.INFO, f"\n{separator}\nAgent '{self.name}' has been queried !\nINPUT:\n{messages}\n")
if self.mock:
response = str("I am GROOT !")
Args.LOGGER.log(logging.INFO, f"\nAgent '{self.name}' produced OUTPUT:\n{response}\n{separator}\n")
return response
system_prompt = self.get_system_prompt()
messages_with_roles = self._bake_roles(messages)
conversation = [SystemMessage(content=system_prompt)] + messages_with_roles
response = str(self.model.invoke(conversation).content)
Args.LOGGER.log(logging.INFO, f"\nAgent '{self.name}' produced OUTPUT:\n{response}\n{separator}\n")
return response
|