Spaces:
Sleeping
Sleeping
from __future__ import print_function | |
from dizoo.beergame.envs import clBeerGame | |
from torch import Tensor | |
import numpy as np | |
import random | |
from .utils import get_config, update_config | |
import gym | |
import os | |
from typing import Optional | |
class BeerGame(): | |
def __init__(self, role: int, agent_type: str, demandDistribution: int) -> None: | |
self._cfg, unparsed = get_config() | |
self._role = role | |
# prepare loggers and directories | |
# prepare_dirs_and_logger(self._cfg) | |
self._cfg = update_config(self._cfg) | |
# set agent type | |
if agent_type == 'bs': | |
self._cfg.agentTypes = ["bs", "bs", "bs", "bs"] | |
elif agent_type == 'Strm': | |
self._cfg.agentTypes = ["Strm", "Strm", "Strm", "Strm"] | |
self._cfg.agentTypes[role] = "srdqn" | |
self._cfg.demandDistribution = demandDistribution | |
# load demands:0=uniform, 1=normal distribution, 2=the sequence of 4,4,4,4,8,..., 3= basket data, 4= forecast data | |
if self._cfg.observation_data: | |
adsr = 'data/demandTr-obs-' | |
elif self._cfg.demandDistribution == 3: | |
if self._cfg.scaled: | |
adsr = 'data/basket_data/scaled' | |
else: | |
adsr = 'data/basket_data' | |
direc = os.path.realpath(adsr + '/demandTr-' + str(self._cfg.data_id) + '.npy') | |
self._demandTr = np.load(direc) | |
print("loaded training set=", direc) | |
elif self._cfg.demandDistribution == 4: | |
if self._cfg.scaled: | |
adsr = 'data/forecast_data/scaled' | |
else: | |
adsr = 'data/forecast_data' | |
direc = os.path.realpath(adsr + '/demandTr-' + str(self._cfg.data_id) + '.npy') | |
self._demandTr = np.load(direc) | |
print("loaded training set=", direc) | |
else: | |
if self._cfg.demandDistribution == 0: # uniform | |
self._demandTr = np.random.randint(0, self._cfg.demandUp, size=[self._cfg.demandSize, self._cfg.TUp]) | |
elif self._cfg.demandDistribution == 1: # normal distribution | |
self._demandTr = np.round( | |
np.random.normal( | |
self._cfg.demandMu, self._cfg.demandSigma, size=[self._cfg.demandSize, self._cfg.TUp] | |
) | |
).astype(int) | |
elif self._cfg.demandDistribution == 2: # the sequence of 4,4,4,4,8,... | |
self._demandTr = np.concatenate( | |
(4 * np.ones((self._cfg.demandSize, 4)), 8 * np.ones((self._cfg.demandSize, 98))), axis=1 | |
).astype(int) | |
# initilize an instance of Beergame | |
self._env = clBeerGame(self._cfg) | |
self.observation_space = gym.spaces.Box( | |
low=float("-inf"), | |
high=float("inf"), | |
shape=(self._cfg.stateDim * self._cfg.multPerdInpt, ), | |
dtype=np.float32 | |
) # state_space = state_dim * m (considering the reward delay) | |
self.action_space = gym.spaces.Discrete(self._cfg.actionListLen) # length of action list | |
self.reward_space = gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(1, ), dtype=np.float32) | |
# get the length of the demand. | |
self._demand_len = np.shape(self._demandTr)[0] | |
def reset(self): | |
self._env.resetGame(demand=self._demandTr[random.randint(0, self._demand_len - 1)]) | |
obs = [i for item in self._env.players[self._role].currentState for i in item] | |
return obs | |
def seed(self, seed: int) -> None: | |
self._seed = seed | |
np.random.seed(self._seed) | |
def close(self) -> None: | |
pass | |
def step(self, action: np.ndarray): | |
self._env.handelAction(action) | |
self._env.next() | |
newstate = np.append( | |
self._env.players[self._role].currentState[1:, :], [self._env.players[self._role].nextObservation], axis=0 | |
) | |
self._env.players[self._role].currentState = newstate | |
obs = [i for item in newstate for i in item] | |
rew = self._env.players[self._role].curReward | |
done = (self._env.curTime == self._env.T) | |
info = {} | |
return obs, rew, done, info | |
def reward_shaping(self, reward: Tensor) -> Tensor: | |
self._totRew, self._cumReward = self._env.distTotReward(self._role) | |
reward += (self._cfg.distCoeff / 3) * ((self._totRew - self._cumReward) / (self._env.T)) | |
return reward | |
def enable_save_figure(self, figure_path: Optional[str] = None) -> None: | |
self._cfg.ifSaveFigure = True | |
if figure_path is None: | |
figure_path = './' | |
self._cfg.figure_dir = figure_path | |
self._env.doTestMid(self._demandTr[random.randint(0, self._demand_len - 1)]) | |