Spaces:
Sleeping
Sleeping
| from typing import Optional | |
| from .communicator import Communicator, PollCallback | |
| from .environment import UnityEnvironment | |
| from mlagents_envs.communicator_objects.unity_rl_output_pb2 import UnityRLOutputProto | |
| from mlagents_envs.communicator_objects.brain_parameters_pb2 import ( | |
| BrainParametersProto, | |
| ActionSpecProto, | |
| ) | |
| from mlagents_envs.communicator_objects.unity_rl_initialization_output_pb2 import ( | |
| UnityRLInitializationOutputProto, | |
| ) | |
| from mlagents_envs.communicator_objects.unity_input_pb2 import UnityInputProto | |
| from mlagents_envs.communicator_objects.unity_output_pb2 import UnityOutputProto | |
| from mlagents_envs.communicator_objects.agent_info_pb2 import AgentInfoProto | |
| from mlagents_envs.communicator_objects.observation_pb2 import ( | |
| ObservationProto, | |
| NONE as COMPRESSION_TYPE_NONE, | |
| PNG as COMPRESSION_TYPE_PNG, | |
| ) | |
| class MockCommunicator(Communicator): | |
| def __init__( | |
| self, | |
| discrete_action=False, | |
| visual_inputs=0, | |
| num_agents=3, | |
| brain_name="RealFakeBrain", | |
| vec_obs_size=3, | |
| ): | |
| """ | |
| Python side of the grpc communication. Python is the client and Unity the server | |
| """ | |
| super().__init__() | |
| self.is_discrete = discrete_action | |
| self.steps = 0 | |
| self.visual_inputs = visual_inputs | |
| self.has_been_closed = False | |
| self.num_agents = num_agents | |
| self.brain_name = brain_name | |
| self.vec_obs_size = vec_obs_size | |
| def initialize( | |
| self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None | |
| ) -> UnityOutputProto: | |
| if self.is_discrete: | |
| action_spec = ActionSpecProto( | |
| num_discrete_actions=2, discrete_branch_sizes=[3, 2] | |
| ) | |
| else: | |
| action_spec = ActionSpecProto(num_continuous_actions=2) | |
| bp = BrainParametersProto( | |
| brain_name=self.brain_name, is_training=True, action_spec=action_spec | |
| ) | |
| rl_init = UnityRLInitializationOutputProto( | |
| name="RealFakeAcademy", | |
| communication_version=UnityEnvironment.API_VERSION, | |
| package_version="mock_package_version", | |
| log_path="", | |
| brain_parameters=[bp], | |
| ) | |
| output = UnityRLOutputProto(agentInfos=self._get_agent_infos()) | |
| return UnityOutputProto(rl_initialization_output=rl_init, rl_output=output) | |
| def _get_agent_infos(self): | |
| dict_agent_info = {} | |
| list_agent_info = [] | |
| vector_obs = [1, 2, 3] | |
| observations = [ | |
| ObservationProto( | |
| compressed_data=None, | |
| shape=[30, 40, 3], | |
| compression_type=COMPRESSION_TYPE_PNG, | |
| ) | |
| for _ in range(self.visual_inputs) | |
| ] | |
| vector_obs_proto = ObservationProto( | |
| float_data=ObservationProto.FloatData(data=vector_obs), | |
| shape=[len(vector_obs)], | |
| compression_type=COMPRESSION_TYPE_NONE, | |
| ) | |
| observations.append(vector_obs_proto) | |
| for i in range(self.num_agents): | |
| list_agent_info.append( | |
| AgentInfoProto( | |
| reward=1, | |
| done=(i == 2), | |
| max_step_reached=False, | |
| id=i, | |
| observations=observations, | |
| ) | |
| ) | |
| dict_agent_info["RealFakeBrain"] = UnityRLOutputProto.ListAgentInfoProto( | |
| value=list_agent_info | |
| ) | |
| return dict_agent_info | |
| def exchange( | |
| self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None | |
| ) -> UnityOutputProto: | |
| result = UnityRLOutputProto(agentInfos=self._get_agent_infos()) | |
| return UnityOutputProto(rl_output=result) | |
| def close(self): | |
| """ | |
| Sends a shutdown signal to the unity environment, and closes the grpc connection. | |
| """ | |
| self.has_been_closed = True | |