Spaces:
Build error
Build error
| from __future__ import annotations | |
| import logging | |
| import re | |
| from typing import TYPE_CHECKING, Any, List, Optional | |
| from . import order_registry as OrderRegistry | |
| from .base import BaseOrder | |
| if TYPE_CHECKING: | |
| from agentverse.environments import BaseEnvironment | |
| class ClassroomOrder(BaseOrder): | |
| """The order for a classroom discussion | |
| The agents speak in the following order: | |
| 1. The professor speaks first | |
| 2. Then the professor can continue to speak, and the students can raise hands | |
| 3. The professor can call on a student, then the student can speak or ask a question | |
| 4. In the group discussion, the students in the group can speak in turn | |
| """ | |
| def get_next_agent_idx(self, environment: BaseEnvironment) -> List[int]: | |
| # `is_grouped_ended`: whether the group discussion just ended | |
| # `is_grouped`: whether it is currently in a group discussion | |
| if environment.rule_params.get("is_grouped_ended", False): | |
| return [0] | |
| if environment.rule_params.get("is_grouped", False): | |
| return self.get_next_agent_idx_grouped(environment) | |
| else: | |
| return self.get_next_agent_idx_ungrouped(environment) | |
| def get_next_agent_idx_ungrouped(self, environment: BaseEnvironment) -> List[int]: | |
| if len(environment.last_messages) == 0: | |
| # If the class just begins or no one speaks in the last turn, we let only the professor speak | |
| return [0] | |
| elif len(environment.last_messages) == 1: | |
| message = environment.last_messages[0] | |
| sender = message.sender | |
| content = message.content | |
| if sender.startswith("Professor"): | |
| if content.startswith("[CallOn]"): | |
| # 1. professor calls on someone, then the student should speak | |
| result = re.search(r"\[CallOn\] Yes, ([sS]tudent )?(\w+)", content) | |
| if result is not None: | |
| name_to_id = { | |
| agent.name[len("Student ") :]: i | |
| for i, agent in enumerate(environment.agents) | |
| } | |
| return [name_to_id[result.group(2)]] | |
| else: | |
| # 2. professor normally speaks, then anyone can act | |
| return list(range(len(environment.agents))) | |
| elif sender.startswith("Student"): | |
| # 3. student ask question after being called on, or | |
| # 4. only one student raises hand, and the professor happens to listen | |
| # 5. the group discussion is just over, and there happens to be only a student speaking in the last turn | |
| return [0] | |
| else: | |
| # If len(last_messages) > 1, then | |
| # 1. there must be at least one student raises hand or speaks. | |
| # 2. the group discussion is just over. | |
| return [0] | |
| assert ( | |
| False | |
| ), f"Should not reach here, last_messages: {environment.last_messages}" | |
| def get_next_agent_idx_grouped(self, environment: BaseEnvironment) -> List[int]: | |
| # Get the grouping information | |
| # groups: A list of list of agent ids, the i-th list contains | |
| # the agent ids in the i-th group | |
| # group_speaker_mapping: A mapping from group id to the id of | |
| # the speaker in the group | |
| # `groups` should be set in the corresponding `visibility`, | |
| # and `group_speaker_mapping` should be maintained here. | |
| if "groups" not in environment.rule_params: | |
| logging.warning( | |
| "The environment is grouped, but the grouping information is not provided." | |
| ) | |
| groups = environment.rule_params.get( | |
| "groups", [list(range(len(environment.agents)))] | |
| ) | |
| group_speaker_mapping = environment.rule_params.get( | |
| "group_speaker_mapping", {i: 0 for i in range(len(groups))} | |
| ) | |
| # For grouped environment, we let the students speak in turn within each group | |
| next_agent_idx = [] | |
| for group_id in range(len(groups)): | |
| speaker_index = group_speaker_mapping[group_id] | |
| speaker = groups[group_id][speaker_index] | |
| next_agent_idx.append(speaker) | |
| # Maintain the `group_speaker_mapping` | |
| for k, v in group_speaker_mapping.items(): | |
| group_speaker_mapping[k] = (v + 1) % len(groups[k]) | |
| environment.rule_params["group_speaker_mapping"] = group_speaker_mapping | |
| return next_agent_idx | |