Spaces:
Sleeping
Sleeping
| """A compatibility wrapper converting an old-style environment into a valid environment.""" | |
| import sys | |
| from typing import Any, Dict, Optional, Tuple | |
| import gym | |
| from gym.core import ObsType | |
| from gym.utils.step_api_compatibility import convert_to_terminated_truncated_step_api | |
| if sys.version_info >= (3, 8): | |
| from typing import Protocol, runtime_checkable | |
| elif sys.version_info >= (3, 7): | |
| from typing_extensions import Protocol, runtime_checkable | |
| else: | |
| Protocol = object | |
| runtime_checkable = lambda x: x # noqa: E731 | |
| class LegacyEnv(Protocol): | |
| """A protocol for environments using the old step API.""" | |
| observation_space: gym.Space | |
| action_space: gym.Space | |
| def reset(self) -> Any: | |
| """Reset the environment and return the initial observation.""" | |
| ... | |
| def step(self, action: Any) -> Tuple[Any, float, bool, Dict]: | |
| """Run one timestep of the environment's dynamics.""" | |
| ... | |
| def render(self, mode: Optional[str] = "human") -> Any: | |
| """Render the environment.""" | |
| ... | |
| def close(self): | |
| """Close the environment.""" | |
| ... | |
| def seed(self, seed: Optional[int] = None): | |
| """Set the seed for this env's random number generator(s).""" | |
| ... | |
| class EnvCompatibility(gym.Env): | |
| r"""A wrapper which can transform an environment from the old API to the new API. | |
| Old step API refers to step() method returning (observation, reward, done, info), and reset() only retuning the observation. | |
| New step API refers to step() method returning (observation, reward, terminated, truncated, info) and reset() returning (observation, info). | |
| (Refer to docs for details on the API change) | |
| Known limitations: | |
| - Environments that use `self.np_random` might not work as expected. | |
| """ | |
| def __init__(self, old_env: LegacyEnv, render_mode: Optional[str] = None): | |
| """A wrapper which converts old-style envs to valid modern envs. | |
| Some information may be lost in the conversion, so we recommend updating your environment. | |
| Args: | |
| old_env (LegacyEnv): the env to wrap, implemented with the old API | |
| render_mode (str): the render mode to use when rendering the environment, passed automatically to env.render | |
| """ | |
| self.metadata = getattr(old_env, "metadata", {"render_modes": []}) | |
| self.render_mode = render_mode | |
| self.reward_range = getattr(old_env, "reward_range", None) | |
| self.spec = getattr(old_env, "spec", None) | |
| self.env = old_env | |
| self.observation_space = old_env.observation_space | |
| self.action_space = old_env.action_space | |
| def reset( | |
| self, seed: Optional[int] = None, options: Optional[dict] = None | |
| ) -> Tuple[ObsType, dict]: | |
| """Resets the environment. | |
| Args: | |
| seed: the seed to reset the environment with | |
| options: the options to reset the environment with | |
| Returns: | |
| (observation, info) | |
| """ | |
| if seed is not None: | |
| self.env.seed(seed) | |
| # Options are ignored | |
| if self.render_mode == "human": | |
| self.render() | |
| return self.env.reset(), {} | |
| def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]: | |
| """Steps through the environment. | |
| Args: | |
| action: action to step through the environment with | |
| Returns: | |
| (observation, reward, terminated, truncated, info) | |
| """ | |
| obs, reward, done, info = self.env.step(action) | |
| if self.render_mode == "human": | |
| self.render() | |
| return convert_to_terminated_truncated_step_api((obs, reward, done, info)) | |
| def render(self) -> Any: | |
| """Renders the environment. | |
| Returns: | |
| The rendering of the environment, depending on the render mode | |
| """ | |
| return self.env.render(mode=self.render_mode) | |
| def close(self): | |
| """Closes the environment.""" | |
| self.env.close() | |
| def __str__(self): | |
| """Returns the wrapper name and the unwrapped environment string.""" | |
| return f"<{type(self).__name__}{self.env}>" | |
| def __repr__(self): | |
| """Returns the string representation of the wrapper.""" | |
| return str(self) | |