File size: 3,005 Bytes
6340080 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import numpy as np
import torch
def get_agent_id_feature(agent_id, agent_num):
agent_id_feature = torch.zeros(agent_num)
agent_id_feature[agent_id] = 1
return agent_id_feature
def get_movement_feature():
# for simplicity, we use random movement feature here
movement_feature = torch.randint(0, 2, (8, ))
return movement_feature
def get_own_feature():
# for simplicity, we use random own feature here
return torch.randn(10)
def get_ally_visible_feature():
# this function only return the visible feature of one ally
# for simplicity, we use random tensor as ally visible feature while zero tensor as ally invisible feature
if np.random.random() > 0.5:
ally_visible_feature = torch.randn(4)
else:
ally_visible_feature = torch.zeros(4)
return ally_visible_feature
def get_enemy_visible_feature():
# this function only return the visible feature of one enemy
# for simplicity, we use random tensor as enemy visible feature while zero tensor as enemy invisible feature
if np.random.random() > 0.8:
enemy_visible_feature = torch.randn(4)
else:
enemy_visible_feature = torch.zeros(4)
return enemy_visible_feature
def get_ind_global_state(agent_id, ally_agent_num, enemy_agent_num):
# You need to implement this function
raise NotImplementedError
def get_ep_global_state(agent_id, ally_agent_num, enemy_agent_num):
# In many multi-agent environments such as SMAC, the global state is the simplified version of the combination
# of all the agent's independent state, and the concrete implementation depends on the characteris of environment.
# For simplicity, we use random feature here.
ally_center_feature = torch.randn(8)
enemy_center_feature = torch.randn(8)
return torch.cat([ally_center_feature, enemy_center_feature])
def get_as_global_state(agent_id, ally_agent_num, enemy_agent_num):
# You need to implement this function
raise NotImplementedError
def test_global_state():
ally_agent_num = 3
enemy_agent_num = 5
# get independent global state, which usually used in decentralized training
for agent_id in range(ally_agent_num):
ind_global_state = get_ind_global_state(agent_id, ally_agent_num, enemy_agent_num)
assert isinstance(ind_global_state, torch.Tensor)
# get environment provide global state, which is the same for all agents, used in centralized training
for agent_id in range(ally_agent_num):
ep_global_state = get_ep_global_state(agent_id, ally_agent_num, enemy_agent_num)
assert isinstance(ep_global_state, torch.Tensor)
# get naive agent-specific global state, which is the specific for each agent, used in centralized training
for agent_id in range(ally_agent_num):
as_global_state = get_as_global_state(agent_id, ally_agent_num, enemy_agent_num)
assert isinstance(as_global_state, torch.Tensor)
if __name__ == "__main__":
test_global_state()
|