Spaces:
Runtime error
Runtime error
from dataclasses import dataclass | |
import sys | |
from typing import Dict, List, Optional, Union | |
import logging | |
from autogen import Agent, ConversableAgent | |
logger = logging.getLogger(__name__) | |
class GroupChat: | |
"""A group chat class that contains the following data fields: | |
- agents: a list of participating agents. | |
- messages: a list of messages in the group chat. | |
- max_round: the maximum number of rounds. | |
- admin_name: the name of the admin agent if there is one. Default is "Admin". | |
KeyBoardInterrupt will make the admin agent take over. | |
- func_call_filter: whether to enforce function call filter. Default is True. | |
When set to True and when a message is a function call suggestion, | |
the next speaker will be chosen from an agent which contains the corresponding function name | |
in its `function_map`. | |
""" | |
agents: List[Agent] | |
messages: List[Dict] | |
max_round: int = 10 | |
admin_name: str = "Admin" | |
func_call_filter: bool = True | |
def agent_names(self) -> List[str]: | |
"""Return the names of the agents in the group chat.""" | |
return [agent.name for agent in self.agents] | |
def reset(self): | |
"""Reset the group chat.""" | |
self.messages.clear() | |
def agent_by_name(self, name: str) -> Agent: | |
"""Find the next speaker based on the message.""" | |
return self.agents[self.agent_names.index(name)] | |
def next_agent(self, agent: Agent, agents: List[Agent]) -> Agent: | |
"""Return the next agent in the list.""" | |
if agents == self.agents: | |
return agents[(self.agent_names.index(agent.name) + 1) % len(agents)] | |
else: | |
offset = self.agent_names.index(agent.name) + 1 | |
for i in range(len(self.agents)): | |
if self.agents[(offset + i) % len(self.agents)] in agents: | |
return self.agents[(offset + i) % len(self.agents)] | |
def select_speaker_msg(self, agents: List[Agent]): | |
"""Return the message for selecting the next speaker.""" | |
return f"""You are in a role play game. The following roles are available: | |
{self._participant_roles()}. | |
Ignoring the order in which the above roles appear. | |
Think about the dependency relationships between different roles. | |
Read the following conversation. | |
Then select the next role from {[agent.name for agent in agents]} to play. Only return the role.""" | |
def select_speaker(self, last_speaker: Agent, selector: ConversableAgent): | |
"""Select the next speaker.""" | |
if self.func_call_filter and self.messages and "function_call" in self.messages[-1]: | |
# find agents with the right function_map which contains the function name | |
agents = [ | |
agent for agent in self.agents if agent.can_execute_function(self.messages[-1]["function_call"]["name"]) | |
] | |
if len(agents) == 1: | |
# only one agent can execute the function | |
return agents[0] | |
elif not agents: | |
# find all the agents with function_map | |
agents = [agent for agent in self.agents if agent.function_map] | |
if len(agents) == 1: | |
return agents[0] | |
elif not agents: | |
raise ValueError( | |
f"No agent can execute the function {self.messages[-1]['name']}. " | |
"Please check the function_map of the agents." | |
) | |
else: | |
agents = self.agents | |
# Warn if GroupChat is underpopulated | |
n_agents = len(agents) | |
if n_agents < 3: | |
logger.warning( | |
f"GroupChat is underpopulated with {n_agents} agents. Direct communication would be more efficient." | |
) | |
selector.update_system_message(self.select_speaker_msg(agents)) | |
prompt = self.messages[-5:] + [{ | |
"role": "system", | |
"content": f"Read the above conversation. Then select the next role from {[agent.name for agent in agents]} to play. Only return the role.", | |
}] | |
print(prompt) | |
final, name = selector.generate_oai_reply( | |
# 根据前五次对话选择下一个发言人 | |
prompt | |
) | |
if not final: | |
# i = self._random.randint(0, len(self._agent_names) - 1) # randomly pick an id | |
return self.next_agent(last_speaker, agents) | |
try: | |
return self.agent_by_name(name) | |
except ValueError: | |
return self.next_agent(last_speaker, agents) | |
def _participant_roles(self): | |
return "\n".join([f"{agent.name}: {agent.system_message}" for agent in self.agents]) | |
class GroupChatManager(ConversableAgent): | |
"""(In preview) A chat manager agent that can manage a group chat of multiple agents.""" | |
def __init__( | |
self, | |
groupchat: GroupChat, | |
name: Optional[str] = "chat_manager", | |
max_consecutive_auto_reply: Optional[int] = sys.maxsize, | |
human_input_mode: Optional[str] = "NEVER", | |
system_message: Optional[str] = "Group chat manager.", | |
**kwargs, | |
): | |
super().__init__( | |
name=name, | |
max_consecutive_auto_reply=max_consecutive_auto_reply, | |
human_input_mode=human_input_mode, | |
system_message=system_message, | |
**kwargs, | |
) | |
self.groupchat = groupchat | |
self.update_system_message(self.groupchat.select_speaker_msg(self.groupchat.agents)) | |
def broadcast( | |
self, | |
message: Optional[str] = None, | |
sender: Optional[Agent] = None, | |
) -> Union[str, Dict, None]: | |
for agent in self.groupchat.agents: | |
if agent != sender: | |
self.send(message, agent, request_reply=False) | |