Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	| from contextlib import closing | |
| from io import StringIO | |
| from os import path | |
| from typing import Optional | |
| import numpy as np | |
| from gym import Env, logger, spaces, utils | |
| from gym.envs.toy_text.utils import categorical_sample | |
| from gym.error import DependencyNotInstalled | |
| MAP = [ | |
| "+---------+", | |
| "|R: | : :G|", | |
| "| : | : : |", | |
| "| : : : : |", | |
| "| | : | : |", | |
| "|Y| : |B: |", | |
| "+---------+", | |
| ] | |
| WINDOW_SIZE = (550, 350) | |
| class TaxiEnv(Env): | |
| """ | |
| The Taxi Problem | |
| from "Hierarchical Reinforcement Learning with the MAXQ Value Function Decomposition" | |
| by Tom Dietterich | |
| ### Description | |
| There are four designated locations in the grid world indicated by R(ed), | |
| G(reen), Y(ellow), and B(lue). When the episode starts, the taxi starts off | |
| at a random square and the passenger is at a random location. The taxi | |
| drives to the passenger's location, picks up the passenger, drives to the | |
| passenger's destination (another one of the four specified locations), and | |
| then drops off the passenger. Once the passenger is dropped off, the episode ends. | |
| Map: | |
| +---------+ | |
| |R: | : :G| | |
| | : | : : | | |
| | : : : : | | |
| | | : | : | | |
| |Y| : |B: | | |
| +---------+ | |
| ### Actions | |
| There are 6 discrete deterministic actions: | |
| - 0: move south | |
| - 1: move north | |
| - 2: move east | |
| - 3: move west | |
| - 4: pickup passenger | |
| - 5: drop off passenger | |
| ### Observations | |
| There are 500 discrete states since there are 25 taxi positions, 5 possible | |
| locations of the passenger (including the case when the passenger is in the | |
| taxi), and 4 destination locations. | |
| Note that there are 400 states that can actually be reached during an | |
| episode. The missing states correspond to situations in which the passenger | |
| is at the same location as their destination, as this typically signals the | |
| end of an episode. Four additional states can be observed right after a | |
| successful episodes, when both the passenger and the taxi are at the destination. | |
| This gives a total of 404 reachable discrete states. | |
| Each state space is represented by the tuple: | |
| (taxi_row, taxi_col, passenger_location, destination) | |
| An observation is an integer that encodes the corresponding state. | |
| The state tuple can then be decoded with the "decode" method. | |
| Passenger locations: | |
| - 0: R(ed) | |
| - 1: G(reen) | |
| - 2: Y(ellow) | |
| - 3: B(lue) | |
| - 4: in taxi | |
| Destinations: | |
| - 0: R(ed) | |
| - 1: G(reen) | |
| - 2: Y(ellow) | |
| - 3: B(lue) | |
| ### Info | |
| ``step`` and ``reset()`` will return an info dictionary that contains "p" and "action_mask" containing | |
| the probability that the state is taken and a mask of what actions will result in a change of state to speed up training. | |
| As Taxi's initial state is a stochastic, the "p" key represents the probability of the | |
| transition however this value is currently bugged being 1.0, this will be fixed soon. | |
| As the steps are deterministic, "p" represents the probability of the transition which is always 1.0 | |
| For some cases, taking an action will have no effect on the state of the agent. | |
| In v0.25.0, ``info["action_mask"]`` contains a np.ndarray for each of the action specifying | |
| if the action will change the state. | |
| To sample a modifying action, use ``action = env.action_space.sample(info["action_mask"])`` | |
| Or with a Q-value based algorithm ``action = np.argmax(q_values[obs, np.where(info["action_mask"] == 1)[0]])``. | |
| ### Rewards | |
| - -1 per step unless other reward is triggered. | |
| - +20 delivering passenger. | |
| - -10 executing "pickup" and "drop-off" actions illegally. | |
| ### Arguments | |
| ``` | |
| gym.make('Taxi-v3') | |
| ``` | |
| ### Version History | |
| * v3: Map Correction + Cleaner Domain Description, v0.25.0 action masking added to the reset and step information | |
| * v2: Disallow Taxi start location = goal location, Update Taxi observations in the rollout, Update Taxi reward threshold. | |
| * v1: Remove (3,2) from locs, add passidx<4 check | |
| * v0: Initial versions release | |
| """ | |
| metadata = { | |
| "render_modes": ["human", "ansi", "rgb_array"], | |
| "render_fps": 4, | |
| } | |
| def __init__(self, render_mode: Optional[str] = None): | |
| self.desc = np.asarray(MAP, dtype="c") | |
| self.locs = locs = [(0, 0), (0, 4), (4, 0), (4, 3)] | |
| self.locs_colors = [(255, 0, 0), (0, 255, 0), (255, 255, 0), (0, 0, 255)] | |
| num_states = 500 | |
| num_rows = 5 | |
| num_columns = 5 | |
| max_row = num_rows - 1 | |
| max_col = num_columns - 1 | |
| self.initial_state_distrib = np.zeros(num_states) | |
| num_actions = 6 | |
| self.P = { | |
| state: {action: [] for action in range(num_actions)} | |
| for state in range(num_states) | |
| } | |
| for row in range(num_rows): | |
| for col in range(num_columns): | |
| for pass_idx in range(len(locs) + 1): # +1 for being inside taxi | |
| for dest_idx in range(len(locs)): | |
| state = self.encode(row, col, pass_idx, dest_idx) | |
| if pass_idx < 4 and pass_idx != dest_idx: | |
| self.initial_state_distrib[state] += 1 | |
| for action in range(num_actions): | |
| # defaults | |
| new_row, new_col, new_pass_idx = row, col, pass_idx | |
| reward = ( | |
| -1 | |
| ) # default reward when there is no pickup/dropoff | |
| terminated = False | |
| taxi_loc = (row, col) | |
| if action == 0: | |
| new_row = min(row + 1, max_row) | |
| elif action == 1: | |
| new_row = max(row - 1, 0) | |
| if action == 2 and self.desc[1 + row, 2 * col + 2] == b":": | |
| new_col = min(col + 1, max_col) | |
| elif action == 3 and self.desc[1 + row, 2 * col] == b":": | |
| new_col = max(col - 1, 0) | |
| elif action == 4: # pickup | |
| if pass_idx < 4 and taxi_loc == locs[pass_idx]: | |
| new_pass_idx = 4 | |
| else: # passenger not at location | |
| reward = -10 | |
| elif action == 5: # dropoff | |
| if (taxi_loc == locs[dest_idx]) and pass_idx == 4: | |
| new_pass_idx = dest_idx | |
| terminated = True | |
| reward = 20 | |
| elif (taxi_loc in locs) and pass_idx == 4: | |
| new_pass_idx = locs.index(taxi_loc) | |
| else: # dropoff at wrong location | |
| reward = -10 | |
| new_state = self.encode( | |
| new_row, new_col, new_pass_idx, dest_idx | |
| ) | |
| self.P[state][action].append( | |
| (1.0, new_state, reward, terminated) | |
| ) | |
| self.initial_state_distrib /= self.initial_state_distrib.sum() | |
| self.action_space = spaces.Discrete(num_actions) | |
| self.observation_space = spaces.Discrete(num_states) | |
| self.render_mode = render_mode | |
| # pygame utils | |
| self.window = None | |
| self.clock = None | |
| self.cell_size = ( | |
| WINDOW_SIZE[0] / self.desc.shape[1], | |
| WINDOW_SIZE[1] / self.desc.shape[0], | |
| ) | |
| self.taxi_imgs = None | |
| self.taxi_orientation = 0 | |
| self.passenger_img = None | |
| self.destination_img = None | |
| self.median_horiz = None | |
| self.median_vert = None | |
| self.background_img = None | |
| def encode(self, taxi_row, taxi_col, pass_loc, dest_idx): | |
| # (5) 5, 5, 4 | |
| i = taxi_row | |
| i *= 5 | |
| i += taxi_col | |
| i *= 5 | |
| i += pass_loc | |
| i *= 4 | |
| i += dest_idx | |
| return i | |
| def decode(self, i): | |
| out = [] | |
| out.append(i % 4) | |
| i = i // 4 | |
| out.append(i % 5) | |
| i = i // 5 | |
| out.append(i % 5) | |
| i = i // 5 | |
| out.append(i) | |
| assert 0 <= i < 5 | |
| return reversed(out) | |
| def action_mask(self, state: int): | |
| """Computes an action mask for the action space using the state information.""" | |
| mask = np.zeros(6, dtype=np.int8) | |
| taxi_row, taxi_col, pass_loc, dest_idx = self.decode(state) | |
| if taxi_row < 4: | |
| mask[0] = 1 | |
| if taxi_row > 0: | |
| mask[1] = 1 | |
| if taxi_col < 4 and self.desc[taxi_row + 1, 2 * taxi_col + 2] == b":": | |
| mask[2] = 1 | |
| if taxi_col > 0 and self.desc[taxi_row + 1, 2 * taxi_col] == b":": | |
| mask[3] = 1 | |
| if pass_loc < 4 and (taxi_row, taxi_col) == self.locs[pass_loc]: | |
| mask[4] = 1 | |
| if pass_loc == 4 and ( | |
| (taxi_row, taxi_col) == self.locs[dest_idx] | |
| or (taxi_row, taxi_col) in self.locs | |
| ): | |
| mask[5] = 1 | |
| return mask | |
| def step(self, a): | |
| transitions = self.P[self.s][a] | |
| i = categorical_sample([t[0] for t in transitions], self.np_random) | |
| p, s, r, t = transitions[i] | |
| self.s = s | |
| self.lastaction = a | |
| if self.render_mode == "human": | |
| self.render() | |
| return (int(s), r, t, False, {"prob": p, "action_mask": self.action_mask(s)}) | |
| def reset( | |
| self, | |
| *, | |
| seed: Optional[int] = None, | |
| options: Optional[dict] = None, | |
| ): | |
| super().reset(seed=seed) | |
| self.s = categorical_sample(self.initial_state_distrib, self.np_random) | |
| self.lastaction = None | |
| self.taxi_orientation = 0 | |
| if self.render_mode == "human": | |
| self.render() | |
| return int(self.s), {"prob": 1.0, "action_mask": self.action_mask(self.s)} | |
| def render(self): | |
| if self.render_mode is None: | |
| logger.warn( | |
| "You are calling render method without specifying any render mode. " | |
| "You can specify the render_mode at initialization, " | |
| f'e.g. gym("{self.spec.id}", render_mode="rgb_array")' | |
| ) | |
| if self.render_mode == "ansi": | |
| return self._render_text() | |
| else: # self.render_mode in {"human", "rgb_array"}: | |
| return self._render_gui(self.render_mode) | |
| def _render_gui(self, mode): | |
| try: | |
| import pygame # dependency to pygame only if rendering with human | |
| except ImportError: | |
| raise DependencyNotInstalled( | |
| "pygame is not installed, run `pip install gym[toy_text]`" | |
| ) | |
| if self.window is None: | |
| pygame.init() | |
| pygame.display.set_caption("Taxi") | |
| if mode == "human": | |
| self.window = pygame.display.set_mode(WINDOW_SIZE) | |
| elif mode == "rgb_array": | |
| self.window = pygame.Surface(WINDOW_SIZE) | |
| assert ( | |
| self.window is not None | |
| ), "Something went wrong with pygame. This should never happen." | |
| if self.clock is None: | |
| self.clock = pygame.time.Clock() | |
| if self.taxi_imgs is None: | |
| file_names = [ | |
| path.join(path.dirname(__file__), "img/cab_front.png"), | |
| path.join(path.dirname(__file__), "img/cab_rear.png"), | |
| path.join(path.dirname(__file__), "img/cab_right.png"), | |
| path.join(path.dirname(__file__), "img/cab_left.png"), | |
| ] | |
| self.taxi_imgs = [ | |
| pygame.transform.scale(pygame.image.load(file_name), self.cell_size) | |
| for file_name in file_names | |
| ] | |
| if self.passenger_img is None: | |
| file_name = path.join(path.dirname(__file__), "img/passenger.png") | |
| self.passenger_img = pygame.transform.scale( | |
| pygame.image.load(file_name), self.cell_size | |
| ) | |
| if self.destination_img is None: | |
| file_name = path.join(path.dirname(__file__), "img/hotel.png") | |
| self.destination_img = pygame.transform.scale( | |
| pygame.image.load(file_name), self.cell_size | |
| ) | |
| self.destination_img.set_alpha(170) | |
| if self.median_horiz is None: | |
| file_names = [ | |
| path.join(path.dirname(__file__), "img/gridworld_median_left.png"), | |
| path.join(path.dirname(__file__), "img/gridworld_median_horiz.png"), | |
| path.join(path.dirname(__file__), "img/gridworld_median_right.png"), | |
| ] | |
| self.median_horiz = [ | |
| pygame.transform.scale(pygame.image.load(file_name), self.cell_size) | |
| for file_name in file_names | |
| ] | |
| if self.median_vert is None: | |
| file_names = [ | |
| path.join(path.dirname(__file__), "img/gridworld_median_top.png"), | |
| path.join(path.dirname(__file__), "img/gridworld_median_vert.png"), | |
| path.join(path.dirname(__file__), "img/gridworld_median_bottom.png"), | |
| ] | |
| self.median_vert = [ | |
| pygame.transform.scale(pygame.image.load(file_name), self.cell_size) | |
| for file_name in file_names | |
| ] | |
| if self.background_img is None: | |
| file_name = path.join(path.dirname(__file__), "img/taxi_background.png") | |
| self.background_img = pygame.transform.scale( | |
| pygame.image.load(file_name), self.cell_size | |
| ) | |
| desc = self.desc | |
| for y in range(0, desc.shape[0]): | |
| for x in range(0, desc.shape[1]): | |
| cell = (x * self.cell_size[0], y * self.cell_size[1]) | |
| self.window.blit(self.background_img, cell) | |
| if desc[y][x] == b"|" and (y == 0 or desc[y - 1][x] != b"|"): | |
| self.window.blit(self.median_vert[0], cell) | |
| elif desc[y][x] == b"|" and ( | |
| y == desc.shape[0] - 1 or desc[y + 1][x] != b"|" | |
| ): | |
| self.window.blit(self.median_vert[2], cell) | |
| elif desc[y][x] == b"|": | |
| self.window.blit(self.median_vert[1], cell) | |
| elif desc[y][x] == b"-" and (x == 0 or desc[y][x - 1] != b"-"): | |
| self.window.blit(self.median_horiz[0], cell) | |
| elif desc[y][x] == b"-" and ( | |
| x == desc.shape[1] - 1 or desc[y][x + 1] != b"-" | |
| ): | |
| self.window.blit(self.median_horiz[2], cell) | |
| elif desc[y][x] == b"-": | |
| self.window.blit(self.median_horiz[1], cell) | |
| for cell, color in zip(self.locs, self.locs_colors): | |
| color_cell = pygame.Surface(self.cell_size) | |
| color_cell.set_alpha(128) | |
| color_cell.fill(color) | |
| loc = self.get_surf_loc(cell) | |
| self.window.blit(color_cell, (loc[0], loc[1] + 10)) | |
| taxi_row, taxi_col, pass_idx, dest_idx = self.decode(self.s) | |
| if pass_idx < 4: | |
| self.window.blit(self.passenger_img, self.get_surf_loc(self.locs[pass_idx])) | |
| if self.lastaction in [0, 1, 2, 3]: | |
| self.taxi_orientation = self.lastaction | |
| dest_loc = self.get_surf_loc(self.locs[dest_idx]) | |
| taxi_location = self.get_surf_loc((taxi_row, taxi_col)) | |
| if dest_loc[1] <= taxi_location[1]: | |
| self.window.blit( | |
| self.destination_img, | |
| (dest_loc[0], dest_loc[1] - self.cell_size[1] // 2), | |
| ) | |
| self.window.blit(self.taxi_imgs[self.taxi_orientation], taxi_location) | |
| else: # change blit order for overlapping appearance | |
| self.window.blit(self.taxi_imgs[self.taxi_orientation], taxi_location) | |
| self.window.blit( | |
| self.destination_img, | |
| (dest_loc[0], dest_loc[1] - self.cell_size[1] // 2), | |
| ) | |
| if mode == "human": | |
| pygame.display.update() | |
| self.clock.tick(self.metadata["render_fps"]) | |
| elif mode == "rgb_array": | |
| return np.transpose( | |
| np.array(pygame.surfarray.pixels3d(self.window)), axes=(1, 0, 2) | |
| ) | |
| def get_surf_loc(self, map_loc): | |
| return (map_loc[1] * 2 + 1) * self.cell_size[0], ( | |
| map_loc[0] + 1 | |
| ) * self.cell_size[1] | |
| def _render_text(self): | |
| desc = self.desc.copy().tolist() | |
| outfile = StringIO() | |
| out = [[c.decode("utf-8") for c in line] for line in desc] | |
| taxi_row, taxi_col, pass_idx, dest_idx = self.decode(self.s) | |
| def ul(x): | |
| return "_" if x == " " else x | |
| if pass_idx < 4: | |
| out[1 + taxi_row][2 * taxi_col + 1] = utils.colorize( | |
| out[1 + taxi_row][2 * taxi_col + 1], "yellow", highlight=True | |
| ) | |
| pi, pj = self.locs[pass_idx] | |
| out[1 + pi][2 * pj + 1] = utils.colorize( | |
| out[1 + pi][2 * pj + 1], "blue", bold=True | |
| ) | |
| else: # passenger in taxi | |
| out[1 + taxi_row][2 * taxi_col + 1] = utils.colorize( | |
| ul(out[1 + taxi_row][2 * taxi_col + 1]), "green", highlight=True | |
| ) | |
| di, dj = self.locs[dest_idx] | |
| out[1 + di][2 * dj + 1] = utils.colorize(out[1 + di][2 * dj + 1], "magenta") | |
| outfile.write("\n".join(["".join(row) for row in out]) + "\n") | |
| if self.lastaction is not None: | |
| outfile.write( | |
| f" ({['South', 'North', 'East', 'West', 'Pickup', 'Dropoff'][self.lastaction]})\n" | |
| ) | |
| else: | |
| outfile.write("\n") | |
| with closing(outfile): | |
| return outfile.getvalue() | |
| def close(self): | |
| if self.window is not None: | |
| import pygame | |
| pygame.display.quit() | |
| pygame.quit() | |
| # Taxi rider from https://franuka.itch.io/rpg-asset-pack | |
| # All other assets by Mel Tillery http://www.cyaneus.com/ | |
