Spaces:
Runtime error
Runtime error
File size: 5,886 Bytes
b628e9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
from dataclasses import dataclass
import sys
from typing import Dict, List, Optional, Union
import logging
from autogen import Agent, ConversableAgent
logger = logging.getLogger(__name__)
@dataclass
class GroupChat:
"""A group chat class that contains the following data fields:
- agents: a list of participating agents.
- messages: a list of messages in the group chat.
- max_round: the maximum number of rounds.
- admin_name: the name of the admin agent if there is one. Default is "Admin".
KeyBoardInterrupt will make the admin agent take over.
- func_call_filter: whether to enforce function call filter. Default is True.
When set to True and when a message is a function call suggestion,
the next speaker will be chosen from an agent which contains the corresponding function name
in its `function_map`.
"""
agents: List[Agent]
messages: List[Dict]
max_round: int = 10
admin_name: str = "Admin"
func_call_filter: bool = True
@property
def agent_names(self) -> List[str]:
"""Return the names of the agents in the group chat."""
return [agent.name for agent in self.agents]
def reset(self):
"""Reset the group chat."""
self.messages.clear()
def agent_by_name(self, name: str) -> Agent:
"""Find the next speaker based on the message."""
return self.agents[self.agent_names.index(name)]
def next_agent(self, agent: Agent, agents: List[Agent]) -> Agent:
"""Return the next agent in the list."""
if agents == self.agents:
return agents[(self.agent_names.index(agent.name) + 1) % len(agents)]
else:
offset = self.agent_names.index(agent.name) + 1
for i in range(len(self.agents)):
if self.agents[(offset + i) % len(self.agents)] in agents:
return self.agents[(offset + i) % len(self.agents)]
def select_speaker_msg(self, agents: List[Agent]):
"""Return the message for selecting the next speaker."""
return f"""You are in a role play game. The following roles are available:
{self._participant_roles()}.
Ignoring the order in which the above roles appear.
Think about the dependency relationships between different roles.
Read the following conversation.
Then select the next role from {[agent.name for agent in agents]} to play. Only return the role."""
def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
"""Select the next speaker."""
if self.func_call_filter and self.messages and "function_call" in self.messages[-1]:
# find agents with the right function_map which contains the function name
agents = [
agent for agent in self.agents if agent.can_execute_function(self.messages[-1]["function_call"]["name"])
]
if len(agents) == 1:
# only one agent can execute the function
return agents[0]
elif not agents:
# find all the agents with function_map
agents = [agent for agent in self.agents if agent.function_map]
if len(agents) == 1:
return agents[0]
elif not agents:
raise ValueError(
f"No agent can execute the function {self.messages[-1]['name']}. "
"Please check the function_map of the agents."
)
else:
agents = self.agents
# Warn if GroupChat is underpopulated
n_agents = len(agents)
if n_agents < 3:
logger.warning(
f"GroupChat is underpopulated with {n_agents} agents. Direct communication would be more efficient."
)
selector.update_system_message(self.select_speaker_msg(agents))
prompt = self.messages[-5:] + [{
"role": "system",
"content": f"Read the above conversation. Then select the next role from {[agent.name for agent in agents]} to play. Only return the role.",
}]
print(prompt)
final, name = selector.generate_oai_reply(
# 根据前五次对话选择下一个发言人
prompt
)
if not final:
# i = self._random.randint(0, len(self._agent_names) - 1) # randomly pick an id
return self.next_agent(last_speaker, agents)
try:
return self.agent_by_name(name)
except ValueError:
return self.next_agent(last_speaker, agents)
def _participant_roles(self):
return "\n".join([f"{agent.name}: {agent.system_message}" for agent in self.agents])
class GroupChatManager(ConversableAgent):
"""(In preview) A chat manager agent that can manage a group chat of multiple agents."""
def __init__(
self,
groupchat: GroupChat,
name: Optional[str] = "chat_manager",
max_consecutive_auto_reply: Optional[int] = sys.maxsize,
human_input_mode: Optional[str] = "NEVER",
system_message: Optional[str] = "Group chat manager.",
**kwargs,
):
super().__init__(
name=name,
max_consecutive_auto_reply=max_consecutive_auto_reply,
human_input_mode=human_input_mode,
system_message=system_message,
**kwargs,
)
self.groupchat = groupchat
self.update_system_message(self.groupchat.select_speaker_msg(self.groupchat.agents))
def broadcast(
self,
message: Optional[str] = None,
sender: Optional[Agent] = None,
) -> Union[str, Dict, None]:
for agent in self.groupchat.agents:
if agent != sender:
self.send(message, agent, request_reply=False)
|